Unverified Commit 1fd60fec authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

RWKV: enable generation tests (#31490)

* add rwkv tests

* has_attentions set in individual tests
parent d28e647f
...@@ -625,6 +625,9 @@ class RwkvModel(RwkvPreTrainedModel): ...@@ -625,6 +625,9 @@ class RwkvModel(RwkvPreTrainedModel):
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False) use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if attention_mask is None:
logger.warning_once("`attention_mask` was passed, but it is unused in this model.")
if self.training == self.layers_are_rescaled: if self.training == self.layers_are_rescaled:
self._rescale_layers() self._rescale_layers()
...@@ -765,24 +768,6 @@ class RwkvForCausalLM(RwkvPreTrainedModel): ...@@ -765,24 +768,6 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.head = new_embeddings self.head = new_embeddings
def generate(self, *args, **kwargs):
# Thin wrapper to raise exceptions when trying to generate with methods that manipulate `past_key_values`.
# RWKV is one of the few models that don't have it (it has `state` instead, which has different properties and
# usage).
try:
gen_output = super().generate(*args, **kwargs)
except AttributeError as exc:
# Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
if "past_key_values" in str(exc):
raise AttributeError(
"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`. RWKV "
"doesn't have that attribute, try another generation strategy instead. For the available "
"generation strategies, check this doc: https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exc
return gen_output
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs): def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
# only last token for inputs_ids if the state is passed along. # only last token for inputs_ids if the state is passed along.
if state is not None: if state is not None:
......
...@@ -464,6 +464,8 @@ class GenerationTesterMixin: ...@@ -464,6 +464,8 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest("This model doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
config.use_cache = True config.use_cache = True
config.is_decoder = True config.is_decoder = True
...@@ -624,6 +626,8 @@ class GenerationTesterMixin: ...@@ -624,6 +626,8 @@ class GenerationTesterMixin:
if not hasattr(config, "use_cache"): if not hasattr(config, "use_cache"):
self.skipTest("This model doesn't support caching") self.skipTest("This model doesn't support caching")
if any(model_name in model_class.__name__.lower() for model_name in ["rwkv"]):
self.skipTest("Won't fix: model with non-standard dictionary output shapes")
model = model_class(config).to(torch_device).eval() model = model_class(config).to(torch_device).eval()
logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs( logits_process_kwargs, _ = self._get_logits_processor_and_warper_kwargs(
......
...@@ -269,7 +269,7 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -269,7 +269,7 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
pipeline_model_mapping = ( pipeline_model_mapping = (
{"feature-extraction": RwkvModel, "text-generation": RwkvForCausalLM} if is_torch_available() else {} {"feature-extraction": RwkvModel, "text-generation": RwkvForCausalLM} if is_torch_available() else {}
) )
# all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else () all_generative_model_classes = (RwkvForCausalLM,) if is_torch_available() else ()
fx_compatible = False fx_compatible = False
test_missing_keys = False test_missing_keys = False
test_model_parallel = False test_model_parallel = False
...@@ -422,6 +422,52 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin ...@@ -422,6 +422,52 @@ class RwkvModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
model = RwkvModel.from_pretrained(model_name) model = RwkvModel.from_pretrained(model_name)
self.assertIsNotNone(model) self.assertIsNotNone(model)
def test_beam_sample_generate_dict_output(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_beam_sample_generate_dict_output()
self.has_attentions = old_has_attentions
def test_beam_search_generate_dict_output(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_beam_search_generate_dict_output()
self.has_attentions = old_has_attentions
def test_constrained_beam_search_generate_dict_output(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_constrained_beam_search_generate_dict_output()
self.has_attentions = old_has_attentions
def test_greedy_generate_dict_outputs(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_greedy_generate_dict_outputs()
self.has_attentions = old_has_attentions
def test_group_beam_search_generate_dict_output(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_group_beam_search_generate_dict_output()
self.has_attentions = old_has_attentions
def test_sample_generate_dict_output(self):
# This model has a custom attention output shape AND config flags, let's skip those checks
old_has_attentions = self.has_attentions
self.has_attentions = False
super().test_sample_generate_dict_output()
self.has_attentions = old_has_attentions
@unittest.skip("This model doesn't support padding")
def test_left_padding_compatibility(self):
pass
@unittest.skipIf( @unittest.skipIf(
not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204" not is_torch_greater_or_equal_than_2_0, reason="See https://github.com/huggingface/transformers/pull/24204"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment