Unverified Commit b020a736 authored by Yusuke Mori's avatar Yusuke Mori Committed by GitHub
Browse files

Update `past_key_values` in GPT-2 (#9596)



* Update past_key_values in gpt2 (#9391)

* Update generation_utils, and rename some items

* Update modeling_gpt2 to avoid an error in gradient_checkpointing

* Remove 'reorder_cache' from util and add variations to XLNet, TransfoXL, GPT-2

* Change the location of '_reorder_cache' in modeling files

* Add '_reorder_cache' in modeling_ctrl

* Fix a bug of my last commit in CTRL

* Add '_reorder_cache' to GPT2DoubleHeadsModel

* Manage 'use_cache' in config of test_modeling_gpt2

* Clean up the doc string

* Update src/transformers/models/gpt2/modeling_gpt2.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix the doc string (GPT-2, CTRL)

* improve gradient_checkpointing_behavior
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 97b787fb
......@@ -503,18 +503,10 @@ class GenerationMixin:
return model_kwargs
@staticmethod
def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]:
"""
This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every
generation step.
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
subclasses of :class:`~transformers.PreTrainedModel`.
"""
return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past)
def _reorder_cache(self, past, beam_idx):
raise NotImplementedError(
f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}"
)
def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
......
......@@ -774,7 +774,7 @@ class BartEncoder(BartPretrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -993,11 +993,13 @@ class BartDecoder(BartPretrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
raise ValueError(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -539,7 +539,14 @@ class BertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -733,7 +733,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -955,11 +955,13 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
raise ValueError(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -735,7 +735,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -955,11 +955,13 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
raise ValueError(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -15,6 +15,8 @@
# limitations under the License.
""" PyTorch CTRL model."""
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
......@@ -262,7 +264,7 @@ CTRL_INPUTS_DOCSTRING = r"""
details.
`What are input IDs? <../glossary.html#input-ids>`__
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
past_key_values (:obj:`Tuple[Tuple[torch.FloatTensor]]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
have their past given to this model should not be passed as input ids as they have already been computed.
......@@ -389,7 +391,7 @@ class CTRLModel(CTRLPreTrainedModel):
if past_key_values is None:
past_length = 0
past_key_values = [None] * len(self.h)
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
......@@ -575,6 +577,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings(
"""
......
......@@ -536,7 +536,14 @@ class ElectraEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -17,7 +17,7 @@
import os
from dataclasses import dataclass
from typing import List, Optional, Tuple
from typing import Optional, Tuple
import torch
import torch.nn as nn
......@@ -233,7 +233,7 @@ class Attention(nn.Module):
value = torch.cat((past_value, value), dim=-2)
if use_cache is True:
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
present = (key.transpose(-2, -1), value) # transpose to have same shapes
else:
present = None
......@@ -370,9 +370,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
batch_size, num_heads, sequence_length, embed_size_per_head)`).
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
Tuple of length :obj:`config.n_layers`, containing tuples of tensors of shape :obj:`(batch_size, num_heads,
sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding.
......@@ -393,7 +393,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
mc_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
mc_logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
......@@ -419,7 +419,7 @@ GPT2_INPUTS_DOCSTRING = r"""
Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
sequence tokens in the vocabulary.
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
......@@ -430,7 +430,7 @@ GPT2_INPUTS_DOCSTRING = r"""
details.
`What are input IDs? <../glossary.html#input-ids>`__
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
have their past given to this model should not be passed as ``input_ids`` as they have already been
......@@ -640,7 +640,7 @@ class GPT2Model(GPT2PreTrainedModel):
if past_key_values is None:
past_length = 0
past_key_values = [None] * len(self.h)
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
......@@ -708,7 +708,7 @@ class GPT2Model(GPT2PreTrainedModel):
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = layer_past.to(hidden_states.device)
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
......@@ -717,19 +717,25 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
# checkpointing only works with tuple returns, not with lists
return tuple(output for output in module(*inputs, use_cache, output_attentions))
# None for past_key_value
return module(*inputs, use_cache, output_attentions)
return custom_forward
outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
layer_past,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
......@@ -932,6 +938,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings(
"""
......@@ -1095,6 +1113,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings(
"""
......
......@@ -466,7 +466,14 @@ class LayoutLMEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -1695,7 +1695,7 @@ class LEDEncoder(LEDPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None, None)
else:
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -1920,11 +1920,13 @@ class LEDDecoder(LEDPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
raise ValueError(
"When using `gradient_checkpointing`, make sure that `use_cache=False` and `config.use_cache=False`."
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -1226,7 +1226,7 @@ class LongformerEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -742,7 +742,7 @@ class MarianEncoder(MarianPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -958,11 +958,13 @@ class MarianDecoder(MarianPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
raise ValueError(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -780,7 +780,7 @@ class MBartEncoder(MBartPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -1002,11 +1002,13 @@ class MBartDecoder(MBartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
raise ValueError(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -746,7 +746,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -966,11 +966,13 @@ class PegasusDecoder(PegasusPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
raise ValueError(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -479,7 +479,14 @@ class RobertaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -1137,6 +1137,15 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.crit.cutoff_ends = [0] + new_cutoffs
self.crit.n_token = new_num_tokens
@staticmethod
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
"""
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
correct beam_idx at every generation step.
"""
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
@add_start_docstrings(
"""
......
......@@ -1462,6 +1462,15 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
attentions=transformer_outputs.attentions,
)
@staticmethod
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
"""
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
correct beam_idx at every generation step.
"""
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
@add_start_docstrings(
"""
......
......@@ -526,7 +526,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -2182,7 +2182,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -2374,11 +2374,11 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
raise ValueError(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`."
)
logger.warn("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...")
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
......
......@@ -131,6 +131,7 @@ class GPT2ModelTester:
n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment