Unverified Commit 948ffff4 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

RWKV: raise informative exception when attempting to manipulate `past_key_values` (#28600)

parent 9efec114
......@@ -778,6 +778,24 @@ class RwkvForCausalLM(RwkvPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.head = new_embeddings
def generate(self, *args, **kwargs):
# Thin wrapper to raise exceptions when trying to generate with methods that manipulate `past_key_values`.
# RWKV is one of the few models that don't have it (it has `state` instead, which has different properties and
# usage).
try:
gen_output = super().generate(*args, **kwargs)
except AttributeError as exc:
# Expected exception: "AttributeError: '(object name)' object has no attribute 'past_key_values'"
if "past_key_values" in str(exc):
raise AttributeError(
"You tried to call `generate` with a decoding strategy that manipulates `past_key_values`. RWKV "
"doesn't have that attribute, try another generation strategy instead. For the available "
"generation strategies, check this doc: https://huggingface.co/docs/transformers/en/generation_strategies#decoding-strategies"
)
else:
raise exc
return gen_output
def prepare_inputs_for_generation(self, input_ids, state=None, inputs_embeds=None, **kwargs):
# only last token for inputs_ids if the state is passed along.
if state is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment