"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "e30166118ca3419d1f7960c669f281a072fcdc2d"
Unverified Commit 38f7461d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[TFT5, Cache] Add cache to TFT5 (#3772)

* correct gpt2 test inputs

* make style

* delete modeling_gpt2 change in test file

* translate from pytorch

* correct tests

* fix conflicts

* fix conflicts

* fix conflicts

* fix conflicts

* make tensorflow t5 caching work

* make style

* clean reorder cache

* remove unnecessary spaces

* fix test
parent a5b24947
...@@ -351,12 +351,11 @@ class T5Attention(nn.Module): ...@@ -351,12 +351,11 @@ class T5Attention(nn.Module):
else: else:
k, v = past_key_value_state k, v = past_key_value_state
if self.is_decoder and use_cache: if self.is_decoder and use_cache is True:
present_key_value_state = ((k, v),) present_key_value_state = ((k, v),)
else: else:
present_key_value_state = (None,) present_key_value_state = (None,)
# q = q / math.sqrt(dim_per_head) # No scaling in T5
scores = torch.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen) scores = torch.einsum("bnqd,bnkd->bnqk", q, k) # (bs, n_heads, qlen, klen)
if position_bias is None: if position_bias is None:
...@@ -486,11 +485,15 @@ class T5Block(nn.Module): ...@@ -486,11 +485,15 @@ class T5Block(nn.Module):
if past_key_value_state is not None: if past_key_value_state is not None:
assert self.is_decoder, "Only decoder can use `past_key_value_states`" assert self.is_decoder, "Only decoder can use `past_key_value_states`"
assert ( expected_num_past_key_value_states = 2 if encoder_hidden_states is None else 4
len(past_key_value_state) == 4
), "The should be 4 past states. 2 (past / key) for self attention. 2 (past / key) for cross attention. Got {} past key / value states".format( error_message = "There should be {} past states. 2 (past / key) for self attention.{} Got {} past key / value states".format(
len(past_key_value_state) expected_num_past_key_value_states,
"2 (past / key) for cross attention" if expected_num_past_key_value_states == 4 else "",
len(past_key_value_state),
) )
assert len(past_key_value_state) == expected_num_past_key_value_states, error_message
self_attn_past_key_value_state = past_key_value_state[:2] self_attn_past_key_value_state = past_key_value_state[:2]
cross_attn_past_key_value_state = past_key_value_state[2:] cross_attn_past_key_value_state = past_key_value_state[2:]
else: else:
...@@ -507,7 +510,7 @@ class T5Block(nn.Module): ...@@ -507,7 +510,7 @@ class T5Block(nn.Module):
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, present_key_value_state = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
if self.is_decoder: if self.is_decoder and encoder_hidden_states is not None:
# the actual query length is unknown for cross attention # the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here # if using past key value states. Need to inject it here
if present_key_value_state is not None: if present_key_value_state is not None:
...@@ -691,7 +694,6 @@ class T5Stack(T5PreTrainedModel): ...@@ -691,7 +694,6 @@ class T5Stack(T5PreTrainedModel):
if past_key_value_states is None: if past_key_value_states is None:
past_key_value_states = [None] * len(self.block) past_key_value_states = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
...@@ -732,7 +734,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -732,7 +734,7 @@ class T5Stack(T5PreTrainedModel):
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # layer_outputs = hidden-states, key-value-states (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
position_bias = layer_outputs[3 if self.output_attentions else 2] position_bias = layer_outputs[3 if self.output_attentions else 2]
if self.is_decoder: if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if self.output_attentions else 3]
# append next layer key value states # append next layer key value states
present_key_value_states = present_key_value_states + (present_key_value_state,) present_key_value_states = present_key_value_states + (present_key_value_state,)
......
This diff is collapsed.
...@@ -1299,17 +1299,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin): ...@@ -1299,17 +1299,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
@staticmethod @staticmethod
def _reorder_cache(past, beam_idx): def _reorder_cache(past, beam_idx):
reordered_past = [] return tuple(tf.gather(layer_past, beam_idx, axis=1) for layer_past in past)
for layer_past in past:
# get the correct batch idx from layer past batch dim
# batch dim of `past` and `mems` is at 2nd position
reordered_layer_past = [tf.identity(tf.expand_dims(layer_past[:, i], 1)) for i in beam_idx]
reordered_layer_past = tf.concat(reordered_layer_past, axis=1)
# check that shape matches
assert shape_list(reordered_layer_past) == shape_list(layer_past)
reordered_past.append(reordered_layer_past)
past = tuple(reordered_past)
return past
def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty): def _create_next_token_logits_penalties(input_ids, logits, repetition_penalty):
......
...@@ -244,7 +244,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -244,7 +244,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach()
# test that outputs are equal for slice # test that outputs are equal for slice
self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-6))
def create_and_check_t5_decoder_model_attention_mask_past( def create_and_check_t5_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
...@@ -293,7 +293,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -293,7 +293,6 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
def create_t5_and_check_t5_generate_with_past_key_value_states( def create_t5_and_check_t5_generate_with_past_key_value_states(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels, self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, lm_labels,
): ):
config.num_layers = 1
model = T5ForConditionalGeneration(config=config) model = T5ForConditionalGeneration(config=config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
......
...@@ -191,7 +191,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -191,7 +191,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
output_from_past_slice = output_from_past[:, 0, random_slice_idx] output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice # test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-12) tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_gpt2_model_attention_mask_past( def create_and_check_gpt2_model_attention_mask_past(
self, config, input_ids, input_mask, head_mask, token_type_ids, *args self, config, input_ids, input_mask, head_mask, token_type_ids, *args
......
...@@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_tf, slow ...@@ -24,6 +24,7 @@ from .utils import CACHE_DIR, require_tf, slow
if is_tf_available(): if is_tf_available():
import tensorflow as tf
from transformers import TFT5Model, TFT5ForConditionalGeneration, T5Tokenizer from transformers import TFT5Model, TFT5ForConditionalGeneration, T5Tokenizer
...@@ -111,14 +112,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -111,14 +112,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
"decoder_input_ids": input_ids, "decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask, "decoder_attention_mask": input_mask,
} }
encoder_output, decoder_output = model(inputs) decoder_output, decoder_past, encoder_output = model(inputs)
encoder_output, decoder_output = model( decoder_output, decoder_past, encoder_output = model(
input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids input_ids, decoder_attention_mask=input_mask, decoder_input_ids=input_ids
) )
result = { result = {
"encoder_output": encoder_output.numpy(), "encoder_output": encoder_output.numpy(),
"decoder_past": decoder_past,
"decoder_output": decoder_output.numpy(), "decoder_output": decoder_output.numpy(),
} }
self.parent.assertListEqual( self.parent.assertListEqual(
...@@ -127,6 +128,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -127,6 +128,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
self.parent.assertListEqual( self.parent.assertListEqual(
list(result["decoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size] list(result["decoder_output"].shape), [self.batch_size, self.seq_length, self.hidden_size]
) )
self.parent.assertEqual(len(decoder_past), 2)
# decoder_past[0] should correspond to encoder output
self.parent.assertTrue(tf.reduce_all(tf.math.equal(decoder_past[0][0], encoder_output)))
# There should be `num_layers` key value embeddings stored in decoder_past[1]
self.parent.assertEqual(len(decoder_past[1]), config.num_layers)
# There should be a self attn key, a self attn value, a cross attn key and a cross attn value stored in each decoder_past[1] tuple
self.parent.assertEqual(len(decoder_past[1][0]), 4)
def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels): def create_and_check_t5_with_lm_head(self, config, input_ids, input_mask, token_labels):
model = TFT5ForConditionalGeneration(config=config) model = TFT5ForConditionalGeneration(config=config)
...@@ -136,7 +144,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -136,7 +144,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
"decoder_attention_mask": input_mask, "decoder_attention_mask": input_mask,
} }
prediction_scores, decoder_output = model(inputs_dict) prediction_scores, _, _ = model(inputs_dict)
result = { result = {
"prediction_scores": prediction_scores.numpy(), "prediction_scores": prediction_scores.numpy(),
...@@ -145,6 +153,76 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -145,6 +153,76 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size] list(result["prediction_scores"].shape), [self.batch_size, self.seq_length, self.vocab_size]
) )
def create_and_check_t5_decoder_model_past(self, config, input_ids, decoder_input_ids, attention_mask):
model = TFT5Model(config=config).get_decoder()
input_ids = input_ids[:1, :]
self.batch_size = 1
# first forward pass
_, past_key_value_states = model(input_ids, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# append to next input_ids and
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
output_from_no_past = model(next_input_ids)[0]
output_from_past = model(next_tokens, past_key_value_states=past_key_value_states)[0]
# select random slice
random_slice_idx = int(ids_tensor((1,), output_from_past.shape[-1]))
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def create_and_check_t5_decoder_model_attention_mask_past(
self, config, input_ids, decoder_input_ids, attention_mask
):
model = TFT5Model(config=config).get_decoder()
# create attention mask
half_seq_length = self.seq_length // 2
attn_mask_begin = tf.ones((self.batch_size, half_seq_length), dtype=tf.int32)
attn_mask_end = tf.zeros((self.batch_size, self.seq_length - half_seq_length), dtype=tf.int32)
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass
_, past_key_value_states = model(input_ids, attention_mask=attn_mask, use_cache=True)
# create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
# change a random masked slice from input_ids
random_seq_idx_to_change = ids_tensor((1,), half_seq_length).numpy() + 1
random_other_next_tokens = ids_tensor((self.batch_size, self.seq_length), config.vocab_size)
vector_condition = tf.range(self.seq_length) == (self.seq_length - random_seq_idx_to_change)
condition = tf.transpose(
tf.broadcast_to(tf.expand_dims(vector_condition, -1), (self.seq_length, self.batch_size))
)
input_ids = tf.where(condition, random_other_next_tokens, input_ids)
# append to next input_ids and attn_mask
next_input_ids = tf.concat([input_ids, next_tokens], axis=-1)
attn_mask = tf.concat([attn_mask, tf.ones((attn_mask.shape[0], 1), dtype=tf.int32)], axis=1,)
# get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)[0]
output_from_past = model(
next_tokens, past_key_value_states=past_key_value_states, attention_mask=attn_mask
)[0]
# select random slice
random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).numpy().item()
output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx]
output_from_past_slice = output_from_past[:, 0, random_slice_idx]
# test that outputs are equal for slice
tf.debugging.assert_near(output_from_past_slice, output_from_no_past_slice, rtol=1e-6)
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs() config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, token_labels) = config_and_inputs (config, input_ids, input_mask, token_labels) = config_and_inputs
...@@ -152,6 +230,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -152,6 +230,7 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
"inputs": input_ids, "inputs": input_ids,
"decoder_input_ids": input_ids, "decoder_input_ids": input_ids,
"decoder_attention_mask": input_mask, "decoder_attention_mask": input_mask,
"use_cache": tf.convert_to_tensor([False]),
} }
return config, inputs_dict return config, inputs_dict
...@@ -170,6 +249,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -170,6 +249,14 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs) self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
def test_t5_decoder_model_past(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_past(*config_and_inputs)
def test_t5_decoder_model_past_with_attn_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_decoder_model_attention_mask_past(*config_and_inputs)
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in ["t5-small"]: for model_name in ["t5-small"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment