Unverified Commit b020a736 authored by Yusuke Mori's avatar Yusuke Mori Committed by GitHub
Browse files

Update `past_key_values` in GPT-2 (#9596)



* Update past_key_values in gpt2 (#9391)

* Update generation_utils, and rename some items

* Update modeling_gpt2 to avoid an error in gradient_checkpointing

* Remove 'reorder_cache' from util and add variations to XLNet, TransfoXL, GPT-2

* Change the location of '_reorder_cache' in modeling files

* Add '_reorder_cache' in modeling_ctrl

* Fix a bug of my last commit in CTRL

* Add '_reorder_cache' to GPT2DoubleHeadsModel

* Manage 'use_cache' in config of test_modeling_gpt2

* Clean up the doc string

* Update src/transformers/models/gpt2/modeling_gpt2.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix the doc string (GPT-2, CTRL)

* improve gradient_checkpointing_behavior
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 97b787fb
...@@ -503,18 +503,10 @@ class GenerationMixin: ...@@ -503,18 +503,10 @@ class GenerationMixin:
return model_kwargs return model_kwargs
@staticmethod def _reorder_cache(self, past, beam_idx):
def _reorder_cache(past: Tuple[torch.Tensor], beam_idx: torch.Tensor) -> Tuple[torch.Tensor]: raise NotImplementedError(
""" f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to enable beam search for {self.__class__}"
This function is used to re-order the :obj:`past_key_values` or :obj:`mems` cache if )
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` or :obj:`mems` with the correct beam_idx at every
generation step.
For custom re-ordering of :obj:`past_key_values` or :obj:`mems`, the function should be implemented in
subclasses of :class:`~transformers.PreTrainedModel`.
"""
return tuple(layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in past)
def _get_logits_warper( def _get_logits_warper(
self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None self, top_k: int = None, top_p: float = None, temperature: float = None, num_beams: int = None
......
...@@ -774,7 +774,7 @@ class BartEncoder(BartPretrainedModel): ...@@ -774,7 +774,7 @@ class BartEncoder(BartPretrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -993,11 +993,13 @@ class BartDecoder(BartPretrainedModel): ...@@ -993,11 +993,13 @@ class BartDecoder(BartPretrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -539,7 +539,14 @@ class BertEncoder(nn.Module): ...@@ -539,7 +539,14 @@ class BertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -733,7 +733,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel): ...@@ -733,7 +733,7 @@ class BlenderbotEncoder(BlenderbotPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -955,11 +955,13 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): ...@@ -955,11 +955,13 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -735,7 +735,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): ...@@ -735,7 +735,7 @@ class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -955,11 +955,13 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): ...@@ -955,11 +955,13 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
""" PyTorch CTRL model.""" """ PyTorch CTRL model."""
from typing import Tuple
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -262,7 +264,7 @@ CTRL_INPUTS_DOCSTRING = r""" ...@@ -262,7 +264,7 @@ CTRL_INPUTS_DOCSTRING = r"""
details. details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): past_key_values (:obj:`Tuple[Tuple[torch.FloatTensor]]` of length :obj:`config.n_layers`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model (see
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
have their past given to this model should not be passed as input ids as they have already been computed. have their past given to this model should not be passed as input ids as they have already been computed.
...@@ -389,7 +391,7 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -389,7 +391,7 @@ class CTRLModel(CTRLPreTrainedModel):
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
past_key_values = [None] * len(self.h) past_key_values = tuple([None] * len(self.h))
else: else:
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
...@@ -575,6 +577,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -575,6 +577,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -536,7 +536,14 @@ class ElectraEncoder(nn.Module): ...@@ -536,7 +536,14 @@ class ElectraEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -233,7 +233,7 @@ class Attention(nn.Module): ...@@ -233,7 +233,7 @@ class Attention(nn.Module):
value = torch.cat((past_value, value), dim=-2) value = torch.cat((past_value, value), dim=-2)
if use_cache is True: if use_cache is True:
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking present = (key.transpose(-2, -1), value) # transpose to have same shapes
else: else:
present = None present = None
...@@ -370,9 +370,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): ...@@ -370,9 +370,9 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`): mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax). Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2, Tuple of length :obj:`config.n_layers`, containing tuples of tensors of shape :obj:`(batch_size, num_heads,
batch_size, num_heads, sequence_length, embed_size_per_head)`). sequence_length, embed_size_per_head)`).
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
:obj:`past_key_values` input) to speed up sequential decoding. :obj:`past_key_values` input) to speed up sequential decoding.
...@@ -393,7 +393,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput): ...@@ -393,7 +393,7 @@ class GPT2DoubleHeadsModelOutput(ModelOutput):
mc_loss: Optional[torch.FloatTensor] = None mc_loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None logits: torch.FloatTensor = None
mc_logits: torch.FloatTensor = None mc_logits: torch.FloatTensor = None
past_key_values: Optional[List[torch.FloatTensor]] = None past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None
...@@ -419,7 +419,7 @@ GPT2_INPUTS_DOCSTRING = r""" ...@@ -419,7 +419,7 @@ GPT2_INPUTS_DOCSTRING = r"""
Args: Args:
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`): input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input ``past_key_values[0][0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
sequence tokens in the vocabulary. sequence tokens in the vocabulary.
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
...@@ -430,7 +430,7 @@ GPT2_INPUTS_DOCSTRING = r""" ...@@ -430,7 +430,7 @@ GPT2_INPUTS_DOCSTRING = r"""
details. details.
`What are input IDs? <../glossary.html#input-ids>`__ `What are input IDs? <../glossary.html#input-ids>`__
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`): past_key_values (:obj:`Tuple[Tuple[torch.Tensor]]` of length :obj:`config.n_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
have their past given to this model should not be passed as ``input_ids`` as they have already been have their past given to this model should not be passed as ``input_ids`` as they have already been
...@@ -640,7 +640,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -640,7 +640,7 @@ class GPT2Model(GPT2PreTrainedModel):
if past_key_values is None: if past_key_values is None:
past_length = 0 past_length = 0
past_key_values = [None] * len(self.h) past_key_values = tuple([None] * len(self.h))
else: else:
past_length = past_key_values[0][0].size(-2) past_length = past_key_values[0][0].size(-2)
if position_ids is None: if position_ids is None:
...@@ -708,7 +708,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -708,7 +708,7 @@ class GPT2Model(GPT2PreTrainedModel):
torch.cuda.set_device(hidden_states.device) torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct) # Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None: if layer_past is not None:
layer_past = layer_past.to(hidden_states.device) layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states # Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None: if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device) attention_mask = attention_mask.to(hidden_states.device)
...@@ -717,19 +717,25 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -717,19 +717,25 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
# checkpointing only works with tuple returns, not with lists # None for past_key_value
return tuple(output for output in module(*inputs, use_cache, output_attentions)) return module(*inputs, use_cache, output_attentions)
return custom_forward return custom_forward
outputs = torch.utils.checkpoint.checkpoint( outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), create_custom_forward(block),
hidden_states, hidden_states,
layer_past, None,
attention_mask, attention_mask,
head_mask[i], head_mask[i],
encoder_hidden_states, encoder_hidden_states,
...@@ -932,6 +938,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -932,6 +938,18 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
cross_attentions=transformer_outputs.cross_attentions, cross_attentions=transformer_outputs.cross_attentions,
) )
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings( @add_start_docstrings(
""" """
...@@ -1095,6 +1113,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ...@@ -1095,6 +1113,18 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@staticmethod
def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past
)
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -466,7 +466,14 @@ class LayoutLMEncoder(nn.Module): ...@@ -466,7 +466,14 @@ class LayoutLMEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -1695,7 +1695,7 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1695,7 +1695,7 @@ class LEDEncoder(LEDPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None, None) layer_outputs = (None, None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1920,11 +1920,13 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -1920,11 +1920,13 @@ class LEDDecoder(LEDPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing`, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -1226,7 +1226,7 @@ class LongformerEncoder(nn.Module): ...@@ -1226,7 +1226,7 @@ class LongformerEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -742,7 +742,7 @@ class MarianEncoder(MarianPreTrainedModel): ...@@ -742,7 +742,7 @@ class MarianEncoder(MarianPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -958,11 +958,13 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -958,11 +958,13 @@ class MarianDecoder(MarianPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -780,7 +780,7 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -780,7 +780,7 @@ class MBartEncoder(MBartPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1002,11 +1002,13 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1002,11 +1002,13 @@ class MBartDecoder(MBartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -746,7 +746,7 @@ class PegasusEncoder(PegasusPreTrainedModel): ...@@ -746,7 +746,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -966,11 +966,13 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -966,11 +966,13 @@ class PegasusDecoder(PegasusPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn(
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." "`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
) )
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -479,7 +479,14 @@ class RobertaEncoder(nn.Module): ...@@ -479,7 +479,14 @@ class RobertaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache:
logger.warn(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..."
)
use_cache = False
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -1137,6 +1137,15 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1137,6 +1137,15 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.crit.cutoff_ends = [0] + new_cutoffs self.crit.cutoff_ends = [0] + new_cutoffs
self.crit.n_token = new_num_tokens self.crit.n_token = new_num_tokens
@staticmethod
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
"""
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
correct beam_idx at every generation step.
"""
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -1462,6 +1462,15 @@ class XLNetLMHeadModel(XLNetPreTrainedModel): ...@@ -1462,6 +1462,15 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
attentions=transformer_outputs.attentions, attentions=transformer_outputs.attentions,
) )
@staticmethod
def _reorder_cache(mems: List[torch.Tensor], beam_idx: torch.Tensor) -> List[torch.Tensor]:
"""
This function is used to re-order the :obj:`mems` cache if :meth:`~transformers.PretrainedModel.beam_search` or
:meth:`~transformers.PretrainedModel.beam_sample` is called. This is required to match :obj:`mems` with the
correct beam_idx at every generation step.
"""
return [layer_past.index_select(1, beam_idx.to(layer_past.device)) for layer_past in mems]
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -526,7 +526,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module): ...@@ -526,7 +526,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -2182,7 +2182,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model ...@@ -2182,7 +2182,7 @@ class {{cookiecutter.camelcase_modelname}}Encoder({{cookiecutter.camelcase_model
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -2374,11 +2374,11 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model ...@@ -2374,11 +2374,11 @@ class {{cookiecutter.camelcase_modelname}}Decoder({{cookiecutter.camelcase_model
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False): if getattr(self.config, "gradient_checkpointing", False) and self.training:
if use_cache: if use_cache:
raise ValueError( logger.warn("`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`...")
"When using `gradient_checkpointing, make sure that `use_cache=False` and `config.use_cache=False`." use_cache = False
)
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
......
...@@ -131,6 +131,7 @@ class GPT2ModelTester: ...@@ -131,6 +131,7 @@ class GPT2ModelTester:
n_ctx=self.max_position_embeddings, n_ctx=self.max_position_embeddings,
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range, # initializer_range=self.initializer_range,
use_cache=not gradient_checkpointing,
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment