Unverified Commit bc0d26d1 authored by Yossi Synett's avatar Yossi Synett Committed by GitHub
Browse files

[All Seq2Seq model + CLM models that can be used with EncoderDecoder] Add...


[All Seq2Seq model + CLM models that can be used with EncoderDecoder] Add cross-attention weights to outputs (#8071)

* Output cross-attention with decoder attention output

* Update src/transformers/modeling_bert.py

* add cross-attention for t5 and bart as well

* fix tests

* correct typo in docs

* add sylvains and sams comments

* correct typo
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 30f2507a
...@@ -65,12 +65,34 @@ BaseModelOutputWithPooling ...@@ -65,12 +65,34 @@ BaseModelOutputWithPooling
:members: :members:
BaseModelOutputWithCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithCrossAttentions
:members:
BaseModelOutputWithPoolingAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions
:members:
BaseModelOutputWithPast BaseModelOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPast .. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPast
:members: :members:
BaseModelOutputWithPastAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_outputs.BaseModelOutputWithPastAndCrossAttentions
:members:
Seq2SeqModelOutput Seq2SeqModelOutput
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...@@ -85,6 +107,20 @@ CausalLMOutput ...@@ -85,6 +107,20 @@ CausalLMOutput
:members: :members:
CausalLMOutputWithCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithCrossAttentions
:members:
CausalLMOutputWithPastAndCrossAttentions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.modeling_outputs.CausalLMOutputWithPastAndCrossAttentions
:members:
CausalLMOutputWithPast CausalLMOutputWithPast
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
...@@ -35,7 +35,7 @@ from .file_utils import ( ...@@ -35,7 +35,7 @@ from .file_utils import (
) )
from .modeling_outputs import ( from .modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput, Seq2SeqLMOutput,
Seq2SeqModelOutput, Seq2SeqModelOutput,
Seq2SeqQuestionAnsweringModelOutput, Seq2SeqQuestionAnsweringModelOutput,
...@@ -451,11 +451,12 @@ class DecoderLayer(nn.Module): ...@@ -451,11 +451,12 @@ class DecoderLayer(nn.Module):
assert self.encoder_attn.cache_key != self.self_attn.cache_key assert self.encoder_attn.cache_key != self.self_attn.cache_key
if self.normalize_before: if self.normalize_before:
x = self.encoder_attn_layer_norm(x) x = self.encoder_attn_layer_norm(x)
x, _ = self.encoder_attn( x, cross_attn_weights = self.encoder_attn(
query=x, query=x,
key=encoder_hidden_states, key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask, key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state layer_state=layer_state, # mutates layer state
output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
...@@ -477,7 +478,8 @@ class DecoderLayer(nn.Module): ...@@ -477,7 +478,8 @@ class DecoderLayer(nn.Module):
x, x,
self_attn_weights, self_attn_weights,
layer_state, layer_state,
) # just self_attn weights for now, following t5, layer_state = cache for decoding cross_attn_weights,
) # layer_state = cache for decoding
class BartDecoder(nn.Module): class BartDecoder(nn.Module):
...@@ -590,6 +592,7 @@ class BartDecoder(nn.Module): ...@@ -590,6 +592,7 @@ class BartDecoder(nn.Module):
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attentions = () if output_attentions else None
next_decoder_cache: List[Dict] = [] next_decoder_cache: List[Dict] = []
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...@@ -601,7 +604,7 @@ class BartDecoder(nn.Module): ...@@ -601,7 +604,7 @@ class BartDecoder(nn.Module):
layer_state = past_key_values[idx] if past_key_values is not None else None layer_state = past_key_values[idx] if past_key_values is not None else None
x, layer_self_attn, layer_past = decoder_layer( x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer(
x, x,
encoder_hidden_states, encoder_hidden_states,
encoder_attn_mask=encoder_padding_mask, encoder_attn_mask=encoder_padding_mask,
...@@ -616,6 +619,7 @@ class BartDecoder(nn.Module): ...@@ -616,6 +619,7 @@ class BartDecoder(nn.Module):
if output_attentions: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
all_cross_attentions += (layer_cross_attn,)
if self.layer_norm: # if config.add_final_layer_norm (mBART) if self.layer_norm: # if config.add_final_layer_norm (mBART)
x = self.layer_norm(x) x = self.layer_norm(x)
...@@ -628,9 +632,15 @@ class BartDecoder(nn.Module): ...@@ -628,9 +632,15 @@ class BartDecoder(nn.Module):
next_cache = next_decoder_cache if use_cache else None next_cache = next_decoder_cache if use_cache else None
if not return_dict: if not return_dict:
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(
return BaseModelOutputWithPast( v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] if v is not None
last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns )
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=x,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attentions,
) )
...@@ -934,6 +944,7 @@ class BartModel(PretrainedBartModel): ...@@ -934,6 +944,7 @@ class BartModel(PretrainedBartModel):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
...@@ -1078,6 +1089,7 @@ class BartForConditionalGeneration(PretrainedBartModel): ...@@ -1078,6 +1089,7 @@ class BartForConditionalGeneration(PretrainedBartModel):
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states, decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions, decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states, encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
...@@ -1207,6 +1219,7 @@ class BartForSequenceClassification(PretrainedBartModel): ...@@ -1207,6 +1219,7 @@ class BartForSequenceClassification(PretrainedBartModel):
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states, decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions, decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states, encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
...@@ -1317,6 +1330,7 @@ class BartForQuestionAnswering(PretrainedBartModel): ...@@ -1317,6 +1330,7 @@ class BartForQuestionAnswering(PretrainedBartModel):
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states, decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions, decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states, encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
......
...@@ -37,9 +37,9 @@ from .file_utils import ( ...@@ -37,9 +37,9 @@ from .file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import ( from .modeling_outputs import (
BaseModelOutput, BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutput, CausalLMOutputWithCrossAttentions,
MaskedLMOutput, MaskedLMOutput,
MultipleChoiceModelOutput, MultipleChoiceModelOutput,
NextSentencePredictorOutput, NextSentencePredictorOutput,
...@@ -449,7 +449,8 @@ class BertEncoder(nn.Module): ...@@ -449,7 +449,8 @@ class BertEncoder(nn.Module):
return_dict=False, return_dict=False,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -483,15 +484,24 @@ class BertEncoder(nn.Module): ...@@ -483,15 +484,24 @@ class BertEncoder(nn.Module):
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(
return BaseModelOutput( v
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -752,7 +762,7 @@ class BertModel(BertPreTrainedModel): ...@@ -752,7 +762,7 @@ class BertModel(BertPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="bert-base-uncased", checkpoint="bert-base-uncased",
output_type=BaseModelOutputWithPooling, output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -843,11 +853,12 @@ class BertModel(BertPreTrainedModel): ...@@ -843,11 +853,12 @@ class BertModel(BertPreTrainedModel):
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
...@@ -984,7 +995,7 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -984,7 +995,7 @@ class BertLMHeadModel(BertPreTrainedModel):
return self.cls.predictions.decoder return self.cls.predictions.decoder
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -1063,11 +1074,12 @@ class BertLMHeadModel(BertPreTrainedModel): ...@@ -1063,11 +1074,12 @@ class BertLMHeadModel(BertPreTrainedModel):
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutput( return CausalLMOutputWithCrossAttentions(
loss=lm_loss, loss=lm_loss,
logits=prediction_scores, logits=prediction_scores,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
) )
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
......
...@@ -28,7 +28,7 @@ from .file_utils import ( ...@@ -28,7 +28,7 @@ from .file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_bert import BertEncoder from .modeling_bert import BertEncoder
from .modeling_outputs import BaseModelOutput, CausalLMOutput from .modeling_outputs import BaseModelOutputWithCrossAttentions, CausalLMOutputWithCrossAttentions
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .utils import logging from .utils import logging
...@@ -297,7 +297,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ...@@ -297,7 +297,7 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder", checkpoint="google/bert_for_seq_generation_L-24_bbc_encoder",
output_type=BaseModelOutput, output_type=BaseModelOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -381,10 +381,11 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel): ...@@ -381,10 +381,11 @@ class BertGenerationEncoder(BertGenerationPreTrainedModel):
if not return_dict: if not return_dict:
return (sequence_output,) + encoder_outputs[1:] return (sequence_output,) + encoder_outputs[1:]
return BaseModelOutput( return BaseModelOutputWithCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
...@@ -422,7 +423,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): ...@@ -422,7 +423,7 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
return self.lm_head.decoder return self.lm_head.decoder
@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -499,11 +500,12 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): ...@@ -499,11 +500,12 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel):
output = (prediction_scores,) + outputs[1:] output = (prediction_scores,) + outputs[1:]
return ((lm_loss,) + output) if lm_loss is not None else output return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutput( return CausalLMOutputWithCrossAttentions(
loss=lm_loss, loss=lm_loss,
logits=prediction_scores, logits=prediction_scores,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.cross_attentions,
) )
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
......
...@@ -34,7 +34,7 @@ from .file_utils import ( ...@@ -34,7 +34,7 @@ from .file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import ( from .modeling_outputs import (
BaseModelOutput, BaseModelOutputWithCrossAttentions,
MaskedLMOutput, MaskedLMOutput,
MultipleChoiceModelOutput, MultipleChoiceModelOutput,
QuestionAnsweringModelOutput, QuestionAnsweringModelOutput,
...@@ -445,7 +445,8 @@ class ElectraEncoder(nn.Module): ...@@ -445,7 +445,8 @@ class ElectraEncoder(nn.Module):
return_dict=False, return_dict=False,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -479,15 +480,24 @@ class ElectraEncoder(nn.Module): ...@@ -479,15 +480,24 @@ class ElectraEncoder(nn.Module):
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(
return BaseModelOutput( v
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -697,7 +707,7 @@ class ElectraModel(ElectraPreTrainedModel): ...@@ -697,7 +707,7 @@ class ElectraModel(ElectraPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="google/electra-small-discriminator", checkpoint="google/electra-small-discriminator",
output_type=BaseModelOutput, output_type=BaseModelOutputWithCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
......
...@@ -426,6 +426,7 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -426,6 +426,7 @@ class EncoderDecoderModel(PreTrainedModel):
past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works past_key_values=None, # TODO(PVP) - need to implement cache for BERT, etc... before this works
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
......
...@@ -46,7 +46,12 @@ from .file_utils import ( ...@@ -46,7 +46,12 @@ from .file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from .modeling_utils import PreTrainedModel from .modeling_utils import PreTrainedModel
from .utils import logging from .utils import logging
...@@ -543,11 +548,12 @@ class DecoderLayer(nn.Module): ...@@ -543,11 +548,12 @@ class DecoderLayer(nn.Module):
# Cross attention # Cross attention
residual = x residual = x
assert self.encoder_attn.cache_key != self.self_attn.cache_key assert self.encoder_attn.cache_key != self.self_attn.cache_key
x, _ = self.encoder_attn( x, cross_attn_weights = self.encoder_attn(
query=x, query=x,
key=encoder_hidden_states, key=encoder_hidden_states,
key_padding_mask=encoder_attn_mask, key_padding_mask=encoder_attn_mask,
layer_state=layer_state, # mutates layer state layer_state=layer_state, # mutates layer state
output_attentions=output_attentions,
) )
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x x = residual + x
...@@ -565,7 +571,8 @@ class DecoderLayer(nn.Module): ...@@ -565,7 +571,8 @@ class DecoderLayer(nn.Module):
x, x,
self_attn_weights, self_attn_weights,
layer_state, layer_state,
) # just self_attn weights for now, following t5, layer_state = cache for decoding cross_attn_weights,
) # layer_state = cache for decoding
class FSMTDecoder(nn.Module): class FSMTDecoder(nn.Module):
...@@ -669,6 +676,7 @@ class FSMTDecoder(nn.Module): ...@@ -669,6 +676,7 @@ class FSMTDecoder(nn.Module):
# decoder layers # decoder layers
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None all_self_attns = () if output_attentions else None
all_cross_attns = () if output_attentions else None
next_decoder_cache = [] next_decoder_cache = []
for idx, decoder_layer in enumerate(self.layers): for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
...@@ -680,7 +688,7 @@ class FSMTDecoder(nn.Module): ...@@ -680,7 +688,7 @@ class FSMTDecoder(nn.Module):
layer_state = past_key_values[idx] if past_key_values is not None else None layer_state = past_key_values[idx] if past_key_values is not None else None
x, layer_self_attn, layer_past = decoder_layer( x, layer_self_attn, layer_past, layer_cross_attn = decoder_layer(
x, x,
encoder_hidden_states, encoder_hidden_states,
encoder_attn_mask=encoder_padding_mask, encoder_attn_mask=encoder_padding_mask,
...@@ -695,6 +703,7 @@ class FSMTDecoder(nn.Module): ...@@ -695,6 +703,7 @@ class FSMTDecoder(nn.Module):
if output_attentions: if output_attentions:
all_self_attns += (layer_self_attn,) all_self_attns += (layer_self_attn,)
all_cross_attns += (layer_cross_attn,)
# Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim) # Convert to standard output format: (seq_len, BS, model_dim) -> (BS, seq_len, model_dim)
if output_hidden_states: if output_hidden_states:
...@@ -707,9 +716,15 @@ class FSMTDecoder(nn.Module): ...@@ -707,9 +716,15 @@ class FSMTDecoder(nn.Module):
next_cache = next_decoder_cache if use_cache else None next_cache = next_decoder_cache if use_cache else None
if not return_dict: if not return_dict:
return tuple(v for v in [x, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(
return BaseModelOutputWithPast( v for v in [x, next_cache, all_hidden_states, all_self_attns, all_cross_attns] if v is not None
last_hidden_state=x, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns )
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=x,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
cross_attentions=all_cross_attns,
) )
...@@ -903,7 +918,7 @@ class FSMTModel(PretrainedFSMTModel): ...@@ -903,7 +918,7 @@ class FSMTModel(PretrainedFSMTModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="facebook/wmt19-ru-en", checkpoint="facebook/wmt19-ru-en",
output_type=BaseModelOutputWithPast, output_type=Seq2SeqModelOutput,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -989,6 +1004,7 @@ class FSMTModel(PretrainedFSMTModel): ...@@ -989,6 +1004,7 @@ class FSMTModel(PretrainedFSMTModel):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
...@@ -1101,6 +1117,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): ...@@ -1101,6 +1117,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
past_key_values=outputs.past_key_values, past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states, decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions, decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states, encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
......
...@@ -33,7 +33,11 @@ from .file_utils import ( ...@@ -33,7 +33,11 @@ from .file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from .modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithPastAndCrossAttentions,
SequenceClassifierOutputWithPast,
)
from .modeling_utils import ( from .modeling_utils import (
Conv1D, Conv1D,
PreTrainedModel, PreTrainedModel,
...@@ -311,14 +315,14 @@ class Block(nn.Module): ...@@ -311,14 +315,14 @@ class Block(nn.Module):
attn_output = cross_attn_outputs[0] attn_output = cross_attn_outputs[0]
# residual connection # residual connection
hidden_states = hidden_states + attn_output hidden_states = hidden_states + attn_output
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states)) feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
# residual connection # residual connection
hidden_states = hidden_states + feed_forward_hidden_states hidden_states = hidden_states + feed_forward_hidden_states
outputs = [hidden_states] + outputs outputs = [hidden_states] + outputs
return outputs # hidden_states, present, (cross_attentions, attentions) return outputs # hidden_states, present, (attentions, cross_attentions)
class GPT2PreTrainedModel(PreTrainedModel): class GPT2PreTrainedModel(PreTrainedModel):
...@@ -506,7 +510,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -506,7 +510,7 @@ class GPT2Model(GPT2PreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="gpt2", checkpoint="gpt2",
output_type=BaseModelOutputWithPast, output_type=BaseModelOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -618,7 +622,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -618,7 +622,8 @@ class GPT2Model(GPT2PreTrainedModel):
output_shape = input_shape + (hidden_states.size(-1),) output_shape = input_shape + (hidden_states.size(-1),)
presents = () if use_cache else None presents = () if use_cache else None
all_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
if output_hidden_states: if output_hidden_states:
...@@ -659,7 +664,9 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -659,7 +664,9 @@ class GPT2Model(GPT2PreTrainedModel):
presents = presents + (present,) presents = presents + (present,)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (outputs[2],) all_self_attentions = all_self_attentions + (outputs[2],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3],)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
...@@ -669,13 +676,14 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -669,13 +676,14 @@ class GPT2Model(GPT2PreTrainedModel):
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_attentions] if v is not None) return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=presents, past_key_values=presents,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -727,7 +735,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -727,7 +735,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="gpt2", checkpoint="gpt2",
output_type=CausalLMOutputWithPast, output_type=CausalLMOutputWithPastAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -795,12 +803,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -795,12 +803,13 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
output = (lm_logits,) + transformer_outputs[1:] output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast( return CausalLMOutputWithPastAndCrossAttentions(
loss=loss, loss=loss,
logits=lm_logits, logits=lm_logits,
past_key_values=transformer_outputs.past_key_values, past_key_values=transformer_outputs.past_key_values,
hidden_states=transformer_outputs.hidden_states, hidden_states=transformer_outputs.hidden_states,
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
) )
......
...@@ -24,7 +24,12 @@ from torch.nn import CrossEntropyLoss ...@@ -24,7 +24,12 @@ from torch.nn import CrossEntropyLoss
from .activations import ACT2FN from .activations import ACT2FN
from .configuration_layoutlm import LayoutLMConfig from .configuration_layoutlm import LayoutLMConfig
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, TokenClassifierOutput from .modeling_outputs import (
BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPoolingAndCrossAttentions,
MaskedLMOutput,
TokenClassifierOutput,
)
from .modeling_utils import ( from .modeling_utils import (
PreTrainedModel, PreTrainedModel,
apply_chunking_to_forward, apply_chunking_to_forward,
...@@ -374,7 +379,8 @@ class LayoutLMEncoder(nn.Module): ...@@ -374,7 +379,8 @@ class LayoutLMEncoder(nn.Module):
return_dict=False, return_dict=False,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -408,15 +414,24 @@ class LayoutLMEncoder(nn.Module): ...@@ -408,15 +414,24 @@ class LayoutLMEncoder(nn.Module):
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(
return BaseModelOutput( v
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -611,7 +626,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel): ...@@ -611,7 +626,7 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="layoutlm-base-uncased", checkpoint="layoutlm-base-uncased",
output_type=BaseModelOutputWithPooling, output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
def forward( def forward(
...@@ -716,11 +731,12 @@ class LayoutLMModel(LayoutLMPreTrainedModel): ...@@ -716,11 +731,12 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
......
...@@ -99,6 +99,120 @@ class BaseModelOutputWithPast(ModelOutput): ...@@ -99,6 +99,120 @@ class BaseModelOutputWithPast(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class BaseModelOutputWithCrossAttentions(ModelOutput):
"""
Base class for model's outputs, with potential hidden states and attentions.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: torch.FloatTensor
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class BaseModelOutputWithPoolingAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that also contains a pooling of the last hidden states.
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
Last layer hidden-state of the first token of the sequence (classification token) further processed by a
Linear layer and a Tanh activation function. The Linear layer weights are trained from the next sentence
prediction (classification) objective during pretraining.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: torch.FloatTensor
pooler_output: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class BaseModelOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
Args:
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
Sequence of hidden-states at the output of the last layer of the model.
If :obj:`past_key_values` is used only the last hidden-state of the sequences of shape :obj:`(batch_size,
1, hidden_size)` is output.
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` and ``config.add_cross_attention=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
last_hidden_state: torch.FloatTensor
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass @dataclass
class Seq2SeqModelOutput(ModelOutput): class Seq2SeqModelOutput(ModelOutput):
""" """
...@@ -128,6 +242,12 @@ class Seq2SeqModelOutput(ModelOutput): ...@@ -128,6 +242,12 @@ class Seq2SeqModelOutput(ModelOutput):
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
...@@ -147,6 +267,7 @@ class Seq2SeqModelOutput(ModelOutput): ...@@ -147,6 +267,7 @@ class Seq2SeqModelOutput(ModelOutput):
past_key_values: Optional[List[torch.FloatTensor]] = None past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -217,6 +338,85 @@ class CausalLMOutputWithPast(ModelOutput): ...@@ -217,6 +338,85 @@ class CausalLMOutputWithPast(ModelOutput):
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class CausalLMOutputWithCrossAttentions(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Language modeling loss (for next-token prediction).
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Cross attentions weights after the attention softmax, used to compute the weighted average in the
cross-attention heads.
"""
loss: Optional[torch.FloatTensor]
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class CausalLMOutputWithPastAndCrossAttentions(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs.
Args:
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
Language modeling loss (for next-token prediction).
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
batch_size, num_heads, sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Cross attentions weights after the attention softmax, used to compute the weighted average in the
cross-attention heads.
"""
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass @dataclass
class SequenceClassifierOutputWithPast(ModelOutput): class SequenceClassifierOutputWithPast(ModelOutput):
""" """
...@@ -309,6 +509,12 @@ class Seq2SeqLMOutput(ModelOutput): ...@@ -309,6 +509,12 @@ class Seq2SeqLMOutput(ModelOutput):
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
...@@ -329,6 +535,7 @@ class Seq2SeqLMOutput(ModelOutput): ...@@ -329,6 +535,7 @@ class Seq2SeqLMOutput(ModelOutput):
past_key_values: Optional[List[torch.FloatTensor]] = None past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -420,6 +627,12 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput): ...@@ -420,6 +627,12 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
...@@ -440,6 +653,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput): ...@@ -440,6 +653,7 @@ class Seq2SeqSequenceClassifierOutput(ModelOutput):
past_key_values: Optional[List[torch.FloatTensor]] = None past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -566,6 +780,12 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): ...@@ -566,6 +780,12 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
self-attention heads. self-attention heads.
cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder of the model. Sequence of hidden-states at the output of the last layer of the encoder of the model.
encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
...@@ -587,6 +807,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput): ...@@ -587,6 +807,7 @@ class Seq2SeqQuestionAnsweringModelOutput(ModelOutput):
past_key_values: Optional[List[torch.FloatTensor]] = None past_key_values: Optional[List[torch.FloatTensor]] = None
decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
import copy import copy
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional, Tuple from typing import Dict, Optional, Tuple
...@@ -261,7 +262,7 @@ class ProphetNetSeq2SeqLMOutput(ModelOutput): ...@@ -261,7 +262,7 @@ class ProphetNetSeq2SeqLMOutput(ModelOutput):
Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
weighted average in the self-attention heads. weighted average in the self-attention heads.
decoder_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_attn_heads, Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_attn_heads,
encoder_sequence_length, decoder_sequence_length)`. encoder_sequence_length, decoder_sequence_length)`.
...@@ -288,11 +289,19 @@ class ProphetNetSeq2SeqLMOutput(ModelOutput): ...@@ -288,11 +289,19 @@ class ProphetNetSeq2SeqLMOutput(ModelOutput):
decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
decoder_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
@property
def decoder_cross_attentions(self):
warnings.warn(
"`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.",
FutureWarning,
)
return self.cross_attentions
@dataclass @dataclass
class ProphetNetSeq2SeqModelOutput(ModelOutput): class ProphetNetSeq2SeqModelOutput(ModelOutput):
...@@ -337,7 +346,7 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput): ...@@ -337,7 +346,7 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput):
Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the Attentions weights of the predict stream of the decoder, after the attention softmax, used to compute the
weighted average in the weighted average in the
decoder_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``): cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_attn_heads, Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_attn_heads,
encoder_sequence_length, decoder_sequence_length)`. encoder_sequence_length, decoder_sequence_length)`.
...@@ -365,11 +374,19 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput): ...@@ -365,11 +374,19 @@ class ProphetNetSeq2SeqModelOutput(ModelOutput):
decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None decoder_ngram_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None decoder_ngram_attentions: Optional[Tuple[torch.FloatTensor]] = None
decoder_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_last_hidden_state: Optional[torch.FloatTensor] = None encoder_last_hidden_state: Optional[torch.FloatTensor] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
@property
def decoder_cross_attentions(self):
warnings.warn(
"`decoder_cross_attentions` is deprecated and will be removed soon. Please use `cross_attentions` instead.",
FutureWarning,
)
return self.cross_attentions
@dataclass @dataclass
class ProphetNetDecoderModelOutput(ModelOutput): class ProphetNetDecoderModelOutput(ModelOutput):
...@@ -1651,7 +1668,7 @@ class ProphetNetModel(ProphetNetPreTrainedModel): ...@@ -1651,7 +1668,7 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
decoder_ngram_hidden_states=decoder_outputs.hidden_states_ngram, decoder_ngram_hidden_states=decoder_outputs.hidden_states_ngram,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
decoder_ngram_attentions=decoder_outputs.ngram_attentions, decoder_ngram_attentions=decoder_outputs.ngram_attentions,
decoder_cross_attentions=decoder_outputs.cross_attentions, cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
...@@ -1766,7 +1783,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): ...@@ -1766,7 +1783,7 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
decoder_ngram_hidden_states=outputs.decoder_ngram_hidden_states, decoder_ngram_hidden_states=outputs.decoder_ngram_hidden_states,
decoder_attentions=outputs.decoder_attentions, decoder_attentions=outputs.decoder_attentions,
decoder_ngram_attentions=outputs.decoder_ngram_attentions, decoder_ngram_attentions=outputs.decoder_ngram_attentions,
decoder_cross_attentions=outputs.decoder_cross_attentions, cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state, encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states, encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions, encoder_attentions=outputs.encoder_attentions,
...@@ -1986,6 +2003,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): ...@@ -1986,6 +2003,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
hidden_states_ngram=outputs.hidden_states_ngram, hidden_states_ngram=outputs.hidden_states_ngram,
attentions=outputs.attentions, attentions=outputs.attentions,
ngram_attentions=outputs.ngram_attentions, ngram_attentions=outputs.ngram_attentions,
cross_attentions=outputs.cross_attentions,
) )
def _compute_loss(self, logits, labels): def _compute_loss(self, logits, labels):
......
...@@ -31,9 +31,9 @@ from .file_utils import ( ...@@ -31,9 +31,9 @@ from .file_utils import (
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import ( from .modeling_outputs import (
BaseModelOutput, BaseModelOutputWithCrossAttentions,
BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions,
CausalLMOutput, CausalLMOutputWithCrossAttentions,
MaskedLMOutput, MaskedLMOutput,
MultipleChoiceModelOutput, MultipleChoiceModelOutput,
QuestionAnsweringModelOutput, QuestionAnsweringModelOutput,
...@@ -393,7 +393,8 @@ class RobertaEncoder(nn.Module): ...@@ -393,7 +393,8 @@ class RobertaEncoder(nn.Module):
return_dict=False, return_dict=False,
): ):
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
for i, layer_module in enumerate(self.layer): for i, layer_module in enumerate(self.layer):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
...@@ -427,15 +428,24 @@ class RobertaEncoder(nn.Module): ...@@ -427,15 +428,24 @@ class RobertaEncoder(nn.Module):
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],) all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return tuple(
return BaseModelOutput( v
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithCrossAttentions(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -599,7 +609,7 @@ class RobertaModel(RobertaPreTrainedModel): ...@@ -599,7 +609,7 @@ class RobertaModel(RobertaPreTrainedModel):
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="roberta-base", checkpoint="roberta-base",
output_type=BaseModelOutputWithPooling, output_type=BaseModelOutputWithPoolingAndCrossAttentions,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
# Copied from transformers.modeling_bert.BertModel.forward # Copied from transformers.modeling_bert.BertModel.forward
...@@ -689,11 +699,12 @@ class RobertaModel(RobertaPreTrainedModel): ...@@ -689,11 +699,12 @@ class RobertaModel(RobertaPreTrainedModel):
if not return_dict: if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:] return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling( return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output, last_hidden_state=sequence_output,
pooler_output=pooled_output, pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states, hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions, attentions=encoder_outputs.attentions,
cross_attentions=encoder_outputs.cross_attentions,
) )
...@@ -719,7 +730,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel): ...@@ -719,7 +730,7 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
return self.lm_head.decoder return self.lm_head.decoder
@add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=CausalLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids=None,
...@@ -799,11 +810,12 @@ class RobertaForCausalLM(RobertaPreTrainedModel): ...@@ -799,11 +810,12 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
output = (prediction_scores,) + outputs[2:] output = (prediction_scores,) + outputs[2:]
return ((lm_loss,) + output) if lm_loss is not None else output return ((lm_loss,) + output) if lm_loss is not None else output
return CausalLMOutput( return CausalLMOutputWithCrossAttentions(
loss=lm_loss, loss=lm_loss,
logits=prediction_scores, logits=prediction_scores,
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
cross_attentions=outputs.attentions,
) )
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
......
...@@ -33,7 +33,12 @@ from .file_utils import ( ...@@ -33,7 +33,12 @@ from .file_utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
replace_return_docstrings, replace_return_docstrings,
) )
from .modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, Seq2SeqLMOutput, Seq2SeqModelOutput from .modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from .modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
from .utils import logging from .utils import logging
...@@ -503,6 +508,7 @@ class T5Block(nn.Module): ...@@ -503,6 +508,7 @@ class T5Block(nn.Module):
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
return_dict=False,
): ):
if past_key_value is not None: if past_key_value is not None:
...@@ -533,7 +539,8 @@ class T5Block(nn.Module): ...@@ -533,7 +539,8 @@ class T5Block(nn.Module):
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
if self.is_decoder and encoder_hidden_states is not None: do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention:
# the actual query length is unknown for cross attention # the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here # if using past key value states. Need to inject it here
if present_key_value_state is not None: if present_key_value_state is not None:
...@@ -564,7 +571,6 @@ class T5Block(nn.Module): ...@@ -564,7 +571,6 @@ class T5Block(nn.Module):
hidden_states = self.layer[-1](hidden_states) hidden_states = self.layer[-1](hidden_states)
outputs = (hidden_states,) outputs = (hidden_states,)
# Add attentions if we output them
outputs = outputs + (present_key_value_state,) + attention_outputs outputs = outputs + (present_key_value_state,) + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) return outputs # hidden-states, present_key_value_states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
...@@ -743,6 +749,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -743,6 +749,7 @@ class T5Stack(T5PreTrainedModel):
present_key_value_states = () if use_cache else None present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None
position_bias = None position_bias = None
encoder_decoder_position_bias = None encoder_decoder_position_bias = None
...@@ -779,7 +786,9 @@ class T5Stack(T5PreTrainedModel): ...@@ -779,7 +786,9 @@ class T5Stack(T5PreTrainedModel):
present_key_value_states = present_key_value_states + (present_key_value_state,) present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now all_attentions = all_attentions + (layer_outputs[2],)
if self.is_decoder:
all_cross_attentions = all_cross_attentions + (layer_outputs[4 if i == 0 else 3],)
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
...@@ -791,14 +800,21 @@ class T5Stack(T5PreTrainedModel): ...@@ -791,14 +800,21 @@ class T5Stack(T5PreTrainedModel):
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [hidden_states, present_key_value_states, all_hidden_states, all_attentions] for v in [
hidden_states,
present_key_value_states,
all_hidden_states,
all_attentions,
all_cross_attentions,
]
if v is not None if v is not None
) )
return BaseModelOutputWithPast( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_value_states, past_key_values=present_key_value_states,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions,
) )
...@@ -1038,6 +1054,7 @@ class T5Model(T5PreTrainedModel): ...@@ -1038,6 +1054,7 @@ class T5Model(T5PreTrainedModel):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
...@@ -1227,6 +1244,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1227,6 +1244,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
past_key_values=decoder_outputs.past_key_values, past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states, decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions, decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_last_hidden_state=encoder_outputs.last_hidden_state,
encoder_hidden_states=encoder_outputs.hidden_states, encoder_hidden_states=encoder_outputs.hidden_states,
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
......
...@@ -253,9 +253,7 @@ class ModelTesterMixin: ...@@ -253,9 +253,7 @@ class ModelTesterMixin:
out_len = len(outputs) out_len = len(outputs)
if self.is_encoder_decoder: if self.is_encoder_decoder:
correct_outlen = ( correct_outlen = 5
self.model_tester.base_model_out_len if hasattr(self.model_tester, "base_model_out_len") else 4
)
# loss is at first position # loss is at first position
if "labels" in inputs_dict: if "labels" in inputs_dict:
...@@ -266,6 +264,7 @@ class ModelTesterMixin: ...@@ -266,6 +264,7 @@ class ModelTesterMixin:
self.assertEqual(out_len, correct_outlen) self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple)) self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers) self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
...@@ -274,6 +273,19 @@ class ModelTesterMixin: ...@@ -274,6 +273,19 @@ class ModelTesterMixin:
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length], [self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
) )
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
decoder_seq_length,
encoder_key_length,
],
)
# Check attention is always last and order is fine # Check attention is always last and order is fine
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True inputs_dict["output_hidden_states"] = True
......
...@@ -292,6 +292,62 @@ class EncoderDecoderMixin: ...@@ -292,6 +292,62 @@ class EncoderDecoderMixin:
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,)) outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
) )
def check_encoder_decoder_model_output_attentions(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
labels,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
output_attentions=True,
return_dict=True,
)
encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
self.assertEqual(len(encoder_attentions), config.num_hidden_layers)
self.assertListEqual(
list(encoder_attentions[0].shape[-3:]),
[config.num_attention_heads, input_ids.shape[-1], input_ids.shape[-1]],
)
decoder_attentions = outputs_encoder_decoder["decoder_attentions"]
num_decoder_layers = (
decoder_config.num_decoder_layers
if hasattr(decoder_config, "num_decoder_layers")
else decoder_config.num_hidden_layers
)
self.assertEqual(len(decoder_attentions), num_decoder_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[decoder_config.num_attention_heads, decoder_input_ids.shape[-1], decoder_input_ids.shape[-1]],
)
cross_attentions = outputs_encoder_decoder["cross_attentions"]
self.assertEqual(len(cross_attentions), num_decoder_layers)
cross_attention_input_seq_len = input_ids.shape[-1] * (
1 + (decoder_config.ngram if hasattr(decoder_config, "ngram") else 0)
)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[decoder_config.num_attention_heads, cross_attention_input_seq_len, decoder_input_ids.shape[-1]],
)
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs): def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model) enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
...@@ -413,6 +469,10 @@ class EncoderDecoderMixin: ...@@ -413,6 +469,10 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_labels(**input_ids_dict) self.check_encoder_decoder_model_labels(**input_ids_dict)
def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
def test_encoder_decoder_model_generate(self): def test_encoder_decoder_model_generate(self):
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict) self.check_encoder_decoder_model_generate(**input_ids_dict)
......
...@@ -916,6 +916,116 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test ...@@ -916,6 +916,116 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs) self.model_tester.create_and_check_model_fp16_forward(*config_and_inputs)
# methods overwrite method in `test_modeling_common.py`
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
decoder_key_length = getattr(self.model_tester, "decoder_key_length", decoder_seq_length)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)
chunk_length = getattr(self.model_tester, "chunk_length", None)
if chunk_length is not None and hasattr(self.model_tester, "num_hashes"):
encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes
for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)
correct_outlen = 7
# loss is at first position
if "labels" in inputs_dict:
correct_outlen += 1 # loss is added to beginning
self.assertEqual(out_len, correct_outlen)
# decoder attentions
decoder_attentions = outputs.decoder_attentions
self.assertIsInstance(decoder_attentions, (list, tuple))
self.assertEqual(len(decoder_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, decoder_seq_length, decoder_key_length],
)
# cross attentions
cross_attentions = outputs.cross_attentions
self.assertIsInstance(cross_attentions, (list, tuple))
self.assertEqual(len(cross_attentions), self.model_tester.num_hidden_layers)
self.assertListEqual(
list(cross_attentions[0].shape[-3:]),
[
self.model_tester.num_attention_heads,
(self.model_tester.ngram + 1) * decoder_seq_length,
encoder_key_length,
],
)
# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
if hasattr(self.model_tester, "num_hidden_states_types"):
added_hidden_states = self.model_tester.num_hidden_states_types
elif self.is_encoder_decoder:
added_hidden_states = 2
else:
added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))
self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions
self.assertEqual(len(self_attentions), self.model_tester.num_hidden_layers)
if chunk_length is not None:
self.assertListEqual(
list(self_attentions[0].shape[-4:]),
[self.model_tester.num_attention_heads, encoder_seq_length, chunk_length, encoder_key_length],
)
else:
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
@require_torch @require_torch
class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): class ProphetNetStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment