Unverified Commit e7b98370 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`Llama + AWQ`] fix `prepare_inputs_for_generation` 🫠 (#29381)

* use the generation config 🫠

* fixup
parent 50db7ca4
......@@ -1161,7 +1161,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
if self.generation_config.cache_implementation == "static":
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
......
......@@ -1277,7 +1277,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
if past_key_values:
position_ids = position_ids[:, -input_ids.shape[1] :]
if getattr(self.model.layers[0].self_attn, "past_key_value", None) is not None:
if self.generation_config.cache_implementation == "static":
# generation with static cache
cache_position = kwargs.get("cache_position", None)
if cache_position is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment