"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "6f79d264422245d88c7a34032c1a8254a0c65752"
Unverified Commit f7396876 authored by Raushan Turganbay's avatar Raushan Turganbay Committed by GitHub
Browse files

🚨 Bloom support for cache class (#31445)



* bloom dynamic cache

* bloom follows standard cache format

* no skips for bloom anymore

* use cache position when possible

* clean up

* codestyle

* Update src/transformers/models/bloom/modeling_bloom.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/bloom/modeling_bloom.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/bloom/modeling_bloom.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* pr comments

* isinstance fix

* address comments

* make musicgen test happy

* [run-slow] bloom

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 44f6fdd7
...@@ -378,19 +378,7 @@ def _crop_past_key_values(model, past_key_values, max_length): ...@@ -378,19 +378,7 @@ def _crop_past_key_values(model, past_key_values, max_length):
) )
) )
past_key_values = tuple(new_past) past_key_values = tuple(new_past)
# bloom is special # gptbigcode is special and stores kv in shape (batch_size, seq_len, dim), if it's a multi_query model
elif "bloom" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "bloom" in model.config.architectures[0].lower()
):
for idx in range(len(past_key_values)):
new_past.append(
(
past_key_values[idx][0][:, :, :max_length],
past_key_values[idx][1][:, :max_length, :],
)
)
past_key_values = tuple(new_past)
# gptbigcode is too
elif "gptbigcode" in model.__class__.__name__.lower() or ( elif "gptbigcode" in model.__class__.__name__.lower() or (
model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower() model.config.architectures is not None and "gptbigcode" in model.config.architectures[0].lower()
): ):
...@@ -402,7 +390,6 @@ def _crop_past_key_values(model, past_key_values, max_length): ...@@ -402,7 +390,6 @@ def _crop_past_key_values(model, past_key_values, max_length):
past_key_values[idx] = past_key_values[idx][:, :, :max_length, :] past_key_values[idx] = past_key_values[idx][:, :, :max_length, :]
elif isinstance(past_key_values, DynamicCache): elif isinstance(past_key_values, DynamicCache):
past_key_values.crop(max_length) past_key_values.crop(max_length)
elif past_key_values is not None: elif past_key_values is not None:
for idx in range(len(past_key_values)): for idx in range(len(past_key_values)):
new_past.append( new_past.append(
......
...@@ -639,7 +639,7 @@ class GenerationMixin: ...@@ -639,7 +639,7 @@ class GenerationMixin:
return input_ids, model_kwargs return input_ids, model_kwargs
def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False): def _extract_past_from_model_output(self, outputs: ModelOutput):
past_key_values = None past_key_values = None
cache_name = "past_key_values" cache_name = "past_key_values"
if "past_key_values" in outputs: if "past_key_values" in outputs:
...@@ -652,10 +652,6 @@ class GenerationMixin: ...@@ -652,10 +652,6 @@ class GenerationMixin:
past_key_values = outputs.cache_params past_key_values = outputs.cache_params
cache_name = "cache_params" cache_name = "cache_params"
# Bloom fix: standardizes the cache format when requested
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
batch_size = outputs.logits.shape[0]
past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
return cache_name, past_key_values return cache_name, past_key_values
def _update_model_kwargs_for_generation( def _update_model_kwargs_for_generation(
...@@ -663,13 +659,10 @@ class GenerationMixin: ...@@ -663,13 +659,10 @@ class GenerationMixin:
outputs: ModelOutput, outputs: ModelOutput,
model_kwargs: Dict[str, Any], model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False, is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
num_new_tokens: int = 1, num_new_tokens: int = 1,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# update past_key_values keeping its naming used in model code # update past_key_values keeping its naming used in model code
cache_name, cache = self._extract_past_from_model_output( cache_name, cache = self._extract_past_from_model_output(outputs)
outputs, standardize_cache_format=standardize_cache_format
)
model_kwargs[cache_name] = cache model_kwargs[cache_name] = cache
if getattr(outputs, "state", None) is not None: if getattr(outputs, "state", None) is not None:
model_kwargs["state"] = outputs.state model_kwargs["state"] = outputs.state
...@@ -2558,7 +2551,6 @@ class GenerationMixin: ...@@ -2558,7 +2551,6 @@ class GenerationMixin:
outputs, outputs,
model_kwargs, model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder, is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
) )
if not sequential: if not sequential:
...@@ -2723,7 +2715,7 @@ class GenerationMixin: ...@@ -2723,7 +2715,7 @@ class GenerationMixin:
next_past_key_values = selected_outputs["past_key_values"] next_past_key_values = selected_outputs["past_key_values"]
else: else:
_, next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True) _, next_past_key_values = self._extract_past_from_model_output(outputs)
# Do it in-place layer per layer to save memory # Do it in-place layer per layer to save memory
if isinstance(next_past_key_values, DynamicCache) or ( if isinstance(next_past_key_values, DynamicCache) or (
isinstance(next_past_key_values, EncoderDecoderCache) isinstance(next_past_key_values, EncoderDecoderCache)
...@@ -3033,7 +3025,7 @@ class GenerationMixin: ...@@ -3033,7 +3025,7 @@ class GenerationMixin:
past_key_values = self._reorder_cache(past_key_values, beam_idx) past_key_values = self._reorder_cache(past_key_values, beam_idx)
# Exception 2: models with different cache formats. These are limited to `DynamicCache` until their # Exception 2: models with different cache formats. These are limited to `DynamicCache` until their
# cache format is standardized, to avoid adding complexity to the codebase. # cache format is standardized, to avoid adding complexity to the codebase.
elif "bloom" in model_class or "gptbigcode" in model_class: elif "gptbigcode" in model_class:
if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)): if not isinstance(past_key_values, (DynamicCache, EncoderDecoderCache)):
raise ValueError( raise ValueError(
f"Using an unsupported cache format with {model_class}. Currently, it only supports the " f"Using an unsupported cache format with {model_class}. Currently, it only supports the "
...@@ -3161,7 +3153,6 @@ class GenerationMixin: ...@@ -3161,7 +3153,6 @@ class GenerationMixin:
for model_name in [ for model_name in [
"fsmt", "fsmt",
"reformer", "reformer",
"bloom",
"ctrl", "ctrl",
"gpt_bigcode", "gpt_bigcode",
"transo_xl", "transo_xl",
......
...@@ -2568,13 +2568,10 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel): ...@@ -2568,13 +2568,10 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel):
outputs: ModelOutput, outputs: ModelOutput,
model_kwargs: Dict[str, Any], model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False, is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
model_inputs: Optional[Dict[str, Any]] = None, model_inputs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# update past_key_values # update past_key_values
cache_name, cache = self._extract_past_from_model_output( cache_name, cache = self._extract_past_from_model_output(outputs)
outputs, standardize_cache_format=standardize_cache_format
)
model_kwargs[cache_name] = cache model_kwargs[cache_name] = cache
if getattr(outputs, "state", None) is not None: if getattr(outputs, "state", None) is not None:
......
...@@ -252,7 +252,6 @@ class PersimmonAttention(nn.Module): ...@@ -252,7 +252,6 @@ class PersimmonAttention(nn.Module):
else: else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}") raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
# Copied from transformers.models.bloom.modeling_bloom.BloomAttention._split_heads
def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
......
...@@ -1096,7 +1096,6 @@ class GenerationTesterMixin: ...@@ -1096,7 +1096,6 @@ class GenerationTesterMixin:
if any( if any(
model_name in model_class.__name__.lower() model_name in model_class.__name__.lower()
for model_name in [ for model_name in [
"bloom",
"ctrl", "ctrl",
"gptbigcode", "gptbigcode",
"transo_xl", "transo_xl",
...@@ -1878,7 +1877,7 @@ class GenerationTesterMixin: ...@@ -1878,7 +1877,7 @@ class GenerationTesterMixin:
# 2. Some old models still return `output.past_key_values` even without `use_cache=True` # 2. Some old models still return `output.past_key_values` even without `use_cache=True`
# 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is # 3. TODO (joao): A few models have different formats/types, skipping those until the cache refactor is
# complete # complete
models_without_standard_cache = ("bloom", "ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba") models_without_standard_cache = ("ctrl", "fsmt", "gptbigcode", "mega", "reformer", "jamba", "mamba")
has_standard_cache = not any( has_standard_cache = not any(
model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache model_name in config.__class__.__name__.lower() for model_name in models_without_standard_cache
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment