"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "758572cad8f9ebc927e9c9abc7ffdaef5aa42c2d"
Unverified Commit bd5091df authored by Cyril Vallez's avatar Cyril Vallez Committed by GitHub
Browse files

Reduce by 2 the memory requirement in `generate()` 🔥🔥🔥 (#30536)

* Fix contrastive_search for new cache structure, and improve performance by removing inneficient torch.stack(torch.split(x, top_k, dim=0))

* Fix _contrastive_search for non-standard cache using ellipsis slicing

* Fix all outputs.logits memory leaks for all decoding strategies!

* Fix small error in _contrastive_search()

* Make all necessary change and revert for the new class

* Apply coding style

* Remove pipes in type hints for compatibility

* correct type hint

* apply style

* Use DynamicCache by default and solve conflicts

* Fix rebase issues

* Add `_supports_dynamic_cache_class` in models for models that support DynamicCache but not other caches to make DynamicCache the default for more models

* Create generation config to return legacy format by default, or to choose not to

* style

* Fix case when use_cache is False

* Remove default DynamicCache in assiste_decoding if assistant_model does not support it + fix _seen_tokens when cropping cache

* Update prepare_inputs_for_generation() for case with empty DynamicCache

* Correct return of args in _assisted_decoding

* Remove EfficientDynamicCache as it is no longer needed

* Correct mistake in generation config

* Move cache logic of assisted decoding to AssistedCandidateGenerator.__init__

* change DynamicCache function names from "split" to "batch_split" for readability + apply coding style

* Remove `_supports_dynamic_cache_class` attribute after rebase

* Correct missing line lost in conflict resolution during rebasing

* Add special case for Jamba

* Fix jamba test

* Coding style

* coding style

* Correct missing import in rebasing

* Simplify _validate_model_kwargs based on removal of _supports_dynamic_cache attribute

* Simplify code paths in _contrastive_search

* coding style

* Update docstrings of cache methods

* Update prepare_inputs_for_generation() -> past_key_values are always Cache objects
parent d6276f0f
...@@ -377,7 +377,8 @@ class DynamicCache(Cache): ...@@ -377,7 +377,8 @@ class DynamicCache(Cache):
return None return None
def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
"""Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for
backward compatibility."""
legacy_cache = () legacy_cache = ()
for layer_idx in range(len(self)): for layer_idx in range(len(self)):
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
...@@ -385,7 +386,8 @@ class DynamicCache(Cache): ...@@ -385,7 +386,8 @@ class DynamicCache(Cache):
@classmethod @classmethod
def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
"""Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for
backward compatibility."""
cache = cls() cache = cls()
if past_key_values is not None: if past_key_values is not None:
for layer_idx in range(len(past_key_values)): for layer_idx in range(len(past_key_values)):
...@@ -393,6 +395,57 @@ class DynamicCache(Cache): ...@@ -393,6 +395,57 @@ class DynamicCache(Cache):
cache.update(key_states, value_states, layer_idx) cache.update(key_states, value_states, layer_idx)
return cache return cache
def crop(self, maximum_length: int):
"""Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search."""
# In case it is negative
if maximum_length < 0:
maximum_length = self.get_seq_length() - abs(maximum_length)
if self.get_seq_length() <= maximum_length:
return
self._seen_tokens = maximum_length
for idx in range(len(self.key_cache)):
self.key_cache[idx] = self.key_cache[idx][..., :maximum_length, :]
self.value_cache[idx] = self.value_cache[idx][..., :maximum_length, :]
def batch_split(self, full_batch_size: int, split_size: int) -> List["DynamicCache"]:
"""Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
`_split_model_inputs()` in `generation.utils`"""
out = []
for i in range(0, full_batch_size, split_size):
current_split = DynamicCache()
current_split._seen_tokens = self._seen_tokens
current_split.key_cache = [tensor[i : i + split_size] for tensor in self.key_cache]
current_split.value_cache = [tensor[i : i + split_size] for tensor in self.value_cache]
out.append(current_split)
return out
@classmethod
def from_batch_splits(cls, splits: List["DynamicCache"]) -> "DynamicCache":
"""This is the opposite of the above `batch_split()` method. This will be used by `stack_model_outputs` in
`generation.utils`"""
cache = cls()
for idx in range(len(splits[0])):
layer_keys = torch.cat([current.key_cache[idx] for current in splits], dim=0)
layer_values = torch.cat([current.value_cache[idx] for current in splits], dim=0)
cache.update(layer_keys, layer_values, idx)
return cache
def batch_repeat_interleave(self, repeats: int):
"""Repeat the cache `repeats` times in the batch dimension. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx].repeat_interleave(repeats, dim=0)
self.value_cache[layer_idx] = self.value_cache[layer_idx].repeat_interleave(repeats, dim=0)
def batch_select_indices(self, indices: torch.Tensor):
"""Only keep the `indices` in the batch dimension of the cache. Used in contrastive search."""
for layer_idx in range(len(self)):
self.key_cache[layer_idx] = self.key_cache[layer_idx][indices, ...]
self.value_cache[layer_idx] = self.value_cache[layer_idx][indices, ...]
class QuantizedCache(DynamicCache): class QuantizedCache(DynamicCache):
""" """
......
...@@ -116,6 +116,19 @@ class AssistedCandidateGenerator(CandidateGenerator): ...@@ -116,6 +116,19 @@ class AssistedCandidateGenerator(CandidateGenerator):
value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value) value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
) )
# Remove potential default DynamicCache if assistant does not support it
if "past_key_values" in assistant_kwargs.keys():
if (
isinstance(assistant_kwargs["past_key_values"], DynamicCache)
and not self.assistant_model._supports_cache_class
):
# Cache is empty -> remove it from kwargs
if len(assistant_kwargs["past_key_values"]) == 0:
del assistant_kwargs["past_key_values"]
# Cache is not empty -> convert to legacy
else:
assistant_kwargs["past_key_values"] = assistant_kwargs["past_key_values"].to_legacy_cache()
if "assistant_encoder_outputs" in model_kwargs: if "assistant_encoder_outputs" in model_kwargs:
assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"] assistant_kwargs["encoder_outputs"] = model_kwargs["assistant_encoder_outputs"]
elif assistant_model.config.is_encoder_decoder: elif assistant_model.config.is_encoder_decoder:
...@@ -387,10 +400,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length): ...@@ -387,10 +400,7 @@ def _crop_past_key_values(model, past_key_values, maximum_length):
for idx in range(len(past_key_values)): for idx in range(len(past_key_values)):
past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :] past_key_values[idx] = past_key_values[idx][:, :, :maximum_length, :]
elif isinstance(past_key_values, DynamicCache): elif isinstance(past_key_values, DynamicCache):
for idx in range(len(past_key_values.key_cache)): past_key_values.crop(maximum_length)
if past_key_values.value_cache[idx].shape[-1] != 0:
past_key_values.key_cache[idx] = past_key_values.key_cache[idx][:, :, :maximum_length, :]
past_key_values.value_cache[idx] = past_key_values.value_cache[idx][:, :, :maximum_length, :]
elif past_key_values is not None: elif past_key_values is not None:
for idx in range(len(past_key_values)): for idx in range(len(past_key_values)):
......
...@@ -313,6 +313,8 @@ class GenerationConfig(PushToHubMixin): ...@@ -313,6 +313,8 @@ class GenerationConfig(PushToHubMixin):
Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and Arguments used in the key-value cache class can be passed in `cache_config`. Can be passed as a `Dict` and
it will be converted to its repsective `CacheConfig` internally. it will be converted to its repsective `CacheConfig` internally.
Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`. Otherwise can be passed as a `CacheConfig` class matching the indicated `cache_implementation`.
return_legacy_cache (`bool`, *optional*, default to `True`):
Whether to return the legacy or new format of the cache when `DynamicCache` is used by default.
> Wild card > Wild card
...@@ -404,6 +406,7 @@ class GenerationConfig(PushToHubMixin): ...@@ -404,6 +406,7 @@ class GenerationConfig(PushToHubMixin):
self.cache_config = cache_config_class() self.cache_config = cache_config_class()
elif isinstance(self.cache_config, dict): elif isinstance(self.cache_config, dict):
self.cache_config = cache_config_class.from_dict(self.cache_config) self.cache_config = cache_config_class.from_dict(self.cache_config)
self.return_legacy_cache = kwargs.pop("return_legacy_cache", True)
# Prompt lookup decoding # Prompt lookup decoding
self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None) self.prompt_lookup_num_tokens = kwargs.pop("prompt_lookup_num_tokens", None)
......
...@@ -1448,6 +1448,16 @@ class GenerationMixin: ...@@ -1448,6 +1448,16 @@ class GenerationMixin:
else: else:
return return
def _supports_default_dynamic_cache(self) -> bool:
"""
Return `True` if current model can use a `DynamicCache` instance when initializing the `past_key_values`.
This is mostly the same as `_supports_cache_class` attribute, but add exception for `Jamba` model which
uses its own `HybridMambaAttentionDynamicCache` and do not need to initialize the Cache in advance in
order to save memory (because no back and forth `to_legacy_cache` and `from_legacy_cache` will be performed
for `HybridMambaAttentionDynamicCache`).
"""
return self._supports_cache_class and "jamba" not in self.__class__.__name__.lower()
def _prepare_special_tokens( def _prepare_special_tokens(
self, self,
generation_config: GenerationConfig, generation_config: GenerationConfig,
...@@ -1709,6 +1719,7 @@ class GenerationMixin: ...@@ -1709,6 +1719,7 @@ class GenerationMixin:
input_ids_length=input_ids_length, input_ids_length=input_ids_length,
) )
use_dynamic_cache_by_default = False
if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None:
raise ValueError( raise ValueError(
"Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a "
...@@ -1750,6 +1761,16 @@ class GenerationMixin: ...@@ -1750,6 +1761,16 @@ class GenerationMixin:
) )
model_kwargs["past_key_values"] = cache_class(cache_config) model_kwargs["past_key_values"] = cache_class(cache_config)
# Use DynamicCache() instance by default. This will avoid back and forth from legacy format that
# keeps copying the cache thus using much more memory
elif generation_config.cache_implementation is None and self._supports_default_dynamic_cache():
past = model_kwargs.get("past_key_values", None)
if past is None:
model_kwargs["past_key_values"] = DynamicCache()
use_dynamic_cache_by_default = True
elif isinstance(past, tuple):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(past)
use_dynamic_cache_by_default = True
self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
...@@ -2018,6 +2039,11 @@ class GenerationMixin: ...@@ -2018,6 +2039,11 @@ class GenerationMixin:
**model_kwargs, **model_kwargs,
) )
# Convert to legacy cache if needed
if use_dynamic_cache_by_default and generation_config.return_legacy_cache:
if isinstance(result, ModelOutput) and hasattr(result, "past_key_values"):
if isinstance(result.past_key_values, DynamicCache):
result.past_key_values = result.past_key_values.to_legacy_cache()
return result return result
def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool:
...@@ -2185,7 +2211,10 @@ class GenerationMixin: ...@@ -2185,7 +2211,10 @@ class GenerationMixin:
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
# if the first step in the loop, encode all the prefix and obtain: (1) past_key_values; # if the first step in the loop, encode all the prefix and obtain: (1) past_key_values;
# (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step # (2) last_hidden_states; (3) logit_for_next_step; (4) update model kwargs for the next step
if model_kwargs.get("past_key_values") is None: if model_kwargs.get("past_key_values") is None or (
isinstance(model_kwargs["past_key_values"], Cache)
and model_kwargs["past_key_values"].get_seq_length() == 0
):
# prepare inputs # prepare inputs
model_kwargs["use_cache"] = True model_kwargs["use_cache"] = True
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
...@@ -2204,7 +2233,9 @@ class GenerationMixin: ...@@ -2204,7 +2233,9 @@ class GenerationMixin:
last_hidden_states = outputs.hidden_states[-1] last_hidden_states = outputs.hidden_states[-1]
# next logit for contrastive search to select top-k candidate tokens # next logit for contrastive search to select top-k candidate tokens
logit_for_next_step = outputs.logits[:, -1, :] # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for this first iteration
# (the clone itself is always small)
logit_for_next_step = outputs.logits[:, -1, :].clone()
model_kwargs = self._update_model_kwargs_for_generation( model_kwargs = self._update_model_kwargs_for_generation(
outputs, outputs,
...@@ -2212,6 +2243,7 @@ class GenerationMixin: ...@@ -2212,6 +2243,7 @@ class GenerationMixin:
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True, standardize_cache_format=True,
) )
if not sequential: if not sequential:
# Expands model inputs top_k times, for batched forward passes (akin to beam search). # Expands model inputs top_k times, for batched forward passes (akin to beam search).
_, model_kwargs = self._expand_inputs_for_generation( _, model_kwargs = self._expand_inputs_for_generation(
...@@ -2261,24 +2293,27 @@ class GenerationMixin: ...@@ -2261,24 +2293,27 @@ class GenerationMixin:
else (outputs.hidden_states,) else (outputs.hidden_states,)
) )
# This is needed to properly delete outputs.logits which may be very large for this first iteration
# Otherwise a reference to outputs.logits is kept all along until after the next call to self.forward()
del outputs
if not sequential:
# Replicates the new past_key_values to match the `top_k` candidates # Replicates the new past_key_values to match the `top_k` candidates
new_key_values = []
past = model_kwargs["past_key_values"] past = model_kwargs["past_key_values"]
# If it is a static cache, modify it in-place layer after layer to save memory
if isinstance(past, DynamicCache):
past.batch_repeat_interleave(top_k)
else:
new_key_values = []
for layer in past: for layer in past:
items = [] items = []
# item is either the key or the value matrix # item is either the key or the value matrix
for item in layer: for item in layer:
if sequential:
items.append(item.repeat_interleave(1, dim=0))
else:
items.append(item.repeat_interleave(top_k, dim=0)) items.append(item.repeat_interleave(top_k, dim=0))
new_key_values.append(tuple(items)) new_key_values.append(tuple(items))
if not isinstance(past, DynamicCache):
past = tuple(new_key_values) past = tuple(new_key_values)
else:
for layer_idx in range(len(new_key_values)):
past.key_cache[layer_idx] = new_key_values[layer_idx][0]
past.value_cache[layer_idx] = new_key_values[layer_idx][1]
model_kwargs["past_key_values"] = past model_kwargs["past_key_values"] = past
if sequential: if sequential:
...@@ -2293,6 +2328,12 @@ class GenerationMixin: ...@@ -2293,6 +2328,12 @@ class GenerationMixin:
output_hidden_states=True, output_hidden_states=True,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
if isinstance(outputs["past_key_values"], DynamicCache):
# Remove past K-V from output since we don't need to stack later
outputs["past_key_values"] = None
# Remove last token from past K-V since we don't want to append it at this point
model_kwargs["past_key_values"].crop(-1)
all_outputs.append(outputs) all_outputs.append(outputs)
outputs = stack_model_outputs(all_outputs) outputs = stack_model_outputs(all_outputs)
...@@ -2307,6 +2348,11 @@ class GenerationMixin: ...@@ -2307,6 +2348,11 @@ class GenerationMixin:
output_hidden_states=True, output_hidden_states=True,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
# This is essential to avoid having a last reference to the big past K-V and double the necesary memory
# in the next loop
del next_model_inputs
# name is different for encoder-decoder and decoder-only models # name is different for encoder-decoder and decoder-only models
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
next_hidden = outputs.decoder_hidden_states[-1] next_hidden = outputs.decoder_hidden_states[-1]
...@@ -2316,7 +2362,6 @@ class GenerationMixin: ...@@ -2316,7 +2362,6 @@ class GenerationMixin:
full_hidden_states = outputs.hidden_states full_hidden_states = outputs.hidden_states
logits = outputs.logits[:, -1, :] logits = outputs.logits[:, -1, :]
context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0) context_hidden = last_hidden_states.repeat_interleave(top_k, dim=0)
# compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the # compute the degeneration penalty and re-rank the candidates based on the degeneration penalty and the
...@@ -2325,6 +2370,9 @@ class GenerationMixin: ...@@ -2325,6 +2370,9 @@ class GenerationMixin:
selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k) selected_idx = _ranking_fast(context_hidden, next_hidden, top_k_probs, penalty_alpha, top_k)
selected_idx = selected_idx.to("cpu") selected_idx = selected_idx.to("cpu")
# This will be used instead of the previous inneficient torch.stack(torch.split())
augmented_idx = torch.tensor([x + i * top_k for i, x in enumerate(selected_idx)])
# prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing # prepare for the next step: (1) next token_id; (2) past_key_values; (3) last_hidden_states for computing
# the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores # the degeneration penalty; (4) logits for selecting next top-k candidates; (5) selected tokens scores
# (model confidence minus degeneration penalty); (6) decoder hidden_states # (model confidence minus degeneration penalty); (6) decoder hidden_states
...@@ -2354,22 +2402,19 @@ class GenerationMixin: ...@@ -2354,22 +2402,19 @@ class GenerationMixin:
else: else:
next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
# Do it in-place layer per layer to save memory
if isinstance(next_past_key_values, DynamicCache):
next_past_key_values.batch_select_indices(augmented_idx)
else:
new_key_values = [] new_key_values = []
for layer in next_past_key_values: for layer in next_past_key_values:
items = [] items = []
# item is either the key or the value matrix # item is either the key or the value matrix
for item in layer: for item in layer:
item = torch.stack(torch.split(item, top_k, dim=0)) # [B, K, num_head, seq_len, esz] items.append(item[augmented_idx, ...])
item = item[range(batch_size), selected_idx, ...] # [B, num_head, seq_len, esz] new_key_values.append(tuple(items))
items += [item]
new_key_values += [items]
if not isinstance(next_past_key_values, DynamicCache):
next_past_key_values = tuple(new_key_values) next_past_key_values = tuple(new_key_values)
else:
for layer_idx in range(len(new_key_values)):
next_past_key_values.key_cache[layer_idx] = new_key_values[layer_idx][0]
next_past_key_values.value_cache[layer_idx] = new_key_values[layer_idx][1]
logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :] logit_for_next_step = torch.stack(torch.split(logits, top_k))[range(batch_size), selected_idx, :]
...@@ -2431,6 +2476,9 @@ class GenerationMixin: ...@@ -2431,6 +2476,9 @@ class GenerationMixin:
# Contrastive search works by forward looking at the next token, so we need to exclude it from # Contrastive search works by forward looking at the next token, so we need to exclude it from
# `past_key_values` to be consistent with the other decoding methods # `past_key_values` to be consistent with the other decoding methods
if model_kwargs.get("past_key_values") is not None: if model_kwargs.get("past_key_values") is not None:
if isinstance(model_kwargs["past_key_values"], DynamicCache):
model_kwargs["past_key_values"].crop(-1)
else:
past_key_values = [] past_key_values = []
for layer in model_kwargs["past_key_values"]: for layer in model_kwargs["past_key_values"]:
layer_past_key_values = [] layer_past_key_values = []
...@@ -2588,7 +2636,9 @@ class GenerationMixin: ...@@ -2588,7 +2636,9 @@ class GenerationMixin:
if synced_gpus and this_peer_finished: if synced_gpus and this_peer_finished:
continue # don't waste resources running the code we don't need continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :] # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
# pre-process distribution # pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits) next_token_scores = logits_processor(input_ids, next_token_logits)
...@@ -2639,6 +2689,10 @@ class GenerationMixin: ...@@ -2639,6 +2689,10 @@ class GenerationMixin:
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores) unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
this_peer_finished = unfinished_sequences.max() == 0 this_peer_finished = unfinished_sequences.max() == 0
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
del outputs
if streamer is not None: if streamer is not None:
streamer.end() streamer.end()
...@@ -2846,7 +2900,9 @@ class GenerationMixin: ...@@ -2846,7 +2900,9 @@ class GenerationMixin:
cur_len = cur_len + 1 cur_len = cur_len + 1
continue # don't waste resources running the code we don't need continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :] # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
next_token_scores = nn.functional.log_softmax( next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1 next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
...@@ -2922,6 +2978,13 @@ class GenerationMixin: ...@@ -2922,6 +2978,13 @@ class GenerationMixin:
model_kwargs, model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
) )
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
# (that way the memory peak does not include outputs.logits)
del outputs
if model_kwargs.get("past_key_values", None) is not None: if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx model_kwargs["past_key_values"], beam_idx
...@@ -3125,7 +3188,9 @@ class GenerationMixin: ...@@ -3125,7 +3188,9 @@ class GenerationMixin:
if output_scores: if output_scores:
processed_score = torch.zeros_like(outputs.logits[:, -1, :]) processed_score = torch.zeros_like(outputs.logits[:, -1, :])
if output_logits: if output_logits:
raw_logit_score = outputs.logits[:, -1, :] # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
raw_logit_score = outputs.logits[:, -1, :].clone()
for beam_group_idx in range(num_beam_groups): for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams group_start_idx = beam_group_idx * num_sub_beams
...@@ -3142,6 +3207,7 @@ class GenerationMixin: ...@@ -3142,6 +3207,7 @@ class GenerationMixin:
group_input_ids = input_ids[batch_group_indices] group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of current group only # select outputs of beams of current group only
# No need to clone() the logits here as they will not retain outputs.logits at the end of the loop
next_token_logits = outputs.logits[batch_group_indices, -1, :] next_token_logits = outputs.logits[batch_group_indices, -1, :]
next_token_scores = nn.functional.log_softmax( next_token_scores = nn.functional.log_softmax(
...@@ -3231,6 +3297,13 @@ class GenerationMixin: ...@@ -3231,6 +3297,13 @@ class GenerationMixin:
model_kwargs, model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
) )
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
# (that way the memory peak does not include outputs.logits)
del outputs
if model_kwargs.get("past_key_values", None) is not None: if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], reordering_indices model_kwargs["past_key_values"], reordering_indices
...@@ -3393,7 +3466,9 @@ class GenerationMixin: ...@@ -3393,7 +3466,9 @@ class GenerationMixin:
cur_len = cur_len + 1 cur_len = cur_len + 1
continue # don't waste resources running the code we don't need continue # don't waste resources running the code we don't need
next_token_logits = outputs.logits[:, -1, :] # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
# (the clone itself is always small)
next_token_logits = outputs.logits[:, -1, :].clone()
next_token_scores = nn.functional.log_softmax( next_token_scores = nn.functional.log_softmax(
next_token_logits, dim=-1 next_token_logits, dim=-1
) # (batch_size * num_beams, vocab_size) ) # (batch_size * num_beams, vocab_size)
...@@ -3461,6 +3536,13 @@ class GenerationMixin: ...@@ -3461,6 +3536,13 @@ class GenerationMixin:
model_kwargs, model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
) )
# This is needed to properly delete outputs.logits which may be very large for first iteration
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
# IMPORTANT: Note that this should appear BEFORE the call to _reorder_cache() to save the maximum memory
# (that way the memory peak does not include outputs.logits)
del outputs
if model_kwargs.get("past_key_values", None) is not None: if model_kwargs.get("past_key_values", None) is not None:
model_kwargs["past_key_values"] = self._temporary_reorder_cache( model_kwargs["past_key_values"] = self._temporary_reorder_cache(
model_kwargs["past_key_values"], beam_idx model_kwargs["past_key_values"], beam_idx
...@@ -3597,6 +3679,13 @@ class GenerationMixin: ...@@ -3597,6 +3679,13 @@ class GenerationMixin:
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
# This is needed if return_dict_in_generate is True
if isinstance(model_kwargs.get("past_key_values", None), DynamicCache):
if len(model_kwargs["past_key_values"]) == 0:
start_from_empty_dynamic_cache = True
else:
start_from_empty_dynamic_cache = False
this_peer_finished = False this_peer_finished = False
while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device): while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
cur_len = input_ids.shape[-1] cur_len = input_ids.shape[-1]
...@@ -3709,8 +3798,10 @@ class GenerationMixin: ...@@ -3709,8 +3798,10 @@ class GenerationMixin:
if output_logits: if output_logits:
raw_logits += (next_token_logits,) raw_logits += (next_token_logits,)
if "past_key_values" not in model_kwargs: if "past_key_values" not in model_kwargs or start_from_empty_dynamic_cache:
added_len = new_cur_len added_len = new_cur_len
# set it to false for other iterations
start_from_empty_dynamic_cache = False
else: else:
added_len = n_matches + 1 added_len = n_matches + 1
...@@ -3909,6 +4000,9 @@ def _split(data, full_batch_size: int, split_size: int = None): ...@@ -3909,6 +4000,9 @@ def _split(data, full_batch_size: int, split_size: int = None):
return [None] * (full_batch_size // split_size) return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor): if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)] return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
# New cache format
elif isinstance(data, DynamicCache):
return data.batch_split(full_batch_size, split_size)
elif isinstance(data, tuple): elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple): if isinstance(data[0], tuple):
...@@ -4012,6 +4106,9 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput: ...@@ -4012,6 +4106,9 @@ def stack_model_outputs(model_outputs: List[ModelOutput]) -> ModelOutput:
return None return None
if isinstance(data[0], torch.Tensor): if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0) return torch.cat(data, dim=0)
# New cache format
elif isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data)
elif isinstance(data[0], tuple): elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example) # If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple): if isinstance(data[0][0], tuple):
......
...@@ -1167,7 +1167,7 @@ class CohereForCausalLM(CoherePreTrainedModel): ...@@ -1167,7 +1167,7 @@ class CohereForCausalLM(CoherePreTrainedModel):
): ):
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = ( max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device) torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
...@@ -1175,10 +1175,6 @@ class CohereForCausalLM(CoherePreTrainedModel): ...@@ -1175,10 +1175,6 @@ class CohereForCausalLM(CoherePreTrainedModel):
else None else None
) )
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1208,7 +1204,7 @@ class CohereForCausalLM(CoherePreTrainedModel): ...@@ -1208,7 +1204,7 @@ class CohereForCausalLM(CoherePreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
......
...@@ -1443,7 +1443,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel): ...@@ -1443,7 +1443,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
): ):
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = ( max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device) torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
...@@ -1451,10 +1451,6 @@ class DbrxForCausalLM(DbrxPreTrainedModel): ...@@ -1451,10 +1451,6 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
else None else None
) )
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1484,7 +1480,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel): ...@@ -1484,7 +1480,7 @@ class DbrxForCausalLM(DbrxPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
......
...@@ -1163,7 +1163,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1163,7 +1163,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
): ):
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = ( max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device) torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
...@@ -1171,10 +1171,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1171,10 +1171,6 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
else None else None
) )
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1204,7 +1200,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel): ...@@ -1204,7 +1200,7 @@ class GemmaForCausalLM(GemmaPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
......
...@@ -1876,15 +1876,13 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel): ...@@ -1876,15 +1876,13 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
past_length = 0
# Omit tokens covered by past_key_values # Omit tokens covered by past_key_values
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length() max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1915,7 +1913,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel): ...@@ -1915,7 +1913,7 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
......
...@@ -1218,7 +1218,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1218,7 +1218,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
): ):
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = ( max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device) torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
...@@ -1226,10 +1226,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1226,10 +1226,6 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
else None else None
) )
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1259,7 +1255,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): ...@@ -1259,7 +1255,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
......
...@@ -1240,10 +1240,10 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1240,10 +1240,10 @@ class MistralForCausalLM(MistralPreTrainedModel):
use_cache=True, use_cache=True,
**kwargs, **kwargs,
): ):
# Omit tokens covered by past_key_values
past_length = 0 past_length = 0
# Omit tokens covered by past_key_values
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = ( max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device) torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
...@@ -1251,10 +1251,6 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1251,10 +1251,6 @@ class MistralForCausalLM(MistralPreTrainedModel):
else None else None
) )
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1294,7 +1290,7 @@ class MistralForCausalLM(MistralPreTrainedModel): ...@@ -1294,7 +1290,7 @@ class MistralForCausalLM(MistralPreTrainedModel):
attention_mask = attention_mask[:, -past_key_values.max_cache_len :] attention_mask = attention_mask[:, -past_key_values.max_cache_len :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids.contiguous()} model_inputs = {"input_ids": input_ids.contiguous()}
......
...@@ -1407,15 +1407,13 @@ class MixtralForCausalLM(MixtralPreTrainedModel): ...@@ -1407,15 +1407,13 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
output_router_logits=False, output_router_logits=False,
**kwargs, **kwargs,
): ):
past_length = 0
# Omit tokens covered by past_key_values # Omit tokens covered by past_key_values
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length() max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1446,7 +1444,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel): ...@@ -1446,7 +1444,7 @@ class MixtralForCausalLM(MixtralPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
......
...@@ -1198,7 +1198,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel): ...@@ -1198,7 +1198,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
): ):
past_length = 0 past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length()
max_cache_length = ( max_cache_length = (
torch.tensor(past_key_values.get_max_length(), device=input_ids.device) torch.tensor(past_key_values.get_max_length(), device=input_ids.device)
...@@ -1206,10 +1206,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel): ...@@ -1206,10 +1206,6 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
else None else None
) )
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1239,7 +1235,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel): ...@@ -1239,7 +1235,7 @@ class OlmoForCausalLM(OlmoPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
......
...@@ -832,14 +832,12 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): ...@@ -832,14 +832,12 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length() max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -870,7 +868,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): ...@@ -870,7 +868,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
......
...@@ -1211,14 +1211,12 @@ class PhiForCausalLM(PhiPreTrainedModel): ...@@ -1211,14 +1211,12 @@ class PhiForCausalLM(PhiPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length() max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1249,7 +1247,7 @@ class PhiForCausalLM(PhiPreTrainedModel): ...@@ -1249,7 +1247,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
......
...@@ -1299,14 +1299,12 @@ class Phi3ForCausalLM(Phi3PreTrainedModel): ...@@ -1299,14 +1299,12 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length() max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1337,7 +1335,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel): ...@@ -1337,7 +1335,7 @@ class Phi3ForCausalLM(Phi3PreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
......
...@@ -1199,15 +1199,13 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel): ...@@ -1199,15 +1199,13 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
past_length = 0
# Omit tokens covered by past_key_values # Omit tokens covered by past_key_values
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length() max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1238,7 +1236,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel): ...@@ -1238,7 +1236,7 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
......
...@@ -1394,15 +1394,13 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel): ...@@ -1394,15 +1394,13 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
past_length = 0
# Omit tokens covered by past_key_values # Omit tokens covered by past_key_values
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length() max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1433,7 +1431,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel): ...@@ -1433,7 +1431,7 @@ class Qwen2MoeForCausalLM(Qwen2MoePreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
......
...@@ -1208,14 +1208,12 @@ class StableLmForCausalLM(StableLmPreTrainedModel): ...@@ -1208,14 +1208,12 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
past_length = 0
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length() max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1246,7 +1244,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel): ...@@ -1246,7 +1244,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
......
...@@ -1182,15 +1182,13 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): ...@@ -1182,15 +1182,13 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
): ):
past_length = 0
# Omit tokens covered by past_key_values # Omit tokens covered by past_key_values
if past_key_values is not None: if past_key_values is not None:
if isinstance(past_key_values, Cache): # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
cache_length = past_key_values.get_seq_length() cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens past_length = past_key_values.seen_tokens
max_cache_length = past_key_values.get_max_length() max_cache_length = past_key_values.get_max_length()
else:
cache_length = past_length = past_key_values[0][0].shape[2]
max_cache_length = None
# Keep only the unprocessed tokens: # Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
...@@ -1221,7 +1219,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): ...@@ -1221,7 +1219,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None: if inputs_embeds is not None and past_length == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment