Unverified Commit 938cb047 authored by Joao Gante's avatar Joao Gante Committed by GitHub
Browse files

Generate: add Bloom fixes for contrastive search (#20213)

parent fda12563
...@@ -672,8 +672,7 @@ class GenerationMixin: ...@@ -672,8 +672,7 @@ class GenerationMixin:
return input_ids, model_kwargs return input_ids, model_kwargs
@staticmethod def _extract_past_from_model_output(self, outputs: ModelOutput, standardize_cache_format: bool = False):
def _extract_past_from_model_output(outputs: ModelOutput):
past = None past = None
if "past_key_values" in outputs: if "past_key_values" in outputs:
past = outputs.past_key_values past = outputs.past_key_values
...@@ -681,13 +680,24 @@ class GenerationMixin: ...@@ -681,13 +680,24 @@ class GenerationMixin:
past = outputs.mems past = outputs.mems
elif "past_buckets_states" in outputs: elif "past_buckets_states" in outputs:
past = outputs.past_buckets_states past = outputs.past_buckets_states
# Bloom fix: standardizes the cache format when requested
if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
batch_size = outputs.logits.shape[0]
past = self._convert_to_standard_cache(past, batch_size=batch_size)
return past return past
def _update_model_kwargs_for_generation( def _update_model_kwargs_for_generation(
self, outputs: ModelOutput, model_kwargs: Dict[str, Any], is_encoder_decoder: bool = False self,
outputs: ModelOutput,
model_kwargs: Dict[str, Any],
is_encoder_decoder: bool = False,
standardize_cache_format: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
# update past # update past
model_kwargs["past"] = self._extract_past_from_model_output(outputs) model_kwargs["past"] = self._extract_past_from_model_output(
outputs, standardize_cache_format=standardize_cache_format
)
# update token_type_ids with last value # update token_type_ids with last value
if "token_type_ids" in model_kwargs: if "token_type_ids" in model_kwargs:
...@@ -1939,7 +1949,10 @@ class GenerationMixin: ...@@ -1939,7 +1949,10 @@ class GenerationMixin:
logit_for_next_step = outputs.logits[:, -1, :] logit_for_next_step = outputs.logits[:, -1, :]
model_kwargs = self._update_model_kwargs_for_generation( model_kwargs = self._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder outputs,
model_kwargs,
is_encoder_decoder=self.config.is_encoder_decoder,
standardize_cache_format=True,
) )
# Expands model inputs top_k times, for batched forward passes (akin to beam search). # Expands model inputs top_k times, for batched forward passes (akin to beam search).
...@@ -2001,7 +2014,7 @@ class GenerationMixin: ...@@ -2001,7 +2014,7 @@ class GenerationMixin:
outputs = self( outputs = self(
**next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions **next_model_inputs, return_dict=True, output_hidden_states=True, output_attentions=output_attentions
) )
next_past_key_values = self._extract_past_from_model_output(outputs) next_past_key_values = self._extract_past_from_model_output(outputs, standardize_cache_format=True)
logits = outputs.logits[:, -1, :] logits = outputs.logits[:, -1, :]
# name is different for encoder-decoder and decoder-only models # name is different for encoder-decoder and decoder-only models
......
...@@ -506,6 +506,45 @@ class BloomPreTrainedModel(PreTrainedModel): ...@@ -506,6 +506,45 @@ class BloomPreTrainedModel(PreTrainedModel):
if isinstance(module, BloomModel): if isinstance(module, BloomModel):
module.gradient_checkpointing = value module.gradient_checkpointing = value
@staticmethod
def _convert_to_standard_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
num_heads, ...]))
"""
batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
num_heads = batch_size_times_num_heads // batch_size
# key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
# value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
@staticmethod
def _convert_to_bloom_cache(
past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
"""
Converts the cache to the format expected by Bloom, i.e. to tuple(tuple([batch_size * num_heads, ...]))
"""
batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
batch_size_times_num_heads = batch_size * num_heads
# key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
# value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
return tuple(
(
layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
)
for layer_past in past_key_value
)
BLOOM_START_DOCSTRING = r""" BLOOM_START_DOCSTRING = r"""
...@@ -811,6 +850,10 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -811,6 +850,10 @@ class BloomForCausalLM(BloomPreTrainedModel):
if past: if past:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
# the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed
if past[0][0].shape[0] == input_ids.shape[0]:
past = self._convert_to_bloom_cache(past)
return { return {
"input_ids": input_ids, "input_ids": input_ids,
"past_key_values": past, "past_key_values": past,
...@@ -896,9 +939,8 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -896,9 +939,8 @@ class BloomForCausalLM(BloomPreTrainedModel):
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@staticmethod
def _reorder_cache( def _reorder_cache(
past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]: ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
""" """
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
...@@ -907,28 +949,20 @@ class BloomForCausalLM(BloomPreTrainedModel): ...@@ -907,28 +949,20 @@ class BloomForCausalLM(BloomPreTrainedModel):
Output shares the same memory storage as `past`. Output shares the same memory storage as `past`.
""" """
batch_size_times_num_heads, head_dim, seq_length = past[0][0].shape standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
batch_size = len(beam_idx)
num_heads = batch_size_times_num_heads // batch_size
# Get a copy of `beam_idx` on all the devices where we need those indices. # Get a copy of `beam_idx` on all the devices where we need those indices.
device_to_beam_idx = { device_to_beam_idx = {
past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
} }
# key: layer_past[0] [batch_size * num_heads, head_dim, seq_length] reordered_past = tuple(
# value: layer_past[1] [batch_size * num_heads, seq_length, head_dim]
return tuple(
( (
layer_past[0] layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
.view(batch_size, num_heads, head_dim, seq_length) layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
.index_select(0, device_to_beam_idx[layer_past[0].device])
.view(batch_size_times_num_heads, head_dim, seq_length),
layer_past[1]
.view(batch_size, num_heads, seq_length, head_dim)
.index_select(0, device_to_beam_idx[layer_past[0].device])
.view(batch_size_times_num_heads, seq_length, head_dim),
) )
for layer_past in past for layer_past in standardized_past
) )
return self._convert_to_bloom_cache(reordered_past)
@add_start_docstrings( @add_start_docstrings(
......
...@@ -1411,9 +1411,8 @@ class GenerationTesterMixin: ...@@ -1411,9 +1411,8 @@ class GenerationTesterMixin:
# check `generate()` and `contrastive_search()` are equal # check `generate()` and `contrastive_search()` are equal
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
# won't fix: FSMT and Reformer have a different cache variable type (and format). # won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return return
config, input_ids, attention_mask, max_length = self._get_input_ids_and_config() config, input_ids, attention_mask, max_length = self._get_input_ids_and_config()
...@@ -1434,9 +1433,8 @@ class GenerationTesterMixin: ...@@ -1434,9 +1433,8 @@ class GenerationTesterMixin:
def test_contrastive_generate_dict_outputs_use_cache(self): def test_contrastive_generate_dict_outputs_use_cache(self):
for model_class in self.all_generative_model_classes: for model_class in self.all_generative_model_classes:
# TODO: Fix Bloom. Bloom fails because `past` has a different shape.
# won't fix: FSMT and Reformer have a different cache variable type (and format). # won't fix: FSMT and Reformer have a different cache variable type (and format).
if any(model_name in model_class.__name__.lower() for model_name in ["bloom", "fsmt", "reformer"]): if any(model_name in model_class.__name__.lower() for model_name in ["fsmt", "reformer"]):
return return
# enable cache # enable cache
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment