Unverified Commit 1750e629 authored by Mehrad Moradshahi's avatar Mehrad Moradshahi Committed by GitHub
Browse files

Generate can return cross-attention weights too (#10493)

parent b0138422
...@@ -96,6 +96,9 @@ class GreedySearchEncoderDecoderOutput(ModelOutput): ...@@ -96,6 +96,9 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``): decoder_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`. :obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`. :obj:`torch.FloatTensor` of shape :obj:`(batch_size, generated_length, hidden_size)`.
...@@ -106,6 +109,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput): ...@@ -106,6 +109,7 @@ class GreedySearchEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -164,6 +168,9 @@ class SampleEncoderDecoderOutput(ModelOutput): ...@@ -164,6 +168,9 @@ class SampleEncoderDecoderOutput(ModelOutput):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length, :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, num_heads, generated_length,
sequence_length)`. sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`. :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_return_sequences, generated_length, hidden_size)`.
...@@ -174,6 +181,7 @@ class SampleEncoderDecoderOutput(ModelOutput): ...@@ -174,6 +181,7 @@ class SampleEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -239,6 +247,9 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): ...@@ -239,6 +247,9 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads, :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, num_heads,
generated_length, sequence_length)`. generated_length, sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length, :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams*num_return_sequences, generated_length,
...@@ -251,6 +262,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput): ...@@ -251,6 +262,7 @@ class BeamSearchEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -314,6 +326,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): ...@@ -314,6 +326,9 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length, :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, num_heads, generated_length,
sequence_length)`. sequence_length)`.
cross_attentions (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_attentions=True`` is passed or ``config.output_attentions=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_heads, generated_length, sequence_length)`.
decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``): decoder_hidden_states (:obj:`tuple(tuple(torch.FloatTensor))`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
:obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`. :obj:`torch.FloatTensor` of shape :obj:`(batch_size*num_beams, generated_length, hidden_size)`.
...@@ -325,6 +340,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput): ...@@ -325,6 +340,7 @@ class BeamSampleEncoderDecoderOutput(ModelOutput):
encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
cross_attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None decoder_hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
...@@ -1177,6 +1193,7 @@ class GenerationMixin: ...@@ -1177,6 +1193,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
...@@ -1212,6 +1229,8 @@ class GenerationMixin: ...@@ -1212,6 +1229,8 @@ class GenerationMixin:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
) )
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
...@@ -1260,6 +1279,7 @@ class GenerationMixin: ...@@ -1260,6 +1279,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
) )
else: else:
...@@ -1384,6 +1404,7 @@ class GenerationMixin: ...@@ -1384,6 +1404,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
...@@ -1424,6 +1445,8 @@ class GenerationMixin: ...@@ -1424,6 +1445,8 @@ class GenerationMixin:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
) )
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
...@@ -1468,6 +1491,7 @@ class GenerationMixin: ...@@ -1468,6 +1491,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
) )
else: else:
...@@ -1604,6 +1628,7 @@ class GenerationMixin: ...@@ -1604,6 +1628,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
...@@ -1656,6 +1681,8 @@ class GenerationMixin: ...@@ -1656,6 +1681,8 @@ class GenerationMixin:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
) )
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
...@@ -1716,6 +1743,7 @@ class GenerationMixin: ...@@ -1716,6 +1743,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
) )
else: else:
...@@ -1865,6 +1893,7 @@ class GenerationMixin: ...@@ -1865,6 +1893,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
...@@ -1913,6 +1942,8 @@ class GenerationMixin: ...@@ -1913,6 +1942,8 @@ class GenerationMixin:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
) )
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
...@@ -1968,17 +1999,18 @@ class GenerationMixin: ...@@ -1968,17 +1999,18 @@ class GenerationMixin:
if not output_scores: if not output_scores:
sequence_outputs["sequence_scores"] = None sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput( return BeamSampleEncoderDecoderOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
) )
else: else:
return BeamSearchDecoderOnlyOutput( return BeamSampleDecoderOnlyOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
sequences_scores=sequence_outputs["sequence_scores"], sequences_scores=sequence_outputs["sequence_scores"],
scores=scores, scores=scores,
...@@ -2115,6 +2147,7 @@ class GenerationMixin: ...@@ -2115,6 +2147,7 @@ class GenerationMixin:
# init attention / hidden states / scores tuples # init attention / hidden states / scores tuples
scores = () if (return_dict_in_generate and output_scores) else None scores = () if (return_dict_in_generate and output_scores) else None
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
...@@ -2238,6 +2271,8 @@ class GenerationMixin: ...@@ -2238,6 +2271,8 @@ class GenerationMixin:
decoder_attentions += ( decoder_attentions += (
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
) )
if self.config.is_encoder_decoder:
cross_attentions += (outputs.cross_attentions,)
if output_hidden_states: if output_hidden_states:
decoder_hidden_states += ( decoder_hidden_states += (
...@@ -2263,7 +2298,7 @@ class GenerationMixin: ...@@ -2263,7 +2298,7 @@ class GenerationMixin:
if return_dict_in_generate: if return_dict_in_generate:
if not output_scores: if not output_scores:
sequence_outputs["sequence_scores"] sequence_outputs["sequence_scores"] = None
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
return BeamSearchEncoderDecoderOutput( return BeamSearchEncoderDecoderOutput(
sequences=sequence_outputs["sequences"], sequences=sequence_outputs["sequences"],
...@@ -2272,6 +2307,7 @@ class GenerationMixin: ...@@ -2272,6 +2307,7 @@ class GenerationMixin:
encoder_attentions=encoder_attentions, encoder_attentions=encoder_attentions,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
decoder_attentions=decoder_attentions, decoder_attentions=decoder_attentions,
cross_attentions=cross_attentions,
decoder_hidden_states=decoder_hidden_states, decoder_hidden_states=decoder_hidden_states,
) )
else: else:
......
...@@ -39,6 +39,8 @@ if is_torch_available(): ...@@ -39,6 +39,8 @@ if is_torch_available():
TopPLogitsWarper, TopPLogitsWarper,
) )
from transformers.generation_utils import ( from transformers.generation_utils import (
BeamSampleDecoderOnlyOutput,
BeamSampleEncoderDecoderOutput,
BeamSearchDecoderOnlyOutput, BeamSearchDecoderOnlyOutput,
BeamSearchEncoderDecoderOutput, BeamSearchEncoderDecoderOutput,
GreedySearchDecoderOnlyOutput, GreedySearchDecoderOnlyOutput,
...@@ -900,11 +902,11 @@ class GenerationTesterMixin: ...@@ -900,11 +902,11 @@ class GenerationTesterMixin:
) )
if model.config.is_encoder_decoder: if model.config.is_encoder_decoder:
self.assertIsInstance(output_beam_sample, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_beam_sample, BeamSampleEncoderDecoderOutput)
self.assertIsInstance(output_generate, BeamSearchEncoderDecoderOutput) self.assertIsInstance(output_generate, BeamSampleEncoderDecoderOutput)
else: else:
self.assertIsInstance(output_beam_sample, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_beam_sample, BeamSampleDecoderOnlyOutput)
self.assertIsInstance(output_generate, BeamSearchDecoderOnlyOutput) self.assertIsInstance(output_generate, BeamSampleDecoderOnlyOutput)
self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist()) self.assertListEqual(output_generate.sequences.tolist(), output_beam_sample.sequences.tolist())
self.assertTrue( self.assertTrue(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment