Unverified Commit 2de99e6c authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

Fix Loading of Flax(Speech)EncoderDecoderModel kwargs from PreTrained...

Fix Loading of Flax(Speech)EncoderDecoderModel kwargs from PreTrained Encoder-Decoder Checkpoints (#16056)

* Fix Loading of Flax(Speech)EncoderDecoderModel kwargs from PreTrained Encoder-Decoder Checkpoints

* change wording
parent 802984ad
......@@ -822,7 +822,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
)
if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
......@@ -846,7 +848,9 @@ class FlaxEncoderDecoderModel(FlaxPreTrainedModel):
)
if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
......
......@@ -835,7 +835,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
)
if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
encoder_config, kwargs_encoder = AutoConfig.from_pretrained(
encoder_pretrained_model_name_or_path, **kwargs_encoder, return_unused_kwargs=True
)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
......@@ -859,7 +861,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
)
if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
decoder_config, kwargs_decoder = AutoConfig.from_pretrained(
decoder_pretrained_model_name_or_path, **kwargs_decoder, return_unused_kwargs=True
)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
......
......@@ -160,6 +160,51 @@ class FlaxEncoderDecoderMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 1e-5)
def check_encoder_decoder_model_from_encoder_decoder_pretrained(
self,
config,
input_ids,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
# assert that model attributes match those of configs
self.assertEqual(config.use_cache, encoder_model.config.use_cache)
self.assertEqual(decoder_config.use_cache, decoder_model.config.use_cache)
with tempfile.TemporaryDirectory() as enc_tmpdir:
with tempfile.TemporaryDirectory() as dec_tmpdir:
encoder_model.save_pretrained(enc_tmpdir)
decoder_model.save_pretrained(dec_tmpdir)
# load a model from pretrained encoder and decoder checkpoints, setting one encoder and one decoder kwarg opposite to that specified in their respective configs
enc_dec_model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=enc_tmpdir,
decoder_pretrained_model_name_or_path=dec_tmpdir,
encoder_use_cache=not config.use_cache,
decoder_use_cache=not decoder_config.use_cache,
)
# assert that setting encoder and decoder kwargs opposite to those in the configs has correctly been applied
self.assertNotEqual(config.use_cache, enc_dec_model.config.encoder.use_cache)
self.assertNotEqual(decoder_config.use_cache, enc_dec_model.config.decoder.use_cache)
outputs_encoder_decoder = enc_dec_model(
input_ids=input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_hidden_states=True,
return_dict=True,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
def check_encoder_decoder_model_output_attentions(
self,
config,
......@@ -326,6 +371,10 @@ class FlaxEncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_save_and_load(**input_ids_dict)
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_encoder_decoder_pretrained(**input_ids_dict)
def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
......
......@@ -196,6 +196,51 @@ class FlaxEncoderDecoderMixin:
max_diff = np.amax(np.abs(out_1 - out_2))
self.assertLessEqual(max_diff, 4e-2)
def check_encoder_decoder_model_from_encoder_decoder_pretrained(
self,
config,
inputs,
attention_mask,
encoder_hidden_states,
decoder_config,
decoder_input_ids,
decoder_attention_mask,
**kwargs
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
# assert that loading encoder and decoder models from configs has been correctly executed
self.assertEqual(config.add_adapter, encoder_model.config.add_adapter)
self.assertEqual(decoder_config.use_cache, decoder_model.config.use_cache)
with tempfile.TemporaryDirectory() as enc_tmpdir:
with tempfile.TemporaryDirectory() as dec_tmpdir:
encoder_model.save_pretrained(enc_tmpdir)
decoder_model.save_pretrained(dec_tmpdir)
# load a model from pretrained encoder and decoder checkpoints, setting one encoder and one decoder kwarg opposite to that specified in their respective configs
enc_dec_model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=enc_tmpdir,
decoder_pretrained_model_name_or_path=dec_tmpdir,
encoder_add_adapter=not config.add_adapter,
decoder_use_cache=not decoder_config.use_cache,
)
# assert that setting encoder and decoder kwargs opposite to those in the configs has correctly been applied
self.assertNotEqual(config.add_adapter, enc_dec_model.config.encoder.add_adapter)
self.assertNotEqual(decoder_config.use_cache, enc_dec_model.config.decoder.use_cache)
outputs_encoder_decoder = enc_dec_model(
inputs=inputs,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
output_hidden_states=True,
return_dict=True,
)
self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
)
def check_encoder_decoder_model_output_attentions(
self,
config,
......@@ -441,6 +486,10 @@ class FlaxEncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_save_and_load(**input_ids_dict)
def test_encoder_decoder_model_from_encoder_decoder_pretrained(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_from_encoder_decoder_pretrained(**input_ids_dict)
def test_encoder_decoder_model_output_attentions(self):
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_output_attentions(**input_ids_dict)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment