"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "39fa40096984f9d3995a507f2514b71cd675319a"
Unverified Commit 2de99e6c authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Fix Loading of Flax(Speech)EncoderDecoderModel kwargs from PreTrained...

Fix Loading of Flax(Speech)EncoderDecoderModel kwargs from PreTrained Encoder-Decoder Checkpoints (#16056)

* Fix Loading of Flax(Speech)EncoderDecoderModel kwargs from PreTrained Encoder-Decoder Checkpoints

* change wording
parent 802984ad
...@@ -822,7 +822,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -822,7 +822,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
) )
if "config" not in kwargs_encoder: if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info( logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
...@@ -846,7 +848,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -846,7 +848,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
) )
if "config" not in kwargs_decoder: if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info( logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
......
...@@ -835,7 +835,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -835,7 +835,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
) )
if "config" not in kwargs_encoder: if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path) encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info( logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
...@@ -859,7 +861,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -859,7 +861,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
) )
if "config" not in kwargs_decoder: if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path) decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info( logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
......
...@@ -160,6 +160,51 @@ class FlaxEncoderDecoderMixin: ...@@ -160,6 +160,51 @@ class FlaxEncoderDecoderMixin:
max_diff = np.amax(np.abs(out_1 - out_2)) max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5) self.assertLessEqual(max_diff, 1e-5)
def check_encoder_decoder_model_from_encoder_decoder_pretrained(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
# assert that model attributes match those of configs
self.assertEqual(config.use_cache, encoder_model.config.use_cache)
self.assertEqual(decoder_config.use_cache, decoder_model.config.use_cache)
with tempfile.TemporaryDirectory() as enc_tmpdir:
with tempfile.TemporaryDirectory() as dec_tmpdir:
encoder_model.save_pretrained(enc_tmpdir)
decoder_model.save_pretrained(dec_tmpdir)
# load a model from pretrained encoder and decoder checkpoints, setting one encoder and one decoder kwarg opposite to that specified in their respective configs
enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=enc_tmpdir,
decoder_pretrained_model_name_or_path=dec_tmpdir,
encoder_use_cache=not config.use_cache,
decoder_use_cache=not decoder_config.use_cache,
)
# assert that setting encoder and decoder kwargs opposite to those in the configs has correctly been applied
self.assertNotEqual(config.use_cache, enc_dec_model.config.encoder.use_cache)
self.assertNotEqual(decoder_config.use_cache, enc_dec_model.config.decoder.use_cache)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_hidden_states=True,
return_dict=True,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
def check_encoder_decoder_model_output_attentions( def check_encoder_decoder_model_output_attentions(
self, self,
config, config,
...@@ -326,6 +371,10 @@ class FlaxEncoderDecoderMixin: ...@@ -326,6 +371,10 @@ class FlaxEncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_save_and_load(**input_ids_dict) self.check_save_and_load(**input_ids_dict)
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_encoder_decoder_pretrained(**input_ids_dict)
def test_encoder_decoder_model_output_attentions(self): def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict) self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
......
...@@ -196,6 +196,51 @@ class FlaxEncoderDecoderMixin: ...@@ -196,6 +196,51 @@ class FlaxEncoderDecoderMixin:
max_diff = np.amax(np.abs(out_1 - out_2)) max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 4e-2) self.assertLessEqual(max_diff, 4e-2)
def check_encoder_decoder_model_from_encoder_decoder_pretrained(
self,
config,
inputs,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
# assert that loading encoder and decoder models from configs has been correctly executed
self.assertEqual(config.add_adapter, encoder_model.config.add_adapter)
self.assertEqual(decoder_config.use_cache, decoder_model.config.use_cache)
with tempfile.TemporaryDirectory() as enc_tmpdir:
with tempfile.TemporaryDirectory() as dec_tmpdir:
encoder_model.save_pretrained(enc_tmpdir)
decoder_model.save_pretrained(dec_tmpdir)
# load a model from pretrained encoder and decoder checkpoints, setting one encoder and one decoder kwarg opposite to that specified in their respective configs
enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=enc_tmpdir,
decoder_pretrained_model_name_or_path=dec_tmpdir,
encoder_add_adapter=not config.add_adapter,
decoder_use_cache=not decoder_config.use_cache,
)
# assert that setting encoder and decoder kwargs opposite to those in the configs has correctly been applied
self.assertNotEqual(config.add_adapter, enc_dec_model.config.encoder.add_adapter)
self.assertNotEqual(decoder_config.use_cache, enc_dec_model.config.decoder.use_cache)
outputs_encoder_decoder = enc_dec_model(
inputs=inputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_hidden_states=True,
return_dict=True,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
def check_encoder_decoder_model_output_attentions( def check_encoder_decoder_model_output_attentions(
self, self,
config, config,
...@@ -441,6 +486,10 @@ class FlaxEncoderDecoderMixin: ...@@ -441,6 +486,10 @@ class FlaxEncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_save_and_load(**input_ids_dict) self.check_save_and_load(**input_ids_dict)
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_encoder_decoder_pretrained(**input_ids_dict)
def test_encoder_decoder_model_output_attentions(self): def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs() input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict) self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment