Unverified Commit ce2298fb authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[T5, generation] Add decoder caching for T5 (#3682)



* initial commit to add decoder caching for T5

* better naming for caching

* finish T5 decoder caching

* correct test

* added extensive past testing for T5

* clean files

* make tests cleaner

* improve docstring

* improve docstring

* better reorder cache

* make style

* Update src/transformers/modeling_t5.py
Co-Authored-By: default avatarYacine Jernite <yjernite@users.noreply.github.com>

* make set output past work for all layers

* improve docstring

* improve docstring
Co-authored-by: default avatarYacine Jernite <yjernite@users.noreply.github.com>
parent 9384e5f6
This diff is collapsed.
......@@ -1417,17 +1417,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = []
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [layer_past[:, i].unsqueeze(1).clone().detach() for i in beam_idx]
reordered_layer_past = torch.cat(reordered_layer_past, dim=1)
# check that shape matches
assert reordered_layer_past.shape == layer_past.shape
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
return tuple(layer_past.index_select(1, beam_idx) for layer_past in past)
def calc_banned_ngram_tokens(prev_input_ids, num_hypos, no_repeat_ngram_size, cur_len):
......
......@@ -128,6 +128,7 @@ class ModelTesterMixin:
for model_class in self.all_model_classes:
config.output_attentions = True
config.output_hidden_states = False
config.output_past = False
model = model_class(config)
model.to(torch_device)
model.eval()
......@@ -144,10 +145,9 @@ class ModelTesterMixin:
out_len = len(outputs)
if self.is_encoder_decoder:
correct_outlen = (
4 # decoder_features_or_logits, decoder_attentions, encoder_features, encoder_attentions
)
correct_outlen = 4
decoder_attention_idx = 1
if "lm_labels" in inputs_dict: # loss will come first
correct_outlen += 1 # compute loss
decoder_attention_idx += 1
......
......@@ -167,17 +167,20 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
model = T5Model(config=config)
model.to(torch_device)
model.eval()
decoder_output, encoder_output = model(
decoder_output, decoder_past, encoder_output = model(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask,
)
decoder_output, encoder_output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
decoder_output, decoder_past, encoder_output = model(
input_ids=input_ids, decoder_input_ids=decoder_input_ids
)
result = {
"encoder_output": encoder_output,
"decoder_output": decoder_output,
"decoder_past": decoder_past,
}
self.parent.assertListEqual(
list(result["encoder_output"].size()), [self.batch_size, self.encoder_seq_length, self.hidden_size]
......@@ -185,6 +188,13 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual(
list(result["decoder_output"].size()), [self.batch_size, self.decoder_seq_length, self.hidden_size]
)
self.parent.assertEqual(len(decoder_past), 2)
# decoder_past[0] should correspond to encoder output
self.parent.assertTrue(torch.all(decoder_past[0][0] == encoder_output))
# There should be `num_layers` key value embeddings stored in decoder_past[1]
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4)
def create_and_check_t5_with_lm_head(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
......@@ -198,8 +208,8 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
decoder_attention_mask=decoder_attention_mask,
lm_labels=lm_labels,
)
loss, prediction_scores, encoder_features = outputs
self.parent.assertEqual(len(outputs), 3)
loss, prediction_scores, _, _ = outputs
self.parent.assertEqual(len(outputs), 4)
result = {
"loss": loss,
"prediction_scores": prediction_scores,
......@@ -209,6 +219,92 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
)
self.check_loss_output(result)
def create_and_check_t5_decoder_model_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config).get_decoder()
model.to(torch_device)
model.eval()
# first forward pass
output, past_key_value_states = model(input_ids)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
output_from_no_past, _ = model(next_input_ids)
output_from_past, _ = model(next_tokens, past_key_value_states=past_key_value_states)
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_and_check_t5_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
model = T5Model(config=config).get_decoder()
model.to(torch_device)
model.eval()
# create attention mask
attn_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
half_seq_length = input_ids.shape[-1] // 2
attn_mask[:, half_seq_length:] = 0
# first forward pass
output, past_key_value_states = model(input_ids, attention_mask=attn_mask)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).item() + 1
random_other_next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size).squeeze(-1)
input_ids[:, -random_seq_idx_to_change] = random_other_next_tokens
# append to next input_ids and attn_mask
next_input_ids = torch.cat([input_ids, next_tokens], dim=-1)
attn_mask = torch.cat(
[attn_mask, torch.ones((attn_mask.shape[0], 1), dtype=torch.long, device=torch_device)], dim=1,
)
# get two different outputs
output_from_no_past, _ = model(next_input_ids, attention_mask=attn_mask)
output_from_past, _ = model(
next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask
)
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach()
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))
def create_t5_and_check_t5_generate_with_past_key_value_states(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
):
config.num_layers = 1
model = T5ForConditionalGeneration(config=config)
model.to(torch_device)
model.eval()
torch.manual_seed(0)
model.set_output_past(False)
output_without_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
torch.manual_seed(0)
model.set_output_past(True)
output_with_past_cache = model.generate(input_ids[:1], num_beams=2, max_length=5, do_sample=True)
self.parent.assertTrue(torch.all(output_with_past_cache == output_without_past_cache))
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
......@@ -247,6 +343,18 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
def test_t5_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs)
def test_t5_decoder_model_past_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs)
def test_t5_generate_with_past_key_value_states(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_t5_and_check_t5_generate_with_past_key_value_states(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in list(T5_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment