Unverified Commit 7ad784ae authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

Gemma2: add cache warning (#32279)



* gemma2 fallback to dynamic cache

* Update src/transformers/models/gemma2/modeling_gemma2.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* Update src/transformers/models/gemma2/modeling_gemma2.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* raise error and dont fallback to dynamic cache

* prev will break most forward calls/tests

* Update src/transformers/models/gemma2/modeling_gemma2.py
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>

* update

* fix copies

---------
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: default avatarArthur <48595927+ArthurZucker@users.noreply.github.com>
parent a30c865f
...@@ -398,6 +398,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens ...@@ -398,6 +398,7 @@ A [`Constraint`] can be used to force the generation to include specific tokens
[[autodoc]] HybridCache [[autodoc]] HybridCache
- update - update
- get_seq_length
- reset - reset
[[autodoc]] SlidingWindowCache [[autodoc]] SlidingWindowCache
......
...@@ -30,6 +30,12 @@ Tips: ...@@ -30,6 +30,12 @@ Tips:
- The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py` - The original checkpoints can be converted using the conversion script `src/transformers/models/Gemma2/convert_Gemma2_weights_to_hf.py`
<Tip warning={true}>
- Gemma2 uses sliding window attention every second layer, which makes it unsuitable for typical kv caching with [`~DynamicCache`] or tuples of tensors. To enable caching in Gemma2 forward call, you must initialize a [`~HybridCache`] instance and pass it as `past_key_values` to the forward call. Note, that you also have to prepare `cache_position` if the `past_key_values` already contains previous keys and values.
</Tip>
This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen](). This model was contributed by [Arthur Zucker](https://huggingface.co/ArthurZ), [Pedro Cuenca](https://huggingface.co/pcuenq) and [Tom Arsen]().
......
...@@ -807,7 +807,26 @@ class Gemma2Model(Gemma2PreTrainedModel): ...@@ -807,7 +807,26 @@ class Gemma2Model(Gemma2PreTrainedModel):
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
if cache_position is None: if cache_position is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device) if past_key_values is None:
cache_position = torch.arange(0, inputs_embeds.shape[1], device=inputs_embeds.device)
else:
raise ValueError("When `past_key_values` is passed, `cache_position` must be too")
# Probably a forward call with caching, so we set up cache for one call only
if use_cache and past_key_values is None and not self.training:
logger.warning_once(
"You are calling the model with `use_cache=True` but didn't pass `past_key_values` while not training. ",
"If you want to compute with cache, make sure to pass an instance of `HybridCache`. An empty `HybridCache` instance "
"will be created for this call. See for more: (https://huggingface.co/docs/transformers/main/en/internal/generation_utils#transformers.HybridCache)",
)
batch_size, seq_len, _ = inputs_embeds.shape
past_key_values = HybridCache(
self.config,
max_batch_size=batch_size,
max_cache_len=seq_len,
device=self.device,
dtype=inputs_embeds.dtype,
)
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment