"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "331fe04df7dcfb2d22e2ecc39525b6cf74fae575"
Unverified Commit 8869bf41 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[VisionEncoderDecoder] Add gradient checkpointing (#18697)

* add first generation tutorial

* VisionEnocderDecoder gradient checkpointing

* remove generation

* add tests
parent 06a6a4bd
...@@ -155,6 +155,7 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -155,6 +155,7 @@ class VisionEncoderDecoderModel(PreTrainedModel):
config_class = VisionEncoderDecoderConfig config_class = VisionEncoderDecoderConfig
base_model_prefix = "vision_encoder_decoder" base_model_prefix = "vision_encoder_decoder"
main_input_name = "pixel_values" main_input_name = "pixel_values"
supports_gradient_checkpointing = True
def __init__( def __init__(
self, self,
...@@ -221,6 +222,11 @@ class VisionEncoderDecoderModel(PreTrainedModel): ...@@ -221,6 +222,11 @@ class VisionEncoderDecoderModel(PreTrainedModel):
f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head" f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
) )
def _set_gradient_checkpointing(self, module, value=False):
# call both encoder and decoder function on gradient checkpointing
self.encoder._set_gradient_checkpointing(module, value=value)
self.decoder._set_gradient_checkpointing(module, value=value)
def get_encoder(self): def get_encoder(self):
return self.encoder return self.encoder
......
...@@ -396,6 +396,28 @@ class EncoderDecoderMixin: ...@@ -396,6 +396,28 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict) self.check_encoder_decoder_model_generate(**input_ids_dict)
def test_training_gradient_checkpointing(self):
inputs_dict = self.prepare_config_and_inputs()
encoder_model, decoder_model = self.get_encoder_decoder_model(
inputs_dict["config"], inputs_dict["decoder_config"]
)
model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.train()
model.gradient_checkpointing_enable()
model.config.decoder_start_token_id = 0
model.config.pad_token_id = 0
model_inputs = {
"attention_mask": inputs_dict["attention_mask"],
"labels": inputs_dict["labels"],
"decoder_input_ids": inputs_dict["decoder_input_ids"],
}
inputs = inputs_dict["input_features"] if "input_features" in inputs_dict else inputs_dict["input_values"]
loss = model(inputs, **model_inputs).loss
loss.backward()
@slow @slow
def test_real_model_save_load_from_pretrained(self): def test_real_model_save_load_from_pretrained(self):
model_2, inputs = self.get_pretrained_model_and_inputs() model_2, inputs = self.get_pretrained_model_and_inputs()
...@@ -590,6 +612,7 @@ class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase): ...@@ -590,6 +612,7 @@ class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
"decoder_config": decoder_config, "decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"labels": decoder_input_ids,
} }
# there are no published pretrained Speech2Text2ForCausalLM for now # there are no published pretrained Speech2Text2ForCausalLM for now
......
...@@ -324,6 +324,27 @@ class EncoderDecoderMixin: ...@@ -324,6 +324,27 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict) self.check_encoder_decoder_model_generate(**input_ids_dict)
def test_training_gradient_checkpointing(self):
inputs_dict = self.prepare_config_and_inputs()
encoder_model, decoder_model = self.get_encoder_decoder_model(
inputs_dict["config"], inputs_dict["decoder_config"]
)
model = VisionEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.train()
model.gradient_checkpointing_enable()
model.config.decoder_start_token_id = 0
model.config.pad_token_id = 0
model_inputs = {
"pixel_values": inputs_dict["pixel_values"],
"labels": inputs_dict["labels"],
"decoder_input_ids": inputs_dict["decoder_input_ids"],
}
loss = model(**model_inputs).loss
loss.backward()
@slow @slow
def test_real_model_save_load_from_pretrained(self): def test_real_model_save_load_from_pretrained(self):
model_2, inputs = self.get_pretrained_model_and_inputs() model_2, inputs = self.get_pretrained_model_and_inputs()
...@@ -547,6 +568,7 @@ class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -547,6 +568,7 @@ class Swin2BartModelTest(EncoderDecoderMixin, unittest.TestCase):
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs() decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs()
config, pixel_values, _ = encoder_config_and_inputs config, pixel_values, _ = encoder_config_and_inputs
decoder_config, decoder_inputs_dict = decoder_config_and_inputs decoder_config, decoder_inputs_dict = decoder_config_and_inputs
decoder_inputs_dict["labels"] = decoder_inputs_dict["decoder_input_ids"]
# make sure that cross attention layers are added # make sure that cross attention layers are added
decoder_config.add_cross_attention = True decoder_config.add_cross_attention = True
...@@ -644,6 +666,7 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase): ...@@ -644,6 +666,7 @@ class ViT2TrOCR(EncoderDecoderMixin, unittest.TestCase):
"decoder_config": decoder_config, "decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids, "decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask, "decoder_attention_mask": decoder_attention_mask,
"labels": decoder_input_ids,
} }
# there are no published pretrained TrOCR checkpoints for now # there are no published pretrained TrOCR checkpoints for now
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment