Unverified Commit fa1ddced authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[RAG, Bart] Align RAG, Bart cache with T5 and other models of transformers (#9098)

* fix rag

* fix slow test

* fix past in bart
parent 6587cf9f
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import math import math
import random import random
import warnings import warnings
from typing import Dict, Optional, Tuple from typing import Optional, Tuple
import numpy as np import numpy as np
import torch import torch
...@@ -407,7 +407,7 @@ class BartDecoderLayer(nn.Module): ...@@ -407,7 +407,7 @@ class BartDecoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
encoder_attn_mask: Optional[torch.Tensor] = None, encoder_attn_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.Tensor]]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[torch.Tensor] = False, output_attentions: Optional[torch.Tensor] = False,
): ):
...@@ -416,9 +416,10 @@ class BartDecoderLayer(nn.Module): ...@@ -416,9 +416,10 @@ class BartDecoderLayer(nn.Module):
hidden_states = self.self_attn_layer_norm(hidden_states) hidden_states = self.self_attn_layer_norm(hidden_states)
# Self Attention # Self Attention
# decoder uni-directional self-attention cached key/values tuple is at first position # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[0] if past_key_value is not None else None self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
hidden_states, self_attn_weights, self_attn_present_key_value = self.self_attn( # add present self-attn cache to positions 1,2 of present_key_value tuple
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
past_key_value=self_attn_past_key_value, past_key_value=self_attn_past_key_value,
attn_mask=attn_mask, attn_mask=attn_mask,
...@@ -437,8 +438,8 @@ class BartDecoderLayer(nn.Module): ...@@ -437,8 +438,8 @@ class BartDecoderLayer(nn.Module):
if self.normalize_before: if self.normalize_before:
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
# cross_attn cached key/values tuple is at second position # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
cross_attn_past_key_value = past_key_value[1] if past_key_value is not None else None cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
hidden_states=hidden_states, hidden_states=hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
...@@ -451,6 +452,9 @@ class BartDecoderLayer(nn.Module): ...@@ -451,6 +452,9 @@ class BartDecoderLayer(nn.Module):
if not self.normalize_before: if not self.normalize_before:
hidden_states = self.encoder_attn_layer_norm(hidden_states) hidden_states = self.encoder_attn_layer_norm(hidden_states)
# add cross-attn to positions 3,4 of present_key_value tuple
present_key_value = present_key_value + cross_attn_present_key_value
# Fully Connected # Fully Connected
residual = hidden_states residual = hidden_states
if self.normalize_before: if self.normalize_before:
...@@ -463,9 +467,6 @@ class BartDecoderLayer(nn.Module): ...@@ -463,9 +467,6 @@ class BartDecoderLayer(nn.Module):
if not self.normalize_before: if not self.normalize_before:
hidden_states = self.final_layer_norm(hidden_states) hidden_states = self.final_layer_norm(hidden_states)
# make sure decoder uni-directional self-attn at 1st position and cross-attn at 2nd position.
present_key_value = (self_attn_present_key_value, cross_attn_present_key_value)
return ( return (
hidden_states, hidden_states,
self_attn_weights, self_attn_weights,
...@@ -600,7 +601,7 @@ BART_INPUTS_DOCSTRING = r""" ...@@ -600,7 +601,7 @@ BART_INPUTS_DOCSTRING = r"""
:obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`, :obj:`attentions`) :obj:`last_hidden_state` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
`optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the `optional`) is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
cross-attention of the decoder. cross-attention of the decoder.
past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding. Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up decoding.
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
...@@ -857,7 +858,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -857,7 +858,7 @@ class BartDecoder(BartPretrainedModel):
- 0 for tokens that are **masked**. - 0 for tokens that are **masked**.
`What are attention masks? <../glossary.html#attention-mask>`__ `What are attention masks? <../glossary.html#attention-mask>`__
past_key_values (:obj:`Tuple[Tuple[Tuple[torch.Tensor]]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers` with each tuple having 2 tuples each of which has 2 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up Contains precomputed key and value hidden-states of the attention blocks. Can be used to speed up
decoding. decoding.
...@@ -897,7 +898,7 @@ class BartDecoder(BartPretrainedModel): ...@@ -897,7 +898,7 @@ class BartDecoder(BartPretrainedModel):
raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
# past_key_values_length # past_key_values_length
past_key_values_length = past_key_values[0][0][0].shape[2] if past_key_values is not None else 0 past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
...@@ -1284,12 +1285,9 @@ class BartForConditionalGeneration(BartPretrainedModel): ...@@ -1284,12 +1285,9 @@ class BartForConditionalGeneration(BartPretrainedModel):
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
def _reorder_buffer(cache: Tuple[torch.Tensor], new_order) -> Dict:
return tuple(past_state.index_select(0, new_order) for past_state in cache)
reordered_past = () reordered_past = ()
for layer_past in past: for layer_past in past:
reordered_past += (tuple(_reorder_buffer(cache, beam_idx) for cache in layer_past),) reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past return reordered_past
......
...@@ -1029,6 +1029,10 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1029,6 +1029,10 @@ class RagTokenForGeneration(RagPreTrainedModel):
n_docs=None, n_docs=None,
**kwargs **kwargs
): ):
if past is not None:
# if past is defined use only last decoder_input_ids
decoder_input_ids = decoder_input_ids[:, -1:]
return { return {
"input_ids": None, "input_ids": None,
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
...@@ -1057,23 +1061,17 @@ class RagTokenForGeneration(RagPreTrainedModel): ...@@ -1057,23 +1061,17 @@ class RagTokenForGeneration(RagPreTrainedModel):
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
"""Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs""" """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
def _reorder_stacked(hidden_states): def _reorder_stacked(hidden_states, new_order):
n_docs = hidden_states.shape[0] // beam_idx.shape[0] n_docs = hidden_states.shape[0] // new_order.shape[0]
hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:]) hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:])
hidden_states = hidden_states.index_select(0, beam_idx) hidden_states = hidden_states.index_select(0, new_order)
return hidden_states.view(-1, *hidden_states.shape[2:]) result = hidden_states.view(-1, *hidden_states.shape[2:])
return result
def _reorder_buffer(attn_cache):
for k, input_buffer_k in attn_cache.items():
if input_buffer_k is not None:
attn_cache[k] = _reorder_stacked(input_buffer_k)
return attn_cache
reordered_past = [] reordered_past = ()
for layer_past in past: for layer_past in past:
# get the correct batch idx from decoder layer's batch dim for cross and self-attn # get the correct batch idx from decoder layer's batch dim for cross and self-attn
layer_past_new = {attn_key: _reorder_buffer(attn_cache) for attn_key, attn_cache in layer_past.items()} reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),)
reordered_past.append(layer_past_new)
return reordered_past return reordered_past
......
...@@ -535,7 +535,6 @@ class RagDPRBartTest(RagTestMixin, unittest.TestCase): ...@@ -535,7 +535,6 @@ class RagDPRBartTest(RagTestMixin, unittest.TestCase):
n_docs=self.n_docs, n_docs=self.n_docs,
retrieval_vector_size=self.retrieval_vector_size, retrieval_vector_size=self.retrieval_vector_size,
max_combined_length=self.max_combined_length, max_combined_length=self.max_combined_length,
use_cache=False,
) )
return { return {
...@@ -565,7 +564,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase): ...@@ -565,7 +564,6 @@ class RagDPRT5Test(RagTestMixin, unittest.TestCase):
n_docs=self.n_docs, n_docs=self.n_docs,
retrieval_vector_size=self.retrieval_vector_size, retrieval_vector_size=self.retrieval_vector_size,
max_combined_length=self.max_combined_length, max_combined_length=self.max_combined_length,
use_cache=False,
) )
return { return {
...@@ -758,8 +756,8 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -758,8 +756,8 @@ class RagModelIntegrationTests(unittest.TestCase):
generator_tokenizer=rag_decoder_tokenizer, generator_tokenizer=rag_decoder_tokenizer,
) )
rag_token = self.sequence_model rag_sequence = self.sequence_model
rag_token.set_retriever(rag_retriever) rag_sequence.set_retriever(rag_retriever)
input_ids = rag_question_encoder_tokenizer( input_ids = rag_question_encoder_tokenizer(
"who sings does he love me with reba", return_tensors="pt" "who sings does he love me with reba", return_tensors="pt"
...@@ -767,9 +765,9 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -767,9 +765,9 @@ class RagModelIntegrationTests(unittest.TestCase):
input_ids = input_ids.to(torch_device) input_ids = input_ids.to(torch_device)
output_ids = rag_token.generate( output_ids = rag_sequence.generate(
input_ids, input_ids,
decoder_start_token_id=rag_token.generator.config.decoder_start_token_id, decoder_start_token_id=rag_sequence.generator.config.decoder_start_token_id,
num_beams=2, num_beams=2,
num_return_sequences=2, num_return_sequences=2,
) )
...@@ -810,7 +808,7 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -810,7 +808,7 @@ class RagModelIntegrationTests(unittest.TestCase):
retriever = RagRetriever.from_pretrained( retriever = RagRetriever.from_pretrained(
"facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
) )
rag_sequence = RagTokenForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to( rag_sequence = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", retriever=retriever).to(
torch_device torch_device
) )
...@@ -844,9 +842,9 @@ class RagModelIntegrationTests(unittest.TestCase): ...@@ -844,9 +842,9 @@ class RagModelIntegrationTests(unittest.TestCase):
" walls of the abdomen", " walls of the abdomen",
" spodumene", " spodumene",
" obama", " obama",
" grainger's compound", " new orleans",
" japan", " japan",
" old trafford stadium", " old trafford",
] ]
self.assertListEqual(outputs, EXPECTED_OUTPUTS) self.assertListEqual(outputs, EXPECTED_OUTPUTS)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment