Unverified Commit 12eb528b authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[CI ] Remove `past` in favor of `pat_key_values` (#21443)

* fix past renamed to past_key_value

* update more `past`that were ski^êd

* fixup

* remove changes made to rag

* refactor `_reorder_cache` to use `past_key_values`

* fix git `prepare_inputs_for_generation` to pass tests when false is needed in use_cache
parent 5b493762
...@@ -148,7 +148,7 @@ class TFGPTJModelTester: ...@@ -148,7 +148,7 @@ class TFGPTJModelTester:
self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf))
self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1)
output, past = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -159,7 +159,9 @@ class TFGPTJModelTester: ...@@ -159,7 +159,9 @@ class TFGPTJModelTester:
next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1) next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1)
output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"] output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"]
output_from_past = model(next_tokens, token_type_ids=next_token_types, past=past)["last_hidden_state"] output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past_key_values)[
"last_hidden_state"
]
# select random slice # select random slice
random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1])) random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1]))
...@@ -181,7 +183,7 @@ class TFGPTJModelTester: ...@@ -181,7 +183,7 @@ class TFGPTJModelTester:
attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1)
# first forward pass # first forward pass
output, past = model(input_ids, attention_mask=attn_mask).to_tuple() output, past_key_values = model(input_ids, attention_mask=attn_mask).to_tuple()
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size)
...@@ -201,7 +203,9 @@ class TFGPTJModelTester: ...@@ -201,7 +203,9 @@ class TFGPTJModelTester:
# get two different outputs # get two different outputs
output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"]
output_from_past = model(next_tokens, past=past, attention_mask=attn_mask)["last_hidden_state"] output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[
"last_hidden_state"
]
# select random slice # select random slice
random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1])) random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1]))
...@@ -224,7 +228,7 @@ class TFGPTJModelTester: ...@@ -224,7 +228,7 @@ class TFGPTJModelTester:
# first forward pass # first forward pass
outputs = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, use_cache=True) outputs = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, use_cache=True)
output, past = outputs.to_tuple() output, past_key_values = outputs.to_tuple()
# create hypothetical next token and extent to next_input_ids # create hypothetical next token and extent to next_input_ids
next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size)
...@@ -240,7 +244,10 @@ class TFGPTJModelTester: ...@@ -240,7 +244,10 @@ class TFGPTJModelTester:
next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask
)["last_hidden_state"] )["last_hidden_state"]
output_from_past = model( output_from_past = model(
next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past=past next_tokens,
token_type_ids=next_token_types,
attention_mask=next_attention_mask,
past_key_values=past_key_values,
)["last_hidden_state"] )["last_hidden_state"]
self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1]) self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment