Unverified Commit 4c9e0f02 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Add support for gradient checkpointing (#19990)


Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 8214a9f6
...@@ -581,6 +581,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel): ...@@ -581,6 +581,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
config_class = BertGenerationConfig config_class = BertGenerationConfig
base_model_prefix = "bert" base_model_prefix = "bert"
supports_gradient_checkpointing = True
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def _init_weights(self, module): def _init_weights(self, module):
...@@ -599,6 +600,10 @@ class BertGenerationPreTrainedModel(PreTrainedModel): ...@@ -599,6 +600,10 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
module.bias.data.zero_() module.bias.data.zero_()
module.weight.data.fill_(1.0) module.weight.data.fill_(1.0)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, BertEncoder):
module.gradient_checkpointing = value
BERT_GENERATION_START_DOCSTRING = r""" BERT_GENERATION_START_DOCSTRING = r"""
......
...@@ -175,6 +175,8 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -175,6 +175,8 @@ class EncoderDecoderModel(PreTrainedModel):
""" """
config_class = EncoderDecoderConfig config_class = EncoderDecoderConfig
base_model_prefix = "encoder_decoder" base_model_prefix = "encoder_decoder"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
def __init__( def __init__(
self, self,
...@@ -255,6 +257,11 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -255,6 +257,11 @@ class EncoderDecoderModel(PreTrainedModel):
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
) )
def _set_gradient_checkpointing(self, module, value=False):
# call both encoder and decoder function on gradient checkpointing
self.encoder._set_gradient_checkpointing(module, value=value)
self.decoder._set_gradient_checkpointing(module, value=value)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
......
...@@ -611,6 +611,27 @@ class EncoderDecoderMixin: ...@@ -611,6 +611,27 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.create_and_check_encoder_decoder_shared_weights(**input_ids_dict) self.create_and_check_encoder_decoder_shared_weights(**input_ids_dict)
def test_training_gradient_checkpointing(self):
inputs_dict = self.prepare_config_and_inputs()
encoder_model, decoder_model = self.get_encoder_decoder_model(
inputs_dict["config"], inputs_dict["decoder_config"]
)
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.train()
model.gradient_checkpointing_enable()
model.config.decoder_start_token_id = 0
model.config.pad_token_id = 0
model_inputs = {
"input_ids": inputs_dict["input_ids"],
"attention_mask": inputs_dict["attention_mask"],
"labels": inputs_dict["labels"],
"decoder_input_ids": inputs_dict["decoder_input_ids"],
}
loss = model(**model_inputs).loss
loss.backward()
@slow @slow
def test_real_model_save_load_from_pretrained(self): def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model() model_2 = self.get_pretrained_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment