Unverified Commit fdb12080 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix cache for GPT-Neo-X (#17764)

* Fix cache for GPT-Neo-X

* Add more tests
parent a2d34b7c
......@@ -2458,7 +2458,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if offload_state_dict:
# Load back temporarily offloaded state dict
load_offloaded_weights(model, state_dict_index, state_dict_folder)
load_offloaded_weights(model_to_load, state_dict_index, state_dict_folder)
shutil.rmtree(state_dict_folder)
if len(error_msgs) > 0:
......
......@@ -143,7 +143,7 @@ class GPTNeoXAttention(nn.Module):
past_value = layer_past[1]
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)
present = None if use_cache else (key, value)
present = (key, value) if use_cache else None
# Compute attention
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
......
......@@ -218,6 +218,14 @@ class GPTNeoXModelTest(ModelTesterMixin, unittest.TestCase):
self.model_tester.create_and_check_model_as_decoder(config, input_ids, input_mask)
def test_decoder_model_past_large_inputs(self):
config, input_ids, input_mask, token_labels = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(config, input_ids, input_mask)
def test_model_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
for model_name in GPT_NEOX_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment