Unverified Commit 8edc98bb authored by Derrick Blakely's avatar Derrick Blakely Committed by GitHub
Browse files

Allow RAG to output decoder cross-attentions (#9789)



* get cross attns

* add cross-attns doc strings

* fix typo

* line length

* Apply suggestions from code review
Co-authored-by: default avatarQuentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Co-authored-by: default avatarQuentin Lhoest <42851186+lhoestq@users.noreply.github.com>
parent 8f6c12d3
......@@ -102,6 +102,12 @@ class RetrievAugLMMarginOutput(ModelOutput):
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
loss: Optional[torch.FloatTensor] = None
......@@ -120,6 +126,7 @@ class RetrievAugLMMarginOutput(ModelOutput):
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
......@@ -186,6 +193,12 @@ class RetrievAugLMOutput(ModelOutput):
Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
average in the self-attention heads.
generator_cross_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
sequence_length, sequence_length)`.
Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
weighted average in the cross-attention heads.
"""
logits: torch.FloatTensor = None
......@@ -203,6 +216,7 @@ class RetrievAugLMOutput(ModelOutput):
generator_enc_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
generator_dec_attentions: Optional[Tuple[torch.FloatTensor]] = None
generator_cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
class RagPreTrainedModel(PreTrainedModel):
......@@ -619,6 +633,7 @@ class RagModel(RagPreTrainedModel):
decoder_attention_mask=decoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
return_dict=True,
)
......@@ -655,6 +670,7 @@ class RagModel(RagPreTrainedModel):
generator_enc_attentions=gen_outputs.encoder_attentions,
generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
generator_dec_attentions=gen_outputs.decoder_attentions,
generator_cross_attentions=gen_outputs.cross_attentions,
)
......@@ -803,6 +819,7 @@ class RagSequenceForGeneration(RagPreTrainedModel):
generator_enc_attentions=outputs.generator_enc_attentions,
generator_dec_hidden_states=outputs.generator_dec_hidden_states,
generator_dec_attentions=outputs.generator_dec_attentions,
generator_cross_attentions=outputs.generator_cross_attentions,
)
@property
......@@ -1264,6 +1281,7 @@ class RagTokenForGeneration(RagPreTrainedModel):
generator_enc_attentions=outputs.generator_enc_attentions,
generator_dec_hidden_states=outputs.generator_dec_hidden_states,
generator_dec_attentions=outputs.generator_dec_attentions,
generator_cross_attentions=outputs.generator_cross_attentions,
)
@torch.no_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment