"...git@developer.sourcefind.cn:modelzoo/qwen-vl_pytorch.git" did not exist on "5e887c2c06894d3d607a4f17b096462bd131a6eb"
Unverified Commit 27d46397 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make gradient_checkpointing a training argument (#13657)



* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 75f6641e
......@@ -525,6 +525,7 @@ class HubertEncoder(nn.Module):
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -564,7 +565,7 @@ class HubertEncoder(nn.Module):
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -612,6 +613,7 @@ class HubertEncoderStableLayerNorm(nn.Module):
self.layers = nn.ModuleList(
[HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward(
self,
......@@ -651,7 +653,7 @@ class HubertEncoderStableLayerNorm(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -698,6 +700,7 @@ class HubertPreTrainedModel(PreTrainedModel):
config_class = HubertConfig
base_model_prefix = "hubert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -725,6 +728,10 @@ class HubertPreTrainedModel(PreTrainedModel):
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)):
module.gradient_checkpointing = value
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
"""
Computes the output length of the convolutional layers
......
......@@ -579,17 +579,13 @@ class IBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
raise NotImplementedError("gradient checkpointing is not currently supported")
else:
layer_outputs = layer_module(
hidden_states,
hidden_states_scaling_factor,
attention_mask,
layer_head_mask,
output_attentions,
)
layer_outputs = layer_module(
hidden_states,
hidden_states_scaling_factor,
attention_mask,
layer_head_mask,
output_attentions,
)
hidden_states = layer_outputs[0]
if output_attentions:
......
......@@ -71,8 +71,6 @@ class LayoutLMConfig(BertConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
max_2d_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
The maximum value that the 2D position embedding might ever used. Typically set this to something large
just in case (e.g., 1024).
......@@ -108,7 +106,6 @@ class LayoutLMConfig(BertConfig):
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
max_2d_position_embeddings=1024,
**kwargs
):
......@@ -126,7 +123,6 @@ class LayoutLMConfig(BertConfig):
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
pad_token_id=pad_token_id,
gradient_checkpointing=gradient_checkpointing,
**kwargs,
)
self.max_2d_position_embeddings = max_2d_position_embeddings
......
......@@ -442,6 +442,7 @@ class LayoutLMEncoder(nn.Module):
super().__init__()
self.config = config
self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -468,12 +469,11 @@ class LayoutLMEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......@@ -609,6 +609,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
config_class = LayoutLMConfig
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlm"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -627,6 +628,10 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LayoutLMEncoder):
module.gradient_checkpointing = value
LAYOUTLM_START_DOCSTRING = r"""
The LayoutLM model was proposed in `LayoutLM: Pre-training of Text and Layout for Document Image Understanding
......
......@@ -378,6 +378,8 @@ class LayoutLMv2Encoder(nn.Module):
self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
self.gradient_checkpointing = False
def _calculate_1d_position_embeddings(self, hidden_states, position_ids):
rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
rel_pos = relative_position_bucket(
......@@ -443,7 +445,7 @@ class LayoutLMv2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -502,6 +504,7 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
config_class = LayoutLMv2Config
pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlmv2"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -520,6 +523,10 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LayoutLMv2Encoder):
module.gradient_checkpointing = value
def my_convert_sync_batchnorm(module, process_group=None):
# same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d`
......
......@@ -82,8 +82,6 @@ class LEDConfig(PretrainedConfig):
https://arxiv.org/abs/1909.11556>`__ for more details.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models)
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
......@@ -132,7 +130,6 @@ class LEDConfig(PretrainedConfig):
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
gradient_checkpointing=False,
attention_window: Union[List[int], int] = 512,
**kwargs
):
......@@ -157,7 +154,6 @@ class LEDConfig(PretrainedConfig):
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.attention_window = attention_window
self.gradient_checkpointing = gradient_checkpointing
super().__init__(
pad_token_id=pad_token_id,
......
......@@ -1077,6 +1077,7 @@ class LEDClassificationHead(nn.Module):
class LEDPreTrainedModel(PreTrainedModel):
config_class = LEDConfig
base_model_prefix = "led"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
......@@ -1089,6 +1090,10 @@ class LEDPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LEDDecoder, LEDEncoder)):
module.gradient_checkpointing = value
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
......@@ -1625,6 +1630,7 @@ class LEDEncoder(LEDPreTrainedModel):
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.init_weights()
self.gradient_checkpointing = False
def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
......@@ -1809,7 +1815,7 @@ class LEDEncoder(LEDPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -1894,6 +1900,7 @@ class LEDDecoder(LEDPreTrainedModel):
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
......@@ -2061,12 +2068,11 @@ class LEDDecoder(LEDPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -1231,6 +1231,7 @@ class LongformerEncoder(nn.Module):
super().__init__()
self.config = config
self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -1259,7 +1260,7 @@ class LongformerEncoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -1363,6 +1364,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
config_class = LongformerConfig
base_model_prefix = "longformer"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -1381,6 +1383,10 @@ class LongformerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LongformerEncoder):
module.gradient_checkpointing = value
LONGFORMER_START_DOCSTRING = r"""
......
......@@ -68,8 +68,6 @@ class LukeConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
use_entity_aware_attention (:obj:`bool`, defaults to :obj:`True`):
Whether or not the model should use the entity-aware self-attention mechanism proposed in `LUKE: Deep
Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.)
......@@ -106,7 +104,6 @@ class LukeConfig(PretrainedConfig):
type_vocab_size=2,
initializer_range=0.02,
layer_norm_eps=1e-12,
gradient_checkpointing=False,
use_entity_aware_attention=True,
pad_token_id=1,
bos_token_id=0,
......@@ -130,5 +127,4 @@ class LukeConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.use_entity_aware_attention = use_entity_aware_attention
......@@ -579,6 +579,7 @@ class LukeEncoder(nn.Module):
super().__init__()
self.config = config
self.layer = nn.ModuleList([LukeLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
......@@ -600,7 +601,7 @@ class LukeEncoder(nn.Module):
all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False):
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -681,6 +682,7 @@ class LukePreTrainedModel(PreTrainedModel):
config_class = LukeConfig
base_model_prefix = "luke"
supports_gradient_checkpointing = True
def _init_weights(self, module: nn.Module):
"""Initialize the weights"""
......@@ -699,6 +701,10 @@ class LukePreTrainedModel(PreTrainedModel):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LukeEncoder):
module.gradient_checkpointing = value
LUKE_START_DOCSTRING = r"""
......
......@@ -79,8 +79,6 @@ class M2M100Config(PretrainedConfig):
https://arxiv.org/abs/1909.11556>`__ for more details.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
......@@ -121,7 +119,6 @@ class M2M100Config(PretrainedConfig):
init_std=0.02,
decoder_start_token_id=2,
scale_embedding=True,
gradient_checkpointing=False,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
......@@ -145,7 +142,6 @@ class M2M100Config(PretrainedConfig):
self.decoder_layerdrop = decoder_layerdrop
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
......
......@@ -520,6 +520,7 @@ class M2M100DecoderLayer(nn.Module):
class M2M100PreTrainedModel(PreTrainedModel):
config_class = M2M100Config
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
......@@ -532,6 +533,10 @@ class M2M100PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (M2M100Decoder, M2M100Encoder)):
module.gradient_checkpointing = value
M2M_100_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
......@@ -693,6 +698,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
......@@ -787,7 +793,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -857,6 +863,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
......@@ -1013,12 +1020,11 @@ class M2M100Decoder(M2M100PreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -78,8 +78,6 @@ class MarianConfig(PretrainedConfig):
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
......@@ -128,7 +126,6 @@ class MarianConfig(PretrainedConfig):
decoder_start_token_id=58100,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
pad_token_id=58100,
eos_token_id=0,
forced_eos_token_id=0,
......@@ -153,7 +150,6 @@ class MarianConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
pad_token_id=pad_token_id,
......
......@@ -466,6 +466,7 @@ class MarianDecoderLayer(nn.Module):
class MarianPreTrainedModel(PreTrainedModel):
config_class = MarianConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
......@@ -480,6 +481,10 @@ class MarianPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (MarianDecoder, MarianEncoder)):
module.gradient_checkpointing = value
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
......@@ -656,6 +661,7 @@ class MarianEncoder(MarianPreTrainedModel):
)
self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)])
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
......@@ -750,7 +756,7 @@ class MarianEncoder(MarianPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -816,6 +822,7 @@ class MarianDecoder(MarianPreTrainedModel):
)
self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.embed_tokens
......@@ -987,12 +994,11 @@ class MarianDecoder(MarianPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -82,8 +82,6 @@ class MBartConfig(PretrainedConfig):
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
......@@ -131,7 +129,6 @@ class MBartConfig(PretrainedConfig):
init_std=0.02,
classifier_dropout=0.0,
scale_embedding=False,
gradient_checkpointing=False,
pad_token_id=1,
bos_token_id=0,
eos_token_id=2,
......@@ -157,7 +154,6 @@ class MBartConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout
self.use_cache = use_cache
self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__(
pad_token_id=pad_token_id,
......
......@@ -479,6 +479,7 @@ class MBartClassificationHead(nn.Module):
class MBartPreTrainedModel(PreTrainedModel):
config_class = MBartConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module):
std = self.config.init_std
......@@ -491,6 +492,10 @@ class MBartPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (MBartDecoder, MBartDecoder)):
module.gradient_checkpointing = value
@property
def dummy_inputs(self):
pad_token = self.config.pad_token_id
......@@ -685,6 +690,7 @@ class MBartEncoder(MBartPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def forward(
self,
......@@ -780,7 +786,7 @@ class MBartEncoder(MBartPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
......@@ -850,6 +856,7 @@ class MBartDecoder(MBartPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self):
return self.embed_tokens
......@@ -1022,12 +1029,11 @@ class MBartDecoder(MBartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
)
use_cache = False
......
......@@ -65,8 +65,6 @@ class MegatronBertConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
......@@ -108,7 +106,6 @@ class MegatronBertConfig(PretrainedConfig):
initializer_range=0.02,
layer_norm_eps=1e-12,
pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute",
use_cache=True,
**kwargs
......@@ -127,6 +124,5 @@ class MegatronBertConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type
self.use_cache = use_cache
......@@ -180,7 +180,6 @@ def convert_megatron_checkpoint(args, input_state_dict):
"type_vocab_size": 2,
"initializer_range": 0.2,
"layer_norm_eps": 1e-12,
"gradient_checkpointing": False,
"position_embedding_type": "absolute",
"use_cache": False,
}
......
......@@ -508,6 +508,7 @@ class MegatronBertEncoder(nn.Module):
# The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one
# is simply the final LN (Transformer's BERT has it attached to each hidden layer).
self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
def forward(
self,
......@@ -534,12 +535,11 @@ class MegatronBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training:
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
"`use_cache=False`..."
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
......@@ -705,6 +705,7 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
config_class = MegatronBertConfig
load_tf_weights = load_tf_weights_in_megatron_bert
base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module):
......@@ -719,6 +720,10 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MegatronBertEncoder):
module.gradient_checkpointing = value
@dataclass
# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert
......
......@@ -279,7 +279,6 @@ def main():
summary_proj_to_labels=True,
summary_first_dropout=0.1,
scale_attn_weights=True,
gradient_checkpointing=False,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment