Unverified Commit f26e4073 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Cache: models return input cache type (#30716)

parent 71c19850
...@@ -881,7 +881,9 @@ class CohereModel(CoherePreTrainedModel): ...@@ -881,7 +881,9 @@ class CohereModel(CoherePreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
past_seen_tokens = 0 past_seen_tokens = 0
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None: if cache_position is None:
...@@ -943,11 +945,10 @@ class CohereModel(CoherePreTrainedModel): ...@@ -943,11 +945,10 @@ class CohereModel(CoherePreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None next_cache = next_decoder_cache if use_cache else None
if use_cache: if return_legacy_cache:
next_cache = ( next_cache = next_cache.to_legacy_cache()
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
......
...@@ -1115,7 +1115,9 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1115,7 +1115,9 @@ class DbrxModel(DbrxPreTrainedModel):
inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None: if cache_position is None:
...@@ -1182,13 +1184,10 @@ class DbrxModel(DbrxPreTrainedModel): ...@@ -1182,13 +1184,10 @@ class DbrxModel(DbrxPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None next_cache = next_decoder_cache if use_cache else None
if use_cache: if return_legacy_cache:
next_cache = ( next_cache = next_cache.to_legacy_cache()
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
......
...@@ -865,7 +865,9 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -865,7 +865,9 @@ class GemmaModel(GemmaPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None: if cache_position is None:
...@@ -933,13 +935,10 @@ class GemmaModel(GemmaPreTrainedModel): ...@@ -933,13 +935,10 @@ class GemmaModel(GemmaPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None next_cache = next_decoder_cache if use_cache else None
if use_cache: if return_legacy_cache:
next_cache = ( next_cache = next_cache.to_legacy_cache()
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
......
...@@ -960,7 +960,9 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -960,7 +960,9 @@ class LlamaModel(LlamaPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None: if cache_position is None:
...@@ -1021,13 +1023,10 @@ class LlamaModel(LlamaPreTrainedModel): ...@@ -1021,13 +1023,10 @@ class LlamaModel(LlamaPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None next_cache = next_decoder_cache if use_cache else None
if use_cache: if return_legacy_cache:
next_cache = ( next_cache = next_cache.to_legacy_cache()
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
......
...@@ -938,7 +938,9 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -938,7 +938,9 @@ class OlmoModel(OlmoPreTrainedModel):
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if cache_position is None: if cache_position is None:
...@@ -999,13 +1001,10 @@ class OlmoModel(OlmoPreTrainedModel): ...@@ -999,13 +1001,10 @@ class OlmoModel(OlmoPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states += (hidden_states,) all_hidden_states += (hidden_states,)
next_cache = None next_cache = next_decoder_cache if use_cache else None
if use_cache: if return_legacy_cache:
next_cache = ( next_cache = next_cache.to_legacy_cache()
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
if not return_dict: if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast( return BaseModelOutputWithPast(
......
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
import unittest import unittest
from parameterized import parameterized
from transformers import CohereConfig, is_torch_available from transformers import CohereConfig, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
require_bitsandbytes, require_bitsandbytes,
...@@ -296,11 +294,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix ...@@ -296,11 +294,6 @@ class CohereModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@unittest.skip("TODO @gante fix this for Cohere")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def test_model(self): def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
......
...@@ -17,8 +17,6 @@ ...@@ -17,8 +17,6 @@
import unittest import unittest
from parameterized import parameterized
from transformers import DbrxConfig, is_torch_available from transformers import DbrxConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
...@@ -357,11 +355,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -357,11 +355,6 @@ class DbrxModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
def test_tied_weights_keys(self): def test_tied_weights_keys(self):
pass pass
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch @require_torch
class DbrxModelIntegrationTest(unittest.TestCase): class DbrxModelIntegrationTest(unittest.TestCase):
......
...@@ -17,7 +17,6 @@ import tempfile ...@@ -17,7 +17,6 @@ import tempfile
import unittest import unittest
import pytest import pytest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.testing_utils import ( from transformers.testing_utils import (
...@@ -367,11 +366,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -367,11 +366,6 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels) result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels)) self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@unittest.skip("Gemma buffers include complex numbers, which breaks this test") @unittest.skip("Gemma buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
...@@ -591,11 +591,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi ...@@ -591,11 +591,6 @@ class LlamaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}", msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
) )
@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch_gpu @require_torch_gpu
class LlamaIntegrationTest(unittest.TestCase): class LlamaIntegrationTest(unittest.TestCase):
......
...@@ -353,11 +353,6 @@ class OlmoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -353,11 +353,6 @@ class OlmoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# The output should be different for long inputs # The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))
@unittest.skip("TODO @gante fix this for OLMo")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
@require_torch @require_torch
class OlmoIntegrationTest(unittest.TestCase): class OlmoIntegrationTest(unittest.TestCase):
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
""" Testing suite for the PyTorch RecurrentGemma model. """ """ Testing suite for the PyTorch RecurrentGemma model. """
import unittest import unittest
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import ( from transformers.testing_utils import (
require_bitsandbytes, require_bitsandbytes,
...@@ -330,11 +328,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT ...@@ -330,11 +328,6 @@ class RecurrentGemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
config_and_inputs[0].position_embedding_type = type config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip("Recurrent gemma does not use legacy cache")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass
def test_save_load_fast_init_from_base(self): def test_save_load_fast_init_from_base(self):
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment