Unverified Commit 4fda78c3 authored by Zhakshylyk Nurlanov's avatar Zhakshylyk Nurlanov Committed by GitHub
Browse files

Fix `cache_position` initialisation for generation with `use_cache=False` (#30485)



* Fix cache_position init for generation

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix cache position update

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 54a2361a
...@@ -667,7 +667,11 @@ class GenerationMixin: ...@@ -667,7 +667,11 @@ class GenerationMixin:
dim=-1, dim=-1,
) )
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None: if (
model_kwargs.get("use_cache", True)
and "cache_position" in model_kwargs
and model_kwargs["cache_position"] is not None
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
return model_kwargs return model_kwargs
...@@ -1293,6 +1297,10 @@ class GenerationMixin: ...@@ -1293,6 +1297,10 @@ class GenerationMixin:
def _get_initial_cache_position(self, input_ids, model_kwargs): def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
if not model_kwargs.get("use_cache", True):
model_kwargs["cache_position"] = None
return model_kwargs
past_length = 0 past_length = 0
if "past_key_values" in model_kwargs: if "past_key_values" in model_kwargs:
if isinstance(model_kwargs["past_key_values"], Cache): if isinstance(model_kwargs["past_key_values"], Cache):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment