"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "4cb5ffa93d400636b6809563dc806b64a9b9550d"
Unverified Commit 4fda78c3 authored by Zhakshylyk Nurlanov's avatar Zhakshylyk Nurlanov Committed by GitHub
Browse files

Fix `cache_position` initialisation for generation with `use_cache=False` (#30485)



* Fix cache_position init for generation

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* Fix cache position update

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 54a2361a
......@@ -667,7 +667,11 @@ class GenerationMixin:
dim=-1,
)
if "cache_position" in model_kwargs and model_kwargs["cache_position"] is not None:
if (
model_kwargs.get("use_cache", True)
and "cache_position" in model_kwargs
and model_kwargs["cache_position"] is not None
):
model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + num_new_tokens
return model_kwargs
......@@ -1293,6 +1297,10 @@ class GenerationMixin:
def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
if not model_kwargs.get("use_cache", True):
model_kwargs["cache_position"] = None
return model_kwargs
past_length = 0
if "past_key_values" in model_kwargs:
if isinstance(model_kwargs["past_key_values"], Cache):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment