"...pdfjs-dist/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "3a42ebbf5781d6c6408324edeac9d704ca41e6b6"
Unverified Commit bd5091df authored by Cyril Vallez's avatar Cyril Vallez Committed by GitHub
Browse files

Reduce by 2 the memory requirement in `generate()` 🔥🔥🔥 (#30536)

* Fix contrastive_search for new cache structure, and improve performance by removing inneficient torch.stack(torch.split(x, top_k, dim=0))

* Fix _contrastive_search for non-standard cache using ellipsis slicing

* Fix all outputs.logits memory leaks for all decoding strategies!

* Fix small error in _contrastive_search()

* Make all necessary change and revert for the new class

* Apply coding style

* Remove pipes in type hints for compatibility

* correct type hint

* apply style

* Use DynamicCache by default and solve conflicts

* Fix rebase issues

* Add `_supports_dynamic_cache_class` in models for models that support DynamicCache but not other caches to make DynamicCache the default for more models

* Create generation config to return legacy format by default, or to choose not to

* style

* Fix case when use_cache is False

* Remove default DynamicCache in assiste_decoding if assistant_model does not support it + fix _seen_tokens when cropping cache

* Update prepare_inputs_for_generation() for case with empty DynamicCache

* Correct return of args in _assisted_decoding

* Remove EfficientDynamicCache as it is no longer needed

* Correct mistake in generation config

* Move cache logic of assisted decoding to AssistedCandidateGenerator.__init__

* change DynamicCache function names from "split" to "batch_split" for readability + apply coding style

* Remove `_supports_dynamic_cache_class` attribute after rebase

* Correct missing line lost in conflict resolution during rebasing

* Add special case for Jamba

* Fix jamba test

* Coding style

* coding style

* Correct missing import in rebasing

* Simplify _validate_model_kwargs based on removal of _supports_dynamic_cache attribute

* Simplify code paths in _contrastive_search

* coding style

* Update docstrings of cache methods

* Update prepare_inputs_for_generation() -> past_key_values are always Cache objects
parent d6276f0f
......@@ -377,7 +377,8 @@ class DynamicCache(Cache):
return None
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format."""
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = ()
for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
......@@ -385,7 +386,8 @@ class DynamicCache(Cache):
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`."""
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls()
if past_key_values is not None:
for layer_idx in range(len(past_key_values)):
......@@ -393,6 +395,57 @@ class DynamicCache(Cache):
cache.update(key_states, value_states, layer_idx)
return cache
def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
# In case it is negative
if maximum_length < 0:
maximum_length = self.get_seq_length() - abs(maximum_length)
if self.get_seq_length() <= maximum_length:
return
self._seen_tokens = maximum_length
for idx in range(len(self.key_cache)):
self.key_cache[idx] = self.key_cache[idx][..., :maximum_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :maximum_length, :]
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicCache()
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out
@classmethod
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
cache.update(layer_keys, layer_values, idx)
return cache
def batch_repeat_interleave(self, repeats: int):
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
def batch_select_indices(self, indices: torch.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
class QuantizedCache(DynamicCache):
"""
......
......@@ -116,6 +116,19 @@ class AssistedCandidateGenerator(CandidateGenerator):
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
)
# Remove potential default DynamicCache if assistant does not support it
if "past_key_values" in assistant_kwargs.keys():
if (
isinstance(assistant_kwargs["past_key_values"], DynamicCache)
and not self.assistant_model._supports_cache_class
):
# Cache is empty -> remove it from kwargs
if len(assistant_kwargs["past_key_values"]) == 0:
del assistant_kwargs["past_key_values"]
# Cache is not empty -> convert to legacy
else:
assistant_kwargs["past_key_values"] = assistant_kwargs["past_key_values"].to_legacy_cache()
if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
elif assistant_model.config.is_encoder_decoder:
......@@ -387,10 +400,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
elif isinstance(past_key_values, DynamicCache):
for idx in range(len(past_key_values.key_cache)):
if past_key_values.value_cache[idx].shape[-1] != 0:
past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :]
past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :]
past_key_values.crop(maximum_length)
elif past_key_values is not None:
for idx in range(len(past_key_values)):
......
......@@ -313,6 +313,8 @@ class GenerationConfig(PushToHubMixin):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
return_legacy_cache (`bool`, *optional*, default to `True`):
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
> Wild card
......@@ -404,6 +406,7 @@ class GenerationConfig(PushToHubMixin):
self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", True)
# Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
......
......@@ -1448,6 +1448,16 @@ class GenerationMixin:
else:
return
def _supports_default_dynamic_cache(self) -> bool:
"""
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which
uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in
order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed
for `HybridMambaAttentionDynamicCache`).
"""
return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower()
def _prepare_special_tokens(
self,
generation_config: GenerationConfig,
......@@ -1709,6 +1719,7 @@ class GenerationMixin:
input_ids_length=input_ids_length,
)
use_dynamic_cache_by_default = False
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
raise ValueError(
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
......@@ -1750,6 +1761,16 @@ class GenerationMixin:
)
model_kwargs["past_key_values"] = cache_class(cache_config)
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get("past_key_values", None)
if past is None:
model_kwargs["past_key_values"] = DynamicCache()
use_dynamic_cache_by_default = True
elif isinstance(past, tuple):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past)
use_dynamic_cache_by_default = True
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
......@@ -2018,6 +2039,11 @@ class GenerationMixin:
**model_kwargs,
)
# Convert to legacy cache if needed
if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
if isinstance(result.past_key_values, DynamicCache):
result.past_key_values = result.past_key_values.to_legacy_cache()
return result
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
......@@ -2185,7 +2211,10 @@ class GenerationMixin:
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None:
if model_kwargs.get("past_key_values") is None or (
isinstance(model_kwargs["past_key_values"], Cache)
and model_kwargs["past_key_values"].get_seq_length() == 0
):
# prepare inputs
model_kwargs["use_cache"] = True
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
......@@ -2204,7 +2233,9 @@ class GenerationMixin:
last_hidden_states = outputs.hidden_states[-1]
# next logit for contrastive search to select top-k candidate tokens
logit_for_next_step = outputs.logits[:, -1, :]
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
# (the clone itself is always small)
logit_for_next_step = outputs.logits[:, -1, :].clone()
model_kwargs = self._update_model_kwargs_for_generation(
outputs,
......@@ -2212,6 +2243,7 @@ class GenerationMixin:
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
)
if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self._expand_inputs_for_generation(
......@@ -2261,25 +2293,28 @@ class GenerationMixin:
else (outputs.hidden_states,)
)
# Replicates the new past_key_values to match the `top_k` candidates
new_key_values = []
past = model_kwargs["past_key_values"]
for layer in past:
items = []
# item is either the key or the value matrix
for item in layer:
if sequential:
items.append(item.repeat_interleave(1, dim=0))
else:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(tuple(items))
if not isinstance(past, DynamicCache):
past = tuple(new_key_values)
else:
for layer_idx in range(len(new_key_values)):
past.key_cache[layer_idx] = new_key_values[layer_idx][0]
past.value_cache[layer_idx] = new_key_values[layer_idx][1]
model_kwargs["past_key_values"] = past
# This is needed to properly delete outputs.logits which may be very large for this first iteration
# Otherwise a reference to outputs.logits is kept all along until after the next call to self.forward()
del outputs
if not sequential:
# Replicates the new past_key_values to match the `top_k` candidates
past = model_kwargs["past_key_values"]
# If it is a static cache, modify it in-place layer after layer to save memory
if isinstance(past, DynamicCache):
past.batch_repeat_interleave(top_k)
else:
new_key_values = []
for layer in past:
items = []
# item is either the key or the value matrix
for item in layer:
items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(tuple(items))
past = tuple(new_key_values)
model_kwargs["past_key_values"] = past
if sequential:
all_outputs = []
......@@ -2293,6 +2328,12 @@ class GenerationMixin:
output_hidden_states=True,
output_attentions=output_attentions,
)
if isinstance(outputs["past_key_values"], DynamicCache):
# Remove past K-V from output since we don't need to stack later
outputs["past_key_values"] = None
# Remove last token from past K-V since we don't want to append it at this point
model_kwargs["past_key_values"].crop(-1)
all_outputs.append(outputs)
outputs = stack_model_outputs(all_outputs)
......@@ -2307,6 +2348,11 @@ class GenerationMixin:
output_hidden_states=True,
output_attentions=output_attentions,
)
# This is essential to avoid having a last reference to the big past K-V and double the necesary memory
# in the next loop
del next_model_inputs
# name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1]
......@@ -2316,7 +2362,6 @@ class GenerationMixin:
full_hidden_states = outputs.hidden_states
logits = outputs.logits[:, -1, :]
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
......@@ -2325,6 +2370,9 @@ class GenerationMixin:
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
selected_idx = selected_idx.to("cpu")
# This will be used instead of the previous inneficient torch.stack(torch.split())
augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)])
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
# (model confidence minus degeneration penalty); (6) decoder hidden_states
......@@ -2354,22 +2402,19 @@ class GenerationMixin:
else:
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
new_key_values = []
for layer in next_past_key_values:
items = []
# item is either the key or the value matrix
for item in layer:
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz]
item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz]
items += [item]
new_key_values += [items]
if not isinstance(next_past_key_values, DynamicCache):
next_past_key_values = tuple(new_key_values)
# Do it in-place layer per layer to save memory
if isinstance(next_past_key_values, DynamicCache):
next_past_key_values.batch_select_indices(augmented_idx)
else:
for layer_idx in range(len(new_key_values)):
next_past_key_values.key_cache[layer_idx] = new_key_values[layer_idx][0]
next_past_key_values.value_cache[layer_idx] = new_key_values[layer_idx][1]
new_key_values = []
for layer in next_past_key_values:
items = []
# item is either the key or the value matrix
for item in layer:
items.append(item[augmented_idx, ...])
new_key_values.append(tuple(items))
next_past_key_values = tuple(new_key_values)
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
......@@ -2431,13 +2476,16 @@ class GenerationMixin:
# Contrastive search works by forward looking at the next token, so we need to exclude it from
# `past_key_values` to be consistent with the other decoding methods
if model_kwargs.get("past_key_values") is not None:
past_key_values = []
for layer in model_kwargs["past_key_values"]:
layer_past_key_values = []
for item in layer:
layer_past_key_values.append(item[..., :-1, :])
past_key_values.append(tuple(layer_past_key_values))
model_kwargs["past_key_values"] = tuple(past_key_values)
if isinstance(model_kwargs["past_key_values"], DynamicCache):
model_kwargs["past_key_values"].crop(-1)
else:
past_key_values = []
for layer in model_kwargs["past_key_values"]:
layer_past_key_values = []
for item in layer:
layer_past_key_values.append(item[..., :-1, :])
past_key_values.append(tuple(layer_past_key_values))
model_kwargs["past_key_values"] = tuple(past_key_values)
if self.config.is_encoder_decoder:
return GenerateEncoderDecoderOutput(
......@@ -2588,7 +2636,9 @@ class GenerationMixin:
if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
......@@ -2639,6 +2689,10 @@ class GenerationMixin:
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
del outputs
if streamer is not None:
streamer.end()
......@@ -2846,7 +2900,9 @@ class GenerationMixin:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
......@@ -2922,6 +2978,13 @@ class GenerationMixin:
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
# (that way the memory peak does not include outputs.logits)
del outputs
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
......@@ -3125,7 +3188,9 @@ class GenerationMixin:
if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :])
if output_logits:
raw_logit_score = outputs.logits[:, -1, :]
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
raw_logit_score = outputs.logits[:, -1, :].clone()
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
......@@ -3142,6 +3207,7 @@ class GenerationMixin:
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of current group only
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
next_token_logits = outputs.logits[batch_group_indices, -1, :]
next_token_scores = nn.functional.log_softmax(
......@@ -3231,6 +3297,13 @@ class GenerationMixin:
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
# (that way the memory peak does not include outputs.logits)
del outputs
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], reordering_indices
......@@ -3393,7 +3466,9 @@ class GenerationMixin:
cur_len = cur_len + 1
continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :]
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size)
......@@ -3461,6 +3536,13 @@ class GenerationMixin:
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
)
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
# (that way the memory peak does not include outputs.logits)
del outputs
if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx
......@@ -3597,6 +3679,13 @@ class GenerationMixin:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
# This is needed if return_dict_in_generate is True
if isinstance(model_kwargs.get("past_key_values", None), DynamicCache):
if len(model_kwargs["past_key_values"]) == 0:
start_from_empty_dynamic_cache = True
else:
start_from_empty_dynamic_cache = False
this_peer_finished = False
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[-1]
......@@ -3709,8 +3798,10 @@ class GenerationMixin:
if output_logits:
raw_logits += (next_token_logits,)
if "past_key_values" not in model_kwargs:
if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache:
added_len = new_cur_len
# set it to false for other iterations
start_from_empty_dynamic_cache = False
else:
added_len = n_matches + 1
......@@ -3909,6 +4000,9 @@ def _split(data, full_batch_size: int, split_size: int = None):
return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
# New cache format
elif isinstance(data, DynamicCache):
return data.batch_split(full_batch_size, split_size)
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple):
......@@ -4012,6 +4106,9 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
return None
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
# New cache format
elif isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
......
......@@ -1167,18 +1167,14 @@ class CohereForCausalLM(CoherePreTrainedModel):
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1208,7 +1204,7 @@ class CohereForCausalLM(CoherePreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
......
......@@ -1443,18 +1443,14 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1484,7 +1480,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
......
......@@ -1163,18 +1163,14 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1204,7 +1200,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
......
......@@ -1876,15 +1876,13 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1915,7 +1913,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
......
......@@ -1218,18 +1218,14 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1259,7 +1255,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
......
......@@ -1240,21 +1240,17 @@ class MistralForCausalLM(MistralPreTrainedModel):
use_cache=True,
**kwargs,
):
# Omit tokens covered by past_key_values
past_length = 0
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1294,7 +1290,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
attention_mask = attention_mask[:, -past_key_values.max_cache_len :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()}
......
......@@ -1407,15 +1407,13 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
output_router_logits=False,
**kwargs,
):
past_length = 0
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1446,7 +1444,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
......
......@@ -1198,18 +1198,14 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1239,7 +1235,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
......
......@@ -832,14 +832,12 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -870,7 +868,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
......
......@@ -1211,14 +1211,12 @@ class PhiForCausalLM(PhiPreTrainedModel):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1249,7 +1247,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
......
......@@ -1299,14 +1299,12 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1337,7 +1335,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
......
......@@ -1199,15 +1199,13 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1238,7 +1236,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
......
......@@ -1394,15 +1394,13 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1433,7 +1431,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
......
......@@ -1208,14 +1208,12 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1246,7 +1244,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
......
......@@ -1182,15 +1182,13 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
):
past_length = 0
# Omit tokens covered by past_key_values
if past_key_values is not None:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length()
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
......@@ -1221,7 +1219,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment