Unverified Commit 27d46397 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make gradient_checkpointing a training argument (#13657)



* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 75f6641e
...@@ -78,8 +78,6 @@ class PegasusConfig(PretrainedConfig): ...@@ -78,8 +78,6 @@ class PegasusConfig(PretrainedConfig):
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details. https://arxiv.org/abs/1909.11556>`__ for more details.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model). Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
...@@ -128,7 +126,6 @@ class PegasusConfig(PretrainedConfig): ...@@ -128,7 +126,6 @@ class PegasusConfig(PretrainedConfig):
decoder_start_token_id=0, decoder_start_token_id=0,
classifier_dropout=0.0, classifier_dropout=0.0,
scale_embedding=False, scale_embedding=False,
gradient_checkpointing=False,
pad_token_id=0, pad_token_id=0,
eos_token_id=1, eos_token_id=1,
forced_eos_token_id=1, forced_eos_token_id=1,
...@@ -153,7 +150,6 @@ class PegasusConfig(PretrainedConfig): ...@@ -153,7 +150,6 @@ class PegasusConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
self.use_cache = use_cache self.use_cache = use_cache
self.num_hidden_layers = encoder_layers self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
......
...@@ -466,6 +466,7 @@ class PegasusDecoderLayer(nn.Module): ...@@ -466,6 +466,7 @@ class PegasusDecoderLayer(nn.Module):
class PegasusPreTrainedModel(PreTrainedModel): class PegasusPreTrainedModel(PreTrainedModel):
config_class = PegasusConfig config_class = PegasusConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -480,6 +481,10 @@ class PegasusPreTrainedModel(PreTrainedModel): ...@@ -480,6 +481,10 @@ class PegasusPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (PegasusDecoder, PegasusEncoder)):
module.gradient_checkpointing = value
PEGASUS_START_DOCSTRING = r""" PEGASUS_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
...@@ -646,6 +651,7 @@ class PegasusEncoder(PegasusPreTrainedModel): ...@@ -646,6 +651,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def resize_position_embeddings(self, new_num_position_embeddings: int): def resize_position_embeddings(self, new_num_position_embeddings: int):
""" """
...@@ -770,7 +776,7 @@ class PegasusEncoder(PegasusPreTrainedModel): ...@@ -770,7 +776,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -840,6 +846,7 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -840,6 +846,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
...@@ -1040,12 +1047,11 @@ class PegasusDecoder(PegasusPreTrainedModel): ...@@ -1040,12 +1047,11 @@ class PegasusDecoder(PegasusPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -92,8 +92,6 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -92,8 +92,6 @@ class ProphetNetConfig(PretrainedConfig):
smoothing is performed. smoothing is performed.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
""" """
model_type = "prophetnet" model_type = "prophetnet"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
...@@ -124,7 +122,6 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -124,7 +122,6 @@ class ProphetNetConfig(PretrainedConfig):
num_buckets=32, num_buckets=32,
relative_max_distance=128, relative_max_distance=128,
disable_ngram_loss=False, disable_ngram_loss=False,
gradient_checkpointing=False,
eps=0.0, eps=0.0,
use_cache=True, use_cache=True,
pad_token_id=0, pad_token_id=0,
...@@ -158,9 +155,6 @@ class ProphetNetConfig(PretrainedConfig): ...@@ -158,9 +155,6 @@ class ProphetNetConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
# 4 Training Args (should be removed soon)
self.gradient_checkpointing = gradient_checkpointing
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
bos_token_id=bos_token_id, bos_token_id=bos_token_id,
......
...@@ -547,6 +547,7 @@ class ProphetNetDecoderLMOutput(ModelOutput): ...@@ -547,6 +547,7 @@ class ProphetNetDecoderLMOutput(ModelOutput):
class ProphetNetPreTrainedModel(PreTrainedModel): class ProphetNetPreTrainedModel(PreTrainedModel):
config_class = ProphetNetConfig config_class = ProphetNetConfig
base_model_prefix = "prophetnet" base_model_prefix = "prophetnet"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
...@@ -558,6 +559,10 @@ class ProphetNetPreTrainedModel(PreTrainedModel): ...@@ -558,6 +559,10 @@ class ProphetNetPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (ProphetNetDecoder, ProphetNetEncoder)):
module.gradient_checkpointing = value
def _shift_right(self, input_ids): def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id pad_token_id = self.config.pad_token_id
...@@ -1262,6 +1267,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1262,6 +1267,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)]) self.layers = nn.ModuleList([ProphetNetEncoderLayer(config) for _ in range(config.num_encoder_layers)])
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self): def get_input_embeddings(self):
return self.word_embeddings return self.word_embeddings
...@@ -1337,7 +1343,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel): ...@@ -1337,7 +1343,7 @@ class ProphetNetEncoder(ProphetNetPreTrainedModel):
if output_hidden_states: if output_hidden_states:
encoder_hidden_states = encoder_hidden_states + (hidden_states,) encoder_hidden_states = encoder_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1406,6 +1412,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1406,6 +1412,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
self.embeddings_layer_norm = LayerNorm(config.hidden_size) self.embeddings_layer_norm = LayerNorm(config.hidden_size)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self): def get_input_embeddings(self):
return self.word_embeddings return self.word_embeddings
...@@ -1566,12 +1573,11 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): ...@@ -1566,12 +1573,11 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -76,8 +76,6 @@ class RemBertConfig(PretrainedConfig): ...@@ -76,8 +76,6 @@ class RemBertConfig(PretrainedConfig):
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``. relevant if ``config.is_decoder=True``.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example:: Example::
......
...@@ -501,6 +501,7 @@ class RemBertEncoder(nn.Module): ...@@ -501,6 +501,7 @@ class RemBertEncoder(nn.Module):
self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size) self.embedding_hidden_mapping_in = nn.Linear(config.input_embedding_size, config.hidden_size)
self.layer = nn.ModuleList([RemBertLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([RemBertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -528,12 +529,11 @@ class RemBertEncoder(nn.Module): ...@@ -528,12 +529,11 @@ class RemBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
...@@ -648,6 +648,7 @@ class RemBertPreTrainedModel(PreTrainedModel): ...@@ -648,6 +648,7 @@ class RemBertPreTrainedModel(PreTrainedModel):
config_class = RemBertConfig config_class = RemBertConfig
load_tf_weights = load_tf_weights_in_rembert load_tf_weights = load_tf_weights_in_rembert
base_model_prefix = "rembert" base_model_prefix = "rembert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -666,6 +667,10 @@ class RemBertPreTrainedModel(PreTrainedModel): ...@@ -666,6 +667,10 @@ class RemBertPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, RemBertEncoder):
module.gradient_checkpointing = value
REMBERT_START_DOCSTRING = r""" REMBERT_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
......
...@@ -469,6 +469,7 @@ class RobertaEncoder(nn.Module): ...@@ -469,6 +469,7 @@ class RobertaEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -495,12 +496,11 @@ class RobertaEncoder(nn.Module): ...@@ -495,12 +496,11 @@ class RobertaEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
...@@ -585,6 +585,7 @@ class RobertaPreTrainedModel(PreTrainedModel): ...@@ -585,6 +585,7 @@ class RobertaPreTrainedModel(PreTrainedModel):
config_class = RobertaConfig config_class = RobertaConfig
base_model_prefix = "roberta" base_model_prefix = "roberta"
supports_gradient_checkpointing = True
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
...@@ -603,6 +604,10 @@ class RobertaPreTrainedModel(PreTrainedModel): ...@@ -603,6 +604,10 @@ class RobertaPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, RobertaEncoder):
module.gradient_checkpointing = value
def update_keys_to_ignore(self, config, del_keys_to_ignore): def update_keys_to_ignore(self, config, del_keys_to_ignore):
"""Remove some keys from ignore list""" """Remove some keys from ignore list"""
if not config.tie_word_embeddings: if not config.tie_word_embeddings:
......
...@@ -80,8 +80,6 @@ class RoFormerConfig(PretrainedConfig): ...@@ -80,8 +80,6 @@ class RoFormerConfig(PretrainedConfig):
relevant if ``config.is_decoder=True``. relevant if ``config.is_decoder=True``.
rotary_value (:obj:`bool`, `optional`, defaults to :obj:`False`): rotary_value (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not apply rotary position embeddings on value layer. Whether or not apply rotary position embeddings on value layer.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
Example:: Example::
...@@ -114,7 +112,6 @@ class RoFormerConfig(PretrainedConfig): ...@@ -114,7 +112,6 @@ class RoFormerConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
pad_token_id=0, pad_token_id=0,
gradient_checkpointing=False,
rotary_value=False, rotary_value=False,
use_cache=True, use_cache=True,
**kwargs **kwargs
...@@ -134,6 +131,5 @@ class RoFormerConfig(PretrainedConfig): ...@@ -134,6 +131,5 @@ class RoFormerConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.rotary_value = rotary_value self.rotary_value = rotary_value
self.use_cache = use_cache self.use_cache = use_cache
...@@ -551,6 +551,7 @@ class RoFormerEncoder(nn.Module): ...@@ -551,6 +551,7 @@ class RoFormerEncoder(nn.Module):
config.max_position_embeddings, config.hidden_size // config.num_attention_heads config.max_position_embeddings, config.hidden_size // config.num_attention_heads
) )
self.layer = nn.ModuleList([RoFormerLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([RoFormerLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -580,12 +581,11 @@ class RoFormerEncoder(nn.Module): ...@@ -580,12 +581,11 @@ class RoFormerEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
...@@ -705,6 +705,7 @@ class RoFormerPreTrainedModel(PreTrainedModel): ...@@ -705,6 +705,7 @@ class RoFormerPreTrainedModel(PreTrainedModel):
config_class = RoFormerConfig config_class = RoFormerConfig
load_tf_weights = load_tf_weights_in_roformer load_tf_weights = load_tf_weights_in_roformer
base_model_prefix = "roformer" base_model_prefix = "roformer"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [] _keys_to_ignore_on_load_missing = []
_keys_to_ignore_on_load_unexpected = [ _keys_to_ignore_on_load_unexpected = [
r"roformer\.embeddings_project\.weight", r"roformer\.embeddings_project\.weight",
...@@ -729,6 +730,10 @@ class RoFormerPreTrainedModel(PreTrainedModel): ...@@ -729,6 +730,10 @@ class RoFormerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, RoFormerEncoder):
module.gradient_checkpointing = value
ROFORMER_START_DOCSTRING = r""" ROFORMER_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
......
...@@ -134,7 +134,6 @@ class Speech2TextConfig(PretrainedConfig): ...@@ -134,7 +134,6 @@ class Speech2TextConfig(PretrainedConfig):
decoder_start_token_id=2, decoder_start_token_id=2,
classifier_dropout=0.0, classifier_dropout=0.0,
scale_embedding=True, scale_embedding=True,
gradient_checkpointing=False,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
...@@ -165,7 +164,6 @@ class Speech2TextConfig(PretrainedConfig): ...@@ -165,7 +164,6 @@ class Speech2TextConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
self.use_cache = use_cache self.use_cache = use_cache
self.num_hidden_layers = encoder_layers self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.max_source_positions = max_source_positions self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions self.max_target_positions = max_target_positions
......
...@@ -531,6 +531,7 @@ class Speech2TextDecoderLayer(nn.Module): ...@@ -531,6 +531,7 @@ class Speech2TextDecoderLayer(nn.Module):
class Speech2TextPreTrainedModel(PreTrainedModel): class Speech2TextPreTrainedModel(PreTrainedModel):
config_class = Speech2TextConfig config_class = Speech2TextConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -543,6 +544,10 @@ class Speech2TextPreTrainedModel(PreTrainedModel): ...@@ -543,6 +544,10 @@ class Speech2TextPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (Speech2TextDecoder, Speech2TextEncoder)):
module.gradient_checkpointing = value
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor): def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
""" """
Computes the output length of the convolutional layers Computes the output length of the convolutional layers
...@@ -711,6 +716,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): ...@@ -711,6 +716,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -795,7 +801,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel): ...@@ -795,7 +801,7 @@ class Speech2TextEncoder(Speech2TextPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -863,6 +869,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -863,6 +869,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
...@@ -1032,11 +1039,11 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): ...@@ -1032,11 +1039,11 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
) )
use_cache = False use_cache = False
......
...@@ -108,7 +108,6 @@ class Speech2Text2Config(PretrainedConfig): ...@@ -108,7 +108,6 @@ class Speech2Text2Config(PretrainedConfig):
decoder_start_token_id=2, decoder_start_token_id=2,
classifier_dropout=0.0, classifier_dropout=0.0,
scale_embedding=True, scale_embedding=True,
gradient_checkpointing=False,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
...@@ -130,7 +129,6 @@ class Speech2Text2Config(PretrainedConfig): ...@@ -130,7 +129,6 @@ class Speech2Text2Config(PretrainedConfig):
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
self.use_cache = use_cache self.use_cache = use_cache
self.num_hidden_layers = decoder_layers self.num_hidden_layers = decoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
self.max_source_positions = max_source_positions self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions self.max_target_positions = max_target_positions
......
...@@ -407,6 +407,7 @@ class Speech2Text2DecoderLayer(nn.Module): ...@@ -407,6 +407,7 @@ class Speech2Text2DecoderLayer(nn.Module):
class Speech2Text2PreTrainedModel(PreTrainedModel): class Speech2Text2PreTrainedModel(PreTrainedModel):
config_class = Speech2Text2Config config_class = Speech2Text2Config
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -419,6 +420,10 @@ class Speech2Text2PreTrainedModel(PreTrainedModel): ...@@ -419,6 +420,10 @@ class Speech2Text2PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, Speech2Text2Decoder):
module.gradient_checkpointing = value
SPEECH_TO_TEXT_2_START_DOCSTRING = r""" SPEECH_TO_TEXT_2_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
...@@ -465,6 +470,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): ...@@ -465,6 +470,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
self.layers = nn.ModuleList([Speech2Text2DecoderLayer(config) for _ in range(config.decoder_layers)]) self.layers = nn.ModuleList([Speech2Text2DecoderLayer(config) for _ in range(config.decoder_layers)])
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
...@@ -635,11 +641,11 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): ...@@ -635,11 +641,11 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache = True` is incompatible with `config.gradient_checkpointing = True`. Setting `use_cache = False`..." "`use_cache = True` is incompatible with gradient checkpointing. Setting `use_cache = False`..."
) )
use_cache = False use_cache = False
......
...@@ -71,8 +71,6 @@ class SplinterConfig(PretrainedConfig): ...@@ -71,8 +71,6 @@ class SplinterConfig(PretrainedConfig):
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if ``config.is_decoder=True``. relevant if ``config.is_decoder=True``.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If :obj:`True`, use gradient checkpointing to save memory at the expense of slower backward pass.
question_token_id (:obj:`int`, `optional`, defaults to 104): question_token_id (:obj:`int`, `optional`, defaults to 104):
The id of the ``[QUESTION]`` token. The id of the ``[QUESTION]`` token.
......
...@@ -409,6 +409,7 @@ class SplinterEncoder(nn.Module): ...@@ -409,6 +409,7 @@ class SplinterEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([SplinterLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -435,12 +436,11 @@ class SplinterEncoder(nn.Module): ...@@ -435,12 +436,11 @@ class SplinterEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
...@@ -509,6 +509,7 @@ class SplinterPreTrainedModel(PreTrainedModel): ...@@ -509,6 +509,7 @@ class SplinterPreTrainedModel(PreTrainedModel):
config_class = SplinterConfig config_class = SplinterConfig
base_model_prefix = "splinter" base_model_prefix = "splinter"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
...@@ -528,6 +529,10 @@ class SplinterPreTrainedModel(PreTrainedModel): ...@@ -528,6 +529,10 @@ class SplinterPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, SplinterEncoder):
module.gradient_checkpointing = value
SPLINTER_START_DOCSTRING = r""" SPLINTER_START_DOCSTRING = r"""
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class. Use
......
...@@ -77,8 +77,6 @@ class T5Config(PretrainedConfig): ...@@ -77,8 +77,6 @@ class T5Config(PretrainedConfig):
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`. the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
""" """
model_type = "t5" model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
...@@ -102,7 +100,6 @@ class T5Config(PretrainedConfig): ...@@ -102,7 +100,6 @@ class T5Config(PretrainedConfig):
use_cache=True, use_cache=True,
pad_token_id=0, pad_token_id=0,
eos_token_id=1, eos_token_id=1,
gradient_checkpointing=False,
**kwargs **kwargs
): ):
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -120,7 +117,6 @@ class T5Config(PretrainedConfig): ...@@ -120,7 +117,6 @@ class T5Config(PretrainedConfig):
self.initializer_factor = initializer_factor self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache self.use_cache = use_cache
self.gradient_checkpointing = gradient_checkpointing
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
eos_token_id=eos_token_id, eos_token_id=eos_token_id,
......
...@@ -325,7 +325,7 @@ class T5Attention(nn.Module): ...@@ -325,7 +325,7 @@ class T5Attention(nn.Module):
if self.has_relative_attention_bias: if self.has_relative_attention_bias:
self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads) self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
self.pruned_heads = set() self.pruned_heads = set()
self.gradient_checkpointing = getattr(config, "gradient_checkpointing", False) self.gradient_checkpointing = False
def prune_heads(self, heads): def prune_heads(self, heads):
if len(heads) == 0: if len(heads) == 0:
...@@ -489,7 +489,7 @@ class T5Attention(nn.Module): ...@@ -489,7 +489,7 @@ class T5Attention(nn.Module):
position_bias = torch.zeros( position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
) )
if self.training and self.gradient_checkpointing: if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True position_bias.requires_grad = True
else: else:
position_bias = self.compute_bias(real_seq_length, key_length) position_bias = self.compute_bias(real_seq_length, key_length)
...@@ -715,6 +715,7 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -715,6 +715,7 @@ class T5PreTrainedModel(PreTrainedModel):
load_tf_weights = load_tf_weights_in_t5 load_tf_weights = load_tf_weights_in_t5
base_model_prefix = "transformer" base_model_prefix = "transformer"
is_parallelizable = True is_parallelizable = True
supports_gradient_checkpointing = True
@property @property
def dummy_inputs(self): def dummy_inputs(self):
...@@ -769,6 +770,10 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -769,6 +770,10 @@ class T5PreTrainedModel(PreTrainedModel):
if module.has_relative_attention_bias: if module.has_relative_attention_bias:
module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5)) module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (T5Attention, T5Stack)):
module.gradient_checkpointing = value
def _shift_right(self, input_ids): def _shift_right(self, input_ids):
decoder_start_token_id = self.config.decoder_start_token_id decoder_start_token_id = self.config.decoder_start_token_id
pad_token_id = self.config.pad_token_id pad_token_id = self.config.pad_token_id
...@@ -813,6 +818,7 @@ class T5Stack(T5PreTrainedModel): ...@@ -813,6 +818,7 @@ class T5Stack(T5PreTrainedModel):
# Model parallel # Model parallel
self.model_parallel = False self.model_parallel = False
self.device_map = None self.device_map = None
self.gradient_checkpointing = False
@add_start_docstrings(PARALLELIZE_DOCSTRING) @add_start_docstrings(PARALLELIZE_DOCSTRING)
def parallelize(self, device_map=None): def parallelize(self, device_map=None):
...@@ -968,11 +974,10 @@ class T5Stack(T5PreTrainedModel): ...@@ -968,11 +974,10 @@ class T5Stack(T5PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -73,8 +73,6 @@ class TapasConfig(PretrainedConfig): ...@@ -73,8 +73,6 @@ class TapasConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to use gradient checkpointing to save memory at the expense of a slower backward pass.
positive_label_weight (:obj:`float`, `optional`, defaults to 10.0): positive_label_weight (:obj:`float`, `optional`, defaults to 10.0):
Weight for positive labels. Weight for positive labels.
num_aggregation_labels (:obj:`int`, `optional`, defaults to 0): num_aggregation_labels (:obj:`int`, `optional`, defaults to 0):
...@@ -159,7 +157,6 @@ class TapasConfig(PretrainedConfig): ...@@ -159,7 +157,6 @@ class TapasConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
pad_token_id=0, pad_token_id=0,
gradient_checkpointing=False,
positive_label_weight=10.0, positive_label_weight=10.0,
num_aggregation_labels=0, num_aggregation_labels=0,
aggregation_loss_weight=1.0, aggregation_loss_weight=1.0,
...@@ -202,7 +199,6 @@ class TapasConfig(PretrainedConfig): ...@@ -202,7 +199,6 @@ class TapasConfig(PretrainedConfig):
self.type_vocab_sizes = type_vocab_sizes self.type_vocab_sizes = type_vocab_sizes
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
# Fine-tuning task hyperparameters # Fine-tuning task hyperparameters
self.positive_label_weight = positive_label_weight self.positive_label_weight = positive_label_weight
......
...@@ -627,6 +627,7 @@ class TapasEncoder(nn.Module): ...@@ -627,6 +627,7 @@ class TapasEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([TapasLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([TapasLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -649,7 +650,7 @@ class TapasEncoder(nn.Module): ...@@ -649,7 +650,7 @@ class TapasEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False): if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -763,6 +764,7 @@ class TapasPreTrainedModel(PreTrainedModel): ...@@ -763,6 +764,7 @@ class TapasPreTrainedModel(PreTrainedModel):
config_class = TapasConfig config_class = TapasConfig
base_model_prefix = "tapas" base_model_prefix = "tapas"
supports_gradient_checkpointing = True
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
def _init_weights(self, module): def _init_weights(self, module):
...@@ -781,6 +783,10 @@ class TapasPreTrainedModel(PreTrainedModel): ...@@ -781,6 +783,10 @@ class TapasPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, TapasEncoder):
module.gradient_checkpointing = value
TAPAS_START_DOCSTRING = r""" TAPAS_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
......
...@@ -398,6 +398,7 @@ class VisualBertEncoder(nn.Module): ...@@ -398,6 +398,7 @@ class VisualBertEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([VisualBertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -417,7 +418,7 @@ class VisualBertEncoder(nn.Module): ...@@ -417,7 +418,7 @@ class VisualBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -532,6 +533,7 @@ class VisualBertPreTrainedModel(PreTrainedModel): ...@@ -532,6 +533,7 @@ class VisualBertPreTrainedModel(PreTrainedModel):
config_class = VisualBertConfig config_class = VisualBertConfig
base_model_prefix = "visual_bert" base_model_prefix = "visual_bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -547,6 +549,10 @@ class VisualBertPreTrainedModel(PreTrainedModel): ...@@ -547,6 +549,10 @@ class VisualBertPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, VisualBertEncoder):
module.gradient_checkpointing = value
@dataclass @dataclass
class VisualBertForPreTrainingOutput(ModelOutput): class VisualBertForPreTrainingOutput(ModelOutput):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment