Unverified Commit f7396876 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

🚨 Bloom support for cache class (#31445)



* bloom dynamic cache

* bloom follows standard cache format

* no skips for bloom anymore

* use cache position when possible

* clean up

* codestyle

* Update src/transformers/models/bloom/modeling_bloom.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/bloom/modeling_bloom.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/bloom/modeling_bloom.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pr comments

* isinstance fix

* address comments

* make musicgen test happy

* [run-slow] bloom

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 44f6fdd7
......@@ -378,19 +378,7 @@ def _crop_past_key_values(model, past_key_values, max_length):
)
)
past_key_values = tuple(new_past)
# bloom is special
elif "bloom" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
):
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :max_length],
past_key_values[idx][1][:, :max_length, :],
)
)
past_key_values = tuple(new_past)
# gptbigcode is too
# gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model
elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
):
......@@ -402,7 +390,6 @@ def _crop_past_key_values(model, past_key_values, max_length):
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
elif isinstance(past_key_values, DynamicCache):
past_key_values.crop(max_length)
elif past_key_values is not None:
for idx in range(len(past_key_values)):
new_past.append(
......
......@@ -639,7 +639,7 @@ class GenerationMixin:
return input_ids, model_kwargs
def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
def _extract_past_from_model_output(self, outputs: ModelOutput):
past_key_values = None
cache_name = "past_key_values"
if "past_key_values" in outputs:
......@@ -652,10 +652,6 @@ class GenerationMixin:
past_key_values = outputs.cache_params
cache_name = "cache_params"
# Bloom fix: standardizes the cache format when requested
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
batch_size = outputs.logits.shape[0]
past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
return cache_name, past_key_values
def _update_model_kwargs_for_generation(
......@@ -663,13 +659,10 @@ class GenerationMixin:
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
num_new_tokens: int = 1,
) -> Dict[str, Any]:
# update past_key_values keeping its naming used in model code
cache_name, cache = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
cache_name, cache = self._extract_past_from_model_output(outputs)
model_kwargs[cache_name] = cache
if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state
......@@ -2558,7 +2551,6 @@ class GenerationMixin:
outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
)
if not sequential:
......@@ -2723,7 +2715,7 @@ class GenerationMixin:
next_past_key_values = selected_outputs["past_key_values"]
else:
_, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
_, next_past_key_values = self._extract_past_from_model_output(outputs)
# Do it in-place layer per layer to save memory
if isinstance(next_past_key_values, DynamicCache) or (
isinstance(next_past_key_values, EncoderDecoderCache)
......@@ -3033,7 +3025,7 @@ class GenerationMixin:
past_key_values = self._reorder_cache(past_key_values, beam_idx)
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
# cache format is standardized, to avoid adding complexity to the codebase.
elif "bloom" in model_class or "gptbigcode" in model_class:
elif "gptbigcode" in model_class:
if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
raise ValueError(
f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
......@@ -3161,7 +3153,6 @@ class GenerationMixin:
for model_name in [
"fsmt",
"reformer",
"bloom",
"ctrl",
"gpt_bigcode",
"transo_xl",
......
......@@ -2568,13 +2568,10 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
# update past_key_values
cache_name, cache = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
cache_name, cache = self._extract_past_from_model_output(outputs)
model_kwargs[cache_name] = cache
if getattr(outputs, "state", None) is not None:
......
......@@ -252,7 +252,6 @@ class PersimmonAttention(nn.Module):
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
# Copied from transformers.models.bloom.modeling_bloom.BloomAttention._split_heads
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
......
......@@ -1096,7 +1096,6 @@ class GenerationTesterMixin:
if any(
model_name in model_class.__name__.lower()
for model_name in [
"bloom",
"ctrl",
"gptbigcode",
"transo_xl",
......@@ -1878,7 +1877,7 @@ class GenerationTesterMixin:
# 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
# complete
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment