Unverified Commit 040fd471 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix gradient_checkpointing backward compatibility (#14408)



* Fix gradient_checkpointing backward compatibility

* Remove needless line

* make sure mask prob is big enough and length small enough

* Fix tests
Co-authored-by: default avatarpatrickvonplaten <patrick.v.platen@gmail.com>
parent 1cc453d3
...@@ -412,6 +412,17 @@ class ModuleUtilsMixin: ...@@ -412,6 +412,17 @@ class ModuleUtilsMixin:
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
def gradient_checkpointing_hook(module, _):
# Hook to enable backward compatibility for gradient checkpointing. Will be removed once all models have a
# proper post_init method.
if getattr(module.config, "gradient_checkpointing", False):
module.gradient_checkpointing_enable()
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(module.config, "gradient_checkpointing")
# The hook will remove itself after the first execution
module._gradient_checkpointing_hook.remove()
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin): class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
r""" r"""
Base class for all models. Base class for all models.
...@@ -479,10 +490,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -479,10 +490,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Save config and origin of the pretrained weights if given in model # Save config and origin of the pretrained weights if given in model
self.config = config self.config = config
self.name_or_path = config.name_or_path self.name_or_path = config.name_or_path
if getattr(self.config, "gradient_checkpointing", False): if self.supports_gradient_checkpointing:
self.gradient_checkpointing_enable() self._gradient_checkpointing_hook = self.register_forward_pre_hook(gradient_checkpointing_hook)
# Remove the attribute now that is has been consumed, so it's no saved in the config.
delattr(self.config, "gradient_checkpointing")
@classmethod @classmethod
def _from_config(cls, config, **kwargs): def _from_config(cls, config, **kwargs):
......
...@@ -784,7 +784,6 @@ class DetrClassificationHead(nn.Module): ...@@ -784,7 +784,6 @@ class DetrClassificationHead(nn.Module):
class DetrPreTrainedModel(PreTrainedModel): class DetrPreTrainedModel(PreTrainedModel):
config_class = DetrConfig config_class = DetrConfig
base_model_prefix = "model" base_model_prefix = "model"
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
std = self.config.init_std std = self.config.init_std
......
...@@ -504,7 +504,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): ...@@ -504,7 +504,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
config_class = LayoutLMv2Config config_class = LayoutLMv2Config
pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST
base_model_prefix = "layoutlmv2" base_model_prefix = "layoutlmv2"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
......
...@@ -265,6 +265,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -265,6 +265,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
[self.model_tester.batch_size, height, width], device=torch_device [self.model_tester.batch_size, height, width], device=torch_device
).long() ).long()
model = model_class(config) model = model_class(config)
model.gradient_checkpointing_enable()
model.to(torch_device) model.to(torch_device)
model.train() model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
......
...@@ -213,6 +213,25 @@ class ModelTesterMixin: ...@@ -213,6 +213,25 @@ class ModelTesterMixin:
) )
self.assertTrue(len(load_result.unexpected_keys) == 0) self.assertTrue(len(load_result.unexpected_keys) == 0)
def test_gradient_checkpointing_backward_compatibility(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class.supports_gradient_checkpointing:
continue
config.gradient_checkpointing = True
model = model_class(config)
# Model does not have gradient checkpointing activated yet, it will be done at the first forward.
self.assertFalse(model.is_gradient_checkpointing)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
_ = model(**inputs)
# Model has gradient checkpointing activated after the first forward.
self.assertTrue(model.is_gradient_checkpointing)
def test_gradient_checkpointing_enable_disable(self): def test_gradient_checkpointing_enable_disable(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -418,6 +437,7 @@ class ModelTesterMixin: ...@@ -418,6 +437,7 @@ class ModelTesterMixin:
continue continue
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.gradient_checkpointing_enable()
model.train() model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss loss = model(**inputs).loss
......
...@@ -367,6 +367,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -367,6 +367,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
if model_class.__name__ == "DeiTForImageClassificationWithTeacher": if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
continue continue
model = model_class(config) model = model_class(config)
model.gradient_checkpointing_enable()
model.to(torch_device) model.to(torch_device)
model.train() model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
......
...@@ -66,6 +66,8 @@ class UniSpeechSatModelTester: ...@@ -66,6 +66,8 @@ class UniSpeechSatModelTester:
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
hidden_act="gelu", hidden_act="gelu",
initializer_range=0.02, initializer_range=0.02,
mask_time_prob=0.5,
mask_time_length=2,
vocab_size=32, vocab_size=32,
do_stable_layer_norm=False, do_stable_layer_norm=False,
scope=None, scope=None,
...@@ -93,6 +95,8 @@ class UniSpeechSatModelTester: ...@@ -93,6 +95,8 @@ class UniSpeechSatModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm self.do_stable_layer_norm = do_stable_layer_norm
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.scope = scope self.scope = scope
output_seq_length = self.seq_length output_seq_length = self.seq_length
...@@ -121,6 +125,8 @@ class UniSpeechSatModelTester: ...@@ -121,6 +125,8 @@ class UniSpeechSatModelTester:
conv_bias=self.conv_bias, conv_bias=self.conv_bias,
num_conv_pos_embeddings=self.num_conv_pos_embeddings, num_conv_pos_embeddings=self.num_conv_pos_embeddings,
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups, num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
mask_time_prob=self.mask_time_prob,
mask_time_length=self.mask_time_length,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads, num_attention_heads=self.num_attention_heads,
hidden_dropout_prob=self.hidden_dropout_prob, hidden_dropout_prob=self.hidden_dropout_prob,
......
...@@ -78,6 +78,8 @@ class Wav2Vec2ModelTester: ...@@ -78,6 +78,8 @@ class Wav2Vec2ModelTester:
layer_norm_eps=1e-5, layer_norm_eps=1e-5,
hidden_act="gelu", hidden_act="gelu",
initializer_range=0.02, initializer_range=0.02,
mask_time_prob=0.5,
mask_time_length=2,
vocab_size=32, vocab_size=32,
do_stable_layer_norm=False, do_stable_layer_norm=False,
scope=None, scope=None,
...@@ -105,6 +107,8 @@ class Wav2Vec2ModelTester: ...@@ -105,6 +107,8 @@ class Wav2Vec2ModelTester:
self.initializer_range = initializer_range self.initializer_range = initializer_range
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm self.do_stable_layer_norm = do_stable_layer_norm
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.scope = scope self.scope = scope
output_seq_length = self.seq_length output_seq_length = self.seq_length
...@@ -131,6 +135,8 @@ class Wav2Vec2ModelTester: ...@@ -131,6 +135,8 @@ class Wav2Vec2ModelTester:
conv_stride=self.conv_stride, conv_stride=self.conv_stride,
conv_kernel=self.conv_kernel, conv_kernel=self.conv_kernel,
conv_bias=self.conv_bias, conv_bias=self.conv_bias,
mask_time_prob=self.mask_time_prob,
mask_time_length=self.mask_time_length,
num_conv_pos_embeddings=self.num_conv_pos_embeddings, num_conv_pos_embeddings=self.num_conv_pos_embeddings,
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups, num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
num_hidden_layers=self.num_hidden_layers, num_hidden_layers=self.num_hidden_layers,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment