Unverified Commit 739a6316 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: remove deprecated code due to `Cache` and `cache_position` being default (#31898)



* tmp commit

* shorter

* nit

* explicit kwargs

* propagate changes

* mass propagation with a few manual touches (let's see how CI behaves)

* fix cacheless case

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* make fixup

---------
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent 8480fda6
...@@ -1072,6 +1072,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): ...@@ -1072,6 +1072,7 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
attentions=outputs.attentions, attentions=outputs.attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
input_ids, input_ids,
...@@ -1079,42 +1080,19 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): ...@@ -1079,42 +1080,19 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
attention_mask=None, attention_mask=None,
inputs_embeds=None, inputs_embeds=None,
cache_position=None, cache_position=None,
position_ids=None,
use_cache=True, use_cache=True,
**kwargs, **kwargs,
): ):
past_length = 0 # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Omit tokens covered by past_key_values # Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
if past_key_values is not None: if past_key_values is not None:
# Past key values are always initialized with a `Cache` object -> no need for if-else anymore if inputs_embeds is not None: # Exception 1
past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() input_ids = input_ids[:, -cache_position.shape[0] :]
max_cache_length = ( elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
torch.tensor(past_key_values.get_max_length(), device=input_ids.device) input_ids = input_ids[:, cache_position]
if past_key_values.get_max_length() is not None
else None
)
cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length)
# Keep only the unprocessed tokens:
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
# input)
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
# input_ids based on the past_length.
elif past_length < input_ids.shape[1]:
input_ids = input_ids[:, past_length:]
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
# If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
if (
max_cache_length is not None
and attention_mask is not None
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None: if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation # create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1 position_ids = attention_mask.long().cumsum(-1) - 1
...@@ -1123,37 +1101,22 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel): ...@@ -1123,37 +1101,22 @@ class Starcoder2ForCausalLM(Starcoder2PreTrainedModel):
position_ids = position_ids[:, -input_ids.shape[1] :] position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_length == 0: if inputs_embeds is not None and cache_position[0] == 0:
model_inputs = {"inputs_embeds": inputs_embeds} model_inputs = {"inputs_embeds": inputs_embeds}
else: else:
model_inputs = {"input_ids": input_ids} model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
if cache_position is None:
cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device)
elif use_cache:
cache_position = cache_position[-input_length:]
model_inputs.update( model_inputs.update(
{ {
"position_ids": position_ids, "position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values, "past_key_values": past_key_values,
"use_cache": use_cache, "use_cache": use_cache,
"attention_mask": attention_mask, "attention_mask": attention_mask,
"cache_position": cache_position,
} }
) )
return model_inputs return model_inputs
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -1811,15 +1811,6 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM ...@@ -1811,15 +1811,6 @@ class WhisperForConditionalGeneration(WhisperGenerationMixin, WhisperPreTrainedM
"cache_position": cache_position, "cache_position": cache_position,
} }
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
)
return reordered_past
class WhisperDecoderWrapper(WhisperPreTrainedModel): class WhisperDecoderWrapper(WhisperPreTrainedModel):
""" """
......
...@@ -35,8 +35,8 @@ if is_torch_available(): ...@@ -35,8 +35,8 @@ if is_torch_available():
AutoModelForCausalLM, AutoModelForCausalLM,
AutoTokenizer, AutoTokenizer,
DynamicCache, DynamicCache,
GPT2LMHeadModel,
LlamaConfig, LlamaConfig,
LlamaForCausalLM,
SinkCache, SinkCache,
StaticCache, StaticCache,
) )
...@@ -94,7 +94,7 @@ class CacheTest(unittest.TestCase): ...@@ -94,7 +94,7 @@ class CacheTest(unittest.TestCase):
def test_reorder_cache_retrocompatibility(self): def test_reorder_cache_retrocompatibility(self):
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path""" """Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
legacy_reorder_fn = LlamaForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function
legacy_cache = () legacy_cache = ()
new_cache = DynamicCache() new_cache = DynamicCache()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment