Unverified Commit 27d46397 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Make gradient_checkpointing a training argument (#13657)



* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo
Co-authored-by: default avatarStas Bekman <stas00@users.noreply.github.com>
Co-authored-by: default avatarStas Bekman <stas@stason.org>
parent 75f6641e
...@@ -525,6 +525,7 @@ class HubertEncoder(nn.Module): ...@@ -525,6 +525,7 @@ class HubertEncoder(nn.Module):
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -564,7 +565,7 @@ class HubertEncoder(nn.Module): ...@@ -564,7 +565,7 @@ class HubertEncoder(nn.Module):
skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function # create gradient checkpointing function
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -612,6 +613,7 @@ class HubertEncoderStableLayerNorm(nn.Module): ...@@ -612,6 +613,7 @@ class HubertEncoderStableLayerNorm(nn.Module):
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
) )
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -651,7 +653,7 @@ class HubertEncoderStableLayerNorm(nn.Module): ...@@ -651,7 +653,7 @@ class HubertEncoderStableLayerNorm(nn.Module):
if not skip_the_layer or deepspeed_zero3_is_enabled: if not skip_the_layer or deepspeed_zero3_is_enabled:
# under deepspeed zero3 all gpus must run in sync # under deepspeed zero3 all gpus must run in sync
# XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
# create gradient checkpointing function # create gradient checkpointing function
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -698,6 +700,7 @@ class HubertPreTrainedModel(PreTrainedModel): ...@@ -698,6 +700,7 @@ class HubertPreTrainedModel(PreTrainedModel):
config_class = HubertConfig config_class = HubertConfig
base_model_prefix = "hubert" base_model_prefix = "hubert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -725,6 +728,10 @@ class HubertPreTrainedModel(PreTrainedModel): ...@@ -725,6 +728,10 @@ class HubertPreTrainedModel(PreTrainedModel):
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (HubertEncoder, HubertEncoderStableLayerNorm)):
module.gradient_checkpointing = value
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
""" """
Computes the output length of the convolutional layers Computes the output length of the convolutional layers
......
...@@ -579,17 +579,13 @@ class IBertEncoder(nn.Module): ...@@ -579,17 +579,13 @@ class IBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: layer_outputs = layer_module(
raise NotImplementedError("gradient checkpointing is not currently supported") hidden_states,
hidden_states_scaling_factor,
else: attention_mask,
layer_outputs = layer_module( layer_head_mask,
hidden_states, output_attentions,
hidden_states_scaling_factor, )
attention_mask,
layer_head_mask,
output_attentions,
)
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if output_attentions: if output_attentions:
......
...@@ -71,8 +71,6 @@ class LayoutLMConfig(BertConfig): ...@@ -71,8 +71,6 @@ class LayoutLMConfig(BertConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
max_2d_position_embeddings (:obj:`int`, `optional`, defaults to 1024): max_2d_position_embeddings (:obj:`int`, `optional`, defaults to 1024):
The maximum value that the 2D position embedding might ever used. Typically set this to something large The maximum value that the 2D position embedding might ever used. Typically set this to something large
just in case (e.g., 1024). just in case (e.g., 1024).
...@@ -108,7 +106,6 @@ class LayoutLMConfig(BertConfig): ...@@ -108,7 +106,6 @@ class LayoutLMConfig(BertConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
pad_token_id=0, pad_token_id=0,
gradient_checkpointing=False,
max_2d_position_embeddings=1024, max_2d_position_embeddings=1024,
**kwargs **kwargs
): ):
...@@ -126,7 +123,6 @@ class LayoutLMConfig(BertConfig): ...@@ -126,7 +123,6 @@ class LayoutLMConfig(BertConfig):
initializer_range=initializer_range, initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps, layer_norm_eps=layer_norm_eps,
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
gradient_checkpointing=gradient_checkpointing,
**kwargs, **kwargs,
) )
self.max_2d_position_embeddings = max_2d_position_embeddings self.max_2d_position_embeddings = max_2d_position_embeddings
......
...@@ -442,6 +442,7 @@ class LayoutLMEncoder(nn.Module): ...@@ -442,6 +442,7 @@ class LayoutLMEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([LayoutLMLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -468,12 +469,11 @@ class LayoutLMEncoder(nn.Module): ...@@ -468,12 +469,11 @@ class LayoutLMEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
...@@ -609,6 +609,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel): ...@@ -609,6 +609,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
config_class = LayoutLMConfig config_class = LayoutLMConfig
pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST pretrained_model_archive_map = LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlm" base_model_prefix = "layoutlm"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -627,6 +628,10 @@ class LayoutLMPreTrainedModel(PreTrainedModel): ...@@ -627,6 +628,10 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LayoutLMEncoder):
module.gradient_checkpointing = value
LAYOUTLM_START_DOCSTRING = r""" LAYOUTLM_START_DOCSTRING = r"""
The LayoutLM model was proposed in `LayoutLM: Pre-training of Text and Layout for Document Image Understanding The LayoutLM model was proposed in `LayoutLM: Pre-training of Text and Layout for Document Image Understanding
......
...@@ -378,6 +378,8 @@ class LayoutLMv2Encoder(nn.Module): ...@@ -378,6 +378,8 @@ class LayoutLMv2Encoder(nn.Module):
self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False) self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
self.gradient_checkpointing = False
def _calculate_1d_position_embeddings(self, hidden_states, position_ids): def _calculate_1d_position_embeddings(self, hidden_states, position_ids):
rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1) rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
rel_pos = relative_position_bucket( rel_pos = relative_position_bucket(
...@@ -443,7 +445,7 @@ class LayoutLMv2Encoder(nn.Module): ...@@ -443,7 +445,7 @@ class LayoutLMv2Encoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -502,6 +504,7 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): ...@@ -502,6 +504,7 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
config_class = LayoutLMv2Config config_class = LayoutLMv2Config
pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlmv2" base_model_prefix = "layoutlmv2"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -520,6 +523,10 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): ...@@ -520,6 +523,10 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LayoutLMv2Encoder):
module.gradient_checkpointing = value
def my_convert_sync_batchnorm(module, process_group=None): def my_convert_sync_batchnorm(module, process_group=None):
# same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d` # same as `nn.modules.SyncBatchNorm.convert_sync_batchnorm` but allowing converting from `detectron2.layers.FrozenBatchNorm2d`
......
...@@ -82,8 +82,6 @@ class LEDConfig(PretrainedConfig): ...@@ -82,8 +82,6 @@ class LEDConfig(PretrainedConfig):
https://arxiv.org/abs/1909.11556>`__ for more details. https://arxiv.org/abs/1909.11556>`__ for more details.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models) Whether or not the model should return the last key/values attentions (not used by all models)
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example:: Example::
...@@ -132,7 +130,6 @@ class LEDConfig(PretrainedConfig): ...@@ -132,7 +130,6 @@ class LEDConfig(PretrainedConfig):
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
gradient_checkpointing=False,
attention_window: Union[List[int], int] = 512, attention_window: Union[List[int], int] = 512,
**kwargs **kwargs
): ):
...@@ -157,7 +154,6 @@ class LEDConfig(PretrainedConfig): ...@@ -157,7 +154,6 @@ class LEDConfig(PretrainedConfig):
self.use_cache = use_cache self.use_cache = use_cache
self.num_hidden_layers = encoder_layers self.num_hidden_layers = encoder_layers
self.attention_window = attention_window self.attention_window = attention_window
self.gradient_checkpointing = gradient_checkpointing
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
......
...@@ -1077,6 +1077,7 @@ class LEDClassificationHead(nn.Module): ...@@ -1077,6 +1077,7 @@ class LEDClassificationHead(nn.Module):
class LEDPreTrainedModel(PreTrainedModel): class LEDPreTrainedModel(PreTrainedModel):
config_class = LEDConfig config_class = LEDConfig
base_model_prefix = "led" base_model_prefix = "led"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -1089,6 +1090,10 @@ class LEDPreTrainedModel(PreTrainedModel): ...@@ -1089,6 +1090,10 @@ class LEDPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LEDDecoder, LEDEncoder)):
module.gradient_checkpointing = value
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
...@@ -1625,6 +1630,7 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1625,6 +1630,7 @@ class LEDEncoder(LEDPreTrainedModel):
self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor): def _merge_to_attention_mask(self, attention_mask: torch.Tensor, global_attention_mask: torch.Tensor):
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
...@@ -1809,7 +1815,7 @@ class LEDEncoder(LEDPreTrainedModel): ...@@ -1809,7 +1815,7 @@ class LEDEncoder(LEDPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None, None) layer_outputs = (None, None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1894,6 +1900,7 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -1894,6 +1900,7 @@ class LEDDecoder(LEDPreTrainedModel):
self.layernorm_embedding = nn.LayerNorm(config.d_model) self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -2061,12 +2068,11 @@ class LEDDecoder(LEDPreTrainedModel): ...@@ -2061,12 +2068,11 @@ class LEDDecoder(LEDPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -1231,6 +1231,7 @@ class LongformerEncoder(nn.Module): ...@@ -1231,6 +1231,7 @@ class LongformerEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([LongformerLayer(config, layer_id=i) for i in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -1259,7 +1260,7 @@ class LongformerEncoder(nn.Module): ...@@ -1259,7 +1260,7 @@ class LongformerEncoder(nn.Module):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -1363,6 +1364,7 @@ class LongformerPreTrainedModel(PreTrainedModel): ...@@ -1363,6 +1364,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
config_class = LongformerConfig config_class = LongformerConfig
base_model_prefix = "longformer" base_model_prefix = "longformer"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -1381,6 +1383,10 @@ class LongformerPreTrainedModel(PreTrainedModel): ...@@ -1381,6 +1383,10 @@ class LongformerPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LongformerEncoder):
module.gradient_checkpointing = value
LONGFORMER_START_DOCSTRING = r""" LONGFORMER_START_DOCSTRING = r"""
......
...@@ -68,8 +68,6 @@ class LukeConfig(PretrainedConfig): ...@@ -68,8 +68,6 @@ class LukeConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
use_entity_aware_attention (:obj:`bool`, defaults to :obj:`True`): use_entity_aware_attention (:obj:`bool`, defaults to :obj:`True`):
Whether or not the model should use the entity-aware self-attention mechanism proposed in `LUKE: Deep Whether or not the model should use the entity-aware self-attention mechanism proposed in `LUKE: Deep
Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.) Contextualized Entity Representations with Entity-aware Self-attention (Yamada et al.)
...@@ -106,7 +104,6 @@ class LukeConfig(PretrainedConfig): ...@@ -106,7 +104,6 @@ class LukeConfig(PretrainedConfig):
type_vocab_size=2, type_vocab_size=2,
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
gradient_checkpointing=False,
use_entity_aware_attention=True, use_entity_aware_attention=True,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
...@@ -130,5 +127,4 @@ class LukeConfig(PretrainedConfig): ...@@ -130,5 +127,4 @@ class LukeConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.use_entity_aware_attention = use_entity_aware_attention self.use_entity_aware_attention = use_entity_aware_attention
...@@ -579,6 +579,7 @@ class LukeEncoder(nn.Module): ...@@ -579,6 +579,7 @@ class LukeEncoder(nn.Module):
super().__init__() super().__init__()
self.config = config self.config = config
self.layer = nn.ModuleList([LukeLayer(config) for _ in range(config.num_hidden_layers)]) self.layer = nn.ModuleList([LukeLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -600,7 +601,7 @@ class LukeEncoder(nn.Module): ...@@ -600,7 +601,7 @@ class LukeEncoder(nn.Module):
all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,) all_entity_hidden_states = all_entity_hidden_states + (entity_hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
if getattr(self.config, "gradient_checkpointing", False): if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -681,6 +682,7 @@ class LukePreTrainedModel(PreTrainedModel): ...@@ -681,6 +682,7 @@ class LukePreTrainedModel(PreTrainedModel):
config_class = LukeConfig config_class = LukeConfig
base_model_prefix = "luke" base_model_prefix = "luke"
supports_gradient_checkpointing = True
def _init_weights(self, module: nn.Module): def _init_weights(self, module: nn.Module):
"""Initialize the weights""" """Initialize the weights"""
...@@ -699,6 +701,10 @@ class LukePreTrainedModel(PreTrainedModel): ...@@ -699,6 +701,10 @@ class LukePreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, LukeEncoder):
module.gradient_checkpointing = value
LUKE_START_DOCSTRING = r""" LUKE_START_DOCSTRING = r"""
......
...@@ -79,8 +79,6 @@ class M2M100Config(PretrainedConfig): ...@@ -79,8 +79,6 @@ class M2M100Config(PretrainedConfig):
https://arxiv.org/abs/1909.11556>`__ for more details. https://arxiv.org/abs/1909.11556>`__ for more details.
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether or not the model should return the last key/values attentions (not used by all models). Whether or not the model should return the last key/values attentions (not used by all models).
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example:: Example::
...@@ -121,7 +119,6 @@ class M2M100Config(PretrainedConfig): ...@@ -121,7 +119,6 @@ class M2M100Config(PretrainedConfig):
init_std=0.02, init_std=0.02,
decoder_start_token_id=2, decoder_start_token_id=2,
scale_embedding=True, scale_embedding=True,
gradient_checkpointing=False,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
...@@ -145,7 +142,6 @@ class M2M100Config(PretrainedConfig): ...@@ -145,7 +142,6 @@ class M2M100Config(PretrainedConfig):
self.decoder_layerdrop = decoder_layerdrop self.decoder_layerdrop = decoder_layerdrop
self.use_cache = use_cache self.use_cache = use_cache
self.num_hidden_layers = encoder_layers self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__( super().__init__(
......
...@@ -520,6 +520,7 @@ class M2M100DecoderLayer(nn.Module): ...@@ -520,6 +520,7 @@ class M2M100DecoderLayer(nn.Module):
class M2M100PreTrainedModel(PreTrainedModel): class M2M100PreTrainedModel(PreTrainedModel):
config_class = M2M100Config config_class = M2M100Config
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -532,6 +533,10 @@ class M2M100PreTrainedModel(PreTrainedModel): ...@@ -532,6 +533,10 @@ class M2M100PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (M2M100Decoder, M2M100Encoder)):
module.gradient_checkpointing = value
M2M_100_START_DOCSTRING = r""" M2M_100_START_DOCSTRING = r"""
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
...@@ -693,6 +698,7 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -693,6 +698,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -787,7 +793,7 @@ class M2M100Encoder(M2M100PreTrainedModel): ...@@ -787,7 +793,7 @@ class M2M100Encoder(M2M100PreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -857,6 +863,7 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -857,6 +863,7 @@ class M2M100Decoder(M2M100PreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -1013,12 +1020,11 @@ class M2M100Decoder(M2M100PreTrainedModel): ...@@ -1013,12 +1020,11 @@ class M2M100Decoder(M2M100PreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -78,8 +78,6 @@ class MarianConfig(PretrainedConfig): ...@@ -78,8 +78,6 @@ class MarianConfig(PretrainedConfig):
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details. https://arxiv.org/abs/1909.11556>`__ for more details.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model). Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
...@@ -128,7 +126,6 @@ class MarianConfig(PretrainedConfig): ...@@ -128,7 +126,6 @@ class MarianConfig(PretrainedConfig):
decoder_start_token_id=58100, decoder_start_token_id=58100,
classifier_dropout=0.0, classifier_dropout=0.0,
scale_embedding=False, scale_embedding=False,
gradient_checkpointing=False,
pad_token_id=58100, pad_token_id=58100,
eos_token_id=0, eos_token_id=0,
forced_eos_token_id=0, forced_eos_token_id=0,
...@@ -153,7 +150,6 @@ class MarianConfig(PretrainedConfig): ...@@ -153,7 +150,6 @@ class MarianConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
self.use_cache = use_cache self.use_cache = use_cache
self.num_hidden_layers = encoder_layers self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
......
...@@ -466,6 +466,7 @@ class MarianDecoderLayer(nn.Module): ...@@ -466,6 +466,7 @@ class MarianDecoderLayer(nn.Module):
class MarianPreTrainedModel(PreTrainedModel): class MarianPreTrainedModel(PreTrainedModel):
config_class = MarianConfig config_class = MarianConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -480,6 +481,10 @@ class MarianPreTrainedModel(PreTrainedModel): ...@@ -480,6 +481,10 @@ class MarianPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (MarianDecoder, MarianEncoder)):
module.gradient_checkpointing = value
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
...@@ -656,6 +661,7 @@ class MarianEncoder(MarianPreTrainedModel): ...@@ -656,6 +661,7 @@ class MarianEncoder(MarianPreTrainedModel):
) )
self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)])
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -750,7 +756,7 @@ class MarianEncoder(MarianPreTrainedModel): ...@@ -750,7 +756,7 @@ class MarianEncoder(MarianPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -816,6 +822,7 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -816,6 +822,7 @@ class MarianDecoder(MarianPreTrainedModel):
) )
self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)]) self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)])
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
...@@ -987,12 +994,11 @@ class MarianDecoder(MarianPreTrainedModel): ...@@ -987,12 +994,11 @@ class MarianDecoder(MarianPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -82,8 +82,6 @@ class MBartConfig(PretrainedConfig): ...@@ -82,8 +82,6 @@ class MBartConfig(PretrainedConfig):
decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0): decoder_layerdrop: (:obj:`float`, `optional`, defaults to 0.0):
The LayerDrop probability for the decoder. See the `LayerDrop paper <see The LayerDrop probability for the decoder. See the `LayerDrop paper <see
https://arxiv.org/abs/1909.11556>`__ for more details. https://arxiv.org/abs/1909.11556>`__ for more details.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`): scale_embedding (:obj:`bool`, `optional`, defaults to :obj:`False`):
Scale embeddings by diving by sqrt(d_model). Scale embeddings by diving by sqrt(d_model).
use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`):
...@@ -131,7 +129,6 @@ class MBartConfig(PretrainedConfig): ...@@ -131,7 +129,6 @@ class MBartConfig(PretrainedConfig):
init_std=0.02, init_std=0.02,
classifier_dropout=0.0, classifier_dropout=0.0,
scale_embedding=False, scale_embedding=False,
gradient_checkpointing=False,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
eos_token_id=2, eos_token_id=2,
...@@ -157,7 +154,6 @@ class MBartConfig(PretrainedConfig): ...@@ -157,7 +154,6 @@ class MBartConfig(PretrainedConfig):
self.classifier_dropout = classifier_dropout self.classifier_dropout = classifier_dropout
self.use_cache = use_cache self.use_cache = use_cache
self.num_hidden_layers = encoder_layers self.num_hidden_layers = encoder_layers
self.gradient_checkpointing = gradient_checkpointing
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
......
...@@ -479,6 +479,7 @@ class MBartClassificationHead(nn.Module): ...@@ -479,6 +479,7 @@ class MBartClassificationHead(nn.Module):
class MBartPreTrainedModel(PreTrainedModel): class MBartPreTrainedModel(PreTrainedModel):
config_class = MBartConfig config_class = MBartConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
...@@ -491,6 +492,10 @@ class MBartPreTrainedModel(PreTrainedModel): ...@@ -491,6 +492,10 @@ class MBartPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None: if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_() module.weight.data[module.padding_idx].zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (MBartDecoder, MBartDecoder)):
module.gradient_checkpointing = value
@property @property
def dummy_inputs(self): def dummy_inputs(self):
pad_token = self.config.pad_token_id pad_token = self.config.pad_token_id
...@@ -685,6 +690,7 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -685,6 +690,7 @@ class MBartEncoder(MBartPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -780,7 +786,7 @@ class MBartEncoder(MBartPreTrainedModel): ...@@ -780,7 +786,7 @@ class MBartEncoder(MBartPreTrainedModel):
if self.training and (dropout_probability < self.layerdrop): # skip the layer if self.training and (dropout_probability < self.layerdrop): # skip the layer
layer_outputs = (None, None) layer_outputs = (None, None)
else: else:
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
def create_custom_forward(module): def create_custom_forward(module):
def custom_forward(*inputs): def custom_forward(*inputs):
...@@ -850,6 +856,7 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -850,6 +856,7 @@ class MBartDecoder(MBartPreTrainedModel):
self.layer_norm = nn.LayerNorm(config.d_model) self.layer_norm = nn.LayerNorm(config.d_model)
self.init_weights() self.init_weights()
self.gradient_checkpointing = False
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
...@@ -1022,12 +1029,11 @@ class MBartDecoder(MBartPreTrainedModel): ...@@ -1022,12 +1029,11 @@ class MBartDecoder(MBartPreTrainedModel):
past_key_value = past_key_values[idx] if past_key_values is not None else None past_key_value = past_key_values[idx] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warning( logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
......
...@@ -65,8 +65,6 @@ class MegatronBertConfig(PretrainedConfig): ...@@ -65,8 +65,6 @@ class MegatronBertConfig(PretrainedConfig):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12): layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-12):
The epsilon used by the layer normalization layers. The epsilon used by the layer normalization layers.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`): position_embedding_type (:obj:`str`, `optional`, defaults to :obj:`"absolute"`):
Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`, Type of position embedding. Choose one of :obj:`"absolute"`, :obj:`"relative_key"`,
:obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on :obj:`"relative_key_query"`. For positional embeddings use :obj:`"absolute"`. For more information on
...@@ -108,7 +106,6 @@ class MegatronBertConfig(PretrainedConfig): ...@@ -108,7 +106,6 @@ class MegatronBertConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
layer_norm_eps=1e-12, layer_norm_eps=1e-12,
pad_token_id=0, pad_token_id=0,
gradient_checkpointing=False,
position_embedding_type="absolute", position_embedding_type="absolute",
use_cache=True, use_cache=True,
**kwargs **kwargs
...@@ -127,6 +124,5 @@ class MegatronBertConfig(PretrainedConfig): ...@@ -127,6 +124,5 @@ class MegatronBertConfig(PretrainedConfig):
self.type_vocab_size = type_vocab_size self.type_vocab_size = type_vocab_size
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.layer_norm_eps = layer_norm_eps self.layer_norm_eps = layer_norm_eps
self.gradient_checkpointing = gradient_checkpointing
self.position_embedding_type = position_embedding_type self.position_embedding_type = position_embedding_type
self.use_cache = use_cache self.use_cache = use_cache
...@@ -180,7 +180,6 @@ def convert_megatron_checkpoint(args, input_state_dict): ...@@ -180,7 +180,6 @@ def convert_megatron_checkpoint(args, input_state_dict):
"type_vocab_size": 2, "type_vocab_size": 2,
"initializer_range": 0.2, "initializer_range": 0.2,
"layer_norm_eps": 1e-12, "layer_norm_eps": 1e-12,
"gradient_checkpointing": False,
"position_embedding_type": "absolute", "position_embedding_type": "absolute",
"use_cache": False, "use_cache": False,
} }
......
...@@ -508,6 +508,7 @@ class MegatronBertEncoder(nn.Module): ...@@ -508,6 +508,7 @@ class MegatronBertEncoder(nn.Module):
# The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one # The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one
# is simply the final LN (Transformer's BERT has it attached to each hidden layer). # is simply the final LN (Transformer's BERT has it attached to each hidden layer).
self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.gradient_checkpointing = False
def forward( def forward(
self, self,
...@@ -534,12 +535,11 @@ class MegatronBertEncoder(nn.Module): ...@@ -534,12 +535,11 @@ class MegatronBertEncoder(nn.Module):
layer_head_mask = head_mask[i] if head_mask is not None else None layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None past_key_value = past_key_values[i] if past_key_values is not None else None
if getattr(self.config, "gradient_checkpointing", False) and self.training: if self.gradient_checkpointing and self.training:
if use_cache: if use_cache:
logger.warn( logger.warn(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting " "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
"`use_cache=False`..."
) )
use_cache = False use_cache = False
...@@ -705,6 +705,7 @@ class MegatronBertPreTrainedModel(PreTrainedModel): ...@@ -705,6 +705,7 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
config_class = MegatronBertConfig config_class = MegatronBertConfig
load_tf_weights = load_tf_weights_in_megatron_bert load_tf_weights = load_tf_weights_in_megatron_bert
base_model_prefix = "bert" base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -719,6 +720,10 @@ class MegatronBertPreTrainedModel(PreTrainedModel): ...@@ -719,6 +720,10 @@ class MegatronBertPreTrainedModel(PreTrainedModel):
if isinstance(module, nn.Linear) and module.bias is not None: if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_() module.bias.data.zero_()
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MegatronBertEncoder):
module.gradient_checkpointing = value
@dataclass @dataclass
# Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert # Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert
......
...@@ -279,7 +279,6 @@ def main(): ...@@ -279,7 +279,6 @@ def main():
summary_proj_to_labels=True, summary_proj_to_labels=True,
summary_first_dropout=0.1, summary_first_dropout=0.1,
scale_attn_weights=True, scale_attn_weights=True,
gradient_checkpointing=False,
use_cache=True, use_cache=True,
bos_token_id=50256, bos_token_id=50256,
eos_token_id=50256, eos_token_id=50256,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment