Unverified Commit a3fb96a4 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: fix assisted generation with `past_key_values` passed as kwargs (#31644)

parent 492ee17e
......@@ -395,21 +395,21 @@ class DynamicCache(Cache):
cache.update(key_states, value_states, layer_idx)
return cache
def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
def crop(self, max_length: int):
"""Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
negative to remove `max_length` tokens. This is used in assisted decoding and contrastive search."""
# In case it is negative
if maximum_length < 0:
maximum_length = self.get_seq_length() - abs(maximum_length)
if max_length < 0:
max_length = self.get_seq_length() - abs(max_length)
if self.get_seq_length() <= maximum_length:
if self.get_seq_length() <= max_length:
return
self._seen_tokens = maximum_length
self._seen_tokens = max_length
for idx in range(len(self.key_cache)):
self.key_cache[idx] = self.key_cache[idx][..., :maximum_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :maximum_length, :]
self.key_cache[idx] = self.key_cache[idx][..., :max_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :max_length, :]
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
......
......@@ -111,24 +111,11 @@ class AssistedCandidateGenerator(CandidateGenerator):
# Prepare the kwargs for the assistant model
assistant_kwargs = {}
for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
if key not in ("encoder_outputs", "assistant_encoder_outputs"):
if key not in ("encoder_outputs", "assistant_encoder_outputs", "past_key_values"):
assistant_kwargs[key] = (
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)
# Remove potential default DynamicCache if assistant does not support it
if "past_key_values" in assistant_kwargs.keys():
if (
isinstance(assistant_kwargs["past_key_values"], DynamicCache)
and not self.assistant_model._supports_cache_class
):
# Cache is empty -> remove it from kwargs
if len(assistant_kwargs["past_key_values"]) == 0:
del assistant_kwargs["past_key_values"]
# Cache is not empty -> convert to legacy
else:
assistant_kwargs["past_key_values"] = assistant_kwargs["past_key_values"].to_legacy_cache()
if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
elif assistant_model.config.is_encoder_decoder:
......@@ -363,15 +350,15 @@ class PromptLookupCandidateGenerator(CandidateGenerator):
return
def _crop_past_key_values(model, past_key_values, maximum_length):
def _crop_past_key_values(model, past_key_values, max_length):
"""Crops the past key values up to a certain maximum length."""
new_past = []
if model.config.is_encoder_decoder:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
past_key_values[idx][0][:, :, :max_length, :],
past_key_values[idx][1][:, :, :max_length, :],
past_key_values[idx][2],
past_key_values[idx][3],
)
......@@ -384,8 +371,8 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length],
past_key_values[idx][1][:, :maximum_length, :],
past_key_values[idx][0][:, :, :max_length],
past_key_values[idx][1][:, :max_length, :],
)
)
past_key_values = tuple(new_past)
......@@ -395,19 +382,19 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
):
if model.config.multi_query:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :maximum_length, :]
past_key_values[idx] = past_key_values[idx][:, :max_length, :]
else:
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
elif isinstance(past_key_values, DynamicCache):
past_key_values.crop(maximum_length)
past_key_values.crop(max_length)
elif past_key_values is not None:
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :maximum_length, :],
past_key_values[idx][1][:, :, :maximum_length, :],
past_key_values[idx][0][:, :, :max_length, :],
past_key_values[idx][1][:, :, :max_length, :],
)
)
past_key_values = tuple(new_past)
......
......@@ -3697,11 +3697,10 @@ class GenerationMixin:
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
# This is needed if return_dict_in_generate is True
start_from_empty_dynamic_cache = False
if isinstance(model_kwargs.get("past_key_values", None), DynamicCache):
if len(model_kwargs["past_key_values"]) == 0:
start_from_empty_dynamic_cache = True
else:
start_from_empty_dynamic_cache = False
this_peer_finished = False
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment