"vscode:/vscode.git/clone" did not exist on "923dd4e5ef2afff41e2a816c235db9acd4dd3d7d"
Unverified Commit 8edc98bb authored by Derrick Blakely's avatar Derrick Blakely Committed by GitHub
Browse files

Allow RAG to output decoder cross-attentions (#9789)



* get cross attns

* add cross-attns doc strings

* fix typo

* line length

* Apply suggestions from code review
Co-authored-by: default avatarQuentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Co-authored-by: default avatarQuentin Lhoest <42851186+lhoestq@users.noreply.github.com>
parent 8f6c12d3
...@@ -102,6 +102,12 @@ class RetrievAugLMMarginOutput(ModelOutput): ...@@ -102,6 +102,12 @@ class RetrievAugLMMarginOutput(ModelOutput):
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads. average in the self-attention heads.
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
""" """
loss: Optional[torch.FloatTensor] = None loss: Optional[torch.FloatTensor] = None
...@@ -120,6 +126,7 @@ class RetrievAugLMMarginOutput(ModelOutput): ...@@ -120,6 +126,7 @@ class RetrievAugLMMarginOutput(ModelOutput):
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass @dataclass
...@@ -186,6 +193,12 @@ class RetrievAugLMOutput(ModelOutput): ...@@ -186,6 +193,12 @@ class RetrievAugLMOutput(ModelOutput):
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads. average in the self-attention heads.
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
""" """
logits: torch.FloatTensor = None logits: torch.FloatTensor = None
...@@ -203,6 +216,7 @@ class RetrievAugLMOutput(ModelOutput): ...@@ -203,6 +216,7 @@ class RetrievAugLMOutput(ModelOutput):
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
class RagPreTrainedModel(PreTrainedModel): class RagPreTrainedModel(PreTrainedModel):
...@@ -619,6 +633,7 @@ class RagModel(RagPreTrainedModel): ...@@ -619,6 +633,7 @@ class RagModel(RagPreTrainedModel):
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values, past_key_values=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions,
return_dict=True, return_dict=True,
) )
...@@ -655,6 +670,7 @@ class RagModel(RagPreTrainedModel): ...@@ -655,6 +670,7 @@ class RagModel(RagPreTrainedModel):
generator_enc_attentions=gen_outputs.encoder_attentions, generator_enc_attentions=gen_outputs.encoder_attentions,
generator_dec_hidden_states=gen_outputs.decoder_hidden_states, generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
generator_dec_attentions=gen_outputs.decoder_attentions, generator_dec_attentions=gen_outputs.decoder_attentions,
generator_cross_attentions=gen_outputs.cross_attentions,
) )
...@@ -803,6 +819,7 @@ class RagSequenceForGeneration(RagPreTrainedModel): ...@@ -803,6 +819,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
generator_enc_attentions=outputs.generator_enc_attentions, generator_enc_attentions=outputs.generator_enc_attentions,
generator_dec_hidden_states=outputs.generator_dec_hidden_states, generator_dec_hidden_states=outputs.generator_dec_hidden_states,
generator_dec_attentions=outputs.generator_dec_attentions, generator_dec_attentions=outputs.generator_dec_attentions,
generator_cross_attentions=outputs.generator_cross_attentions,
) )
@property @property
...@@ -1264,6 +1281,7 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1264,6 +1281,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
generator_enc_attentions=outputs.generator_enc_attentions, generator_enc_attentions=outputs.generator_enc_attentions,
generator_dec_hidden_states=outputs.generator_dec_hidden_states, generator_dec_hidden_states=outputs.generator_dec_hidden_states,
generator_dec_attentions=outputs.generator_dec_attentions, generator_dec_attentions=outputs.generator_dec_attentions,
generator_cross_attentions=outputs.generator_cross_attentions,
) )
@torch.no_grad() @torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment