Unverified Commit 7ca46335 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[FlaxSpeechEncoderDecoderModel] Ensure Input and Output Word Embeddings Are **Not** Tied (#16444)

* [FlaxSpeechEncoderDecoderModel] Ensure Input and Output Word Embeddings Are **Not** Tied

* rebase
parent e0ac72b7
...@@ -347,6 +347,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -347,6 +347,8 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`." f"and {config.encoder.hidden_size} for `config.encoder.hidden_size`."
) )
# make sure input & output embeddings are not tied
config.tie_word_embeddings = False
module = self.module_class(config=config, dtype=dtype, **kwargs) module = self.module_class(config=config, dtype=dtype, **kwargs)
if input_shape is None: if input_shape is None:
...@@ -890,6 +892,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel): ...@@ -890,6 +892,9 @@ class FlaxSpeechEncoderDecoderModel(FlaxPreTrainedModel):
dtype = kwargs.pop("dtype", jnp.float32) dtype = kwargs.pop("dtype", jnp.float32)
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
# make sure input & output word embeddings are not tied
config.tie_word_embeddings = False
# init model # init model
model = cls(config, dtype=dtype) model = cls(config, dtype=dtype)
model.params["encoder"] = encoder.params model.params["encoder"] = encoder.params
......
...@@ -79,6 +79,7 @@ class FlaxEncoderDecoderMixin: ...@@ -79,6 +79,7 @@ class FlaxEncoderDecoderMixin:
enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config) enc_dec_model = FlaxSpeechEncoderDecoderModel(encoder_decoder_config)
self.assertTrue(enc_dec_model.config.is_encoder_decoder) self.assertTrue(enc_dec_model.config.is_encoder_decoder)
self.assertFalse(enc_dec_model.config.tie_word_embeddings)
outputs_encoder_decoder = enc_dec_model( outputs_encoder_decoder = enc_dec_model(
inputs=inputs, inputs=inputs,
......
...@@ -72,6 +72,7 @@ class EncoderDecoderMixin: ...@@ -72,6 +72,7 @@ class EncoderDecoderMixin:
enc_dec_model.eval() enc_dec_model.eval()
self.assertTrue(enc_dec_model.config.is_encoder_decoder) self.assertTrue(enc_dec_model.config.is_encoder_decoder)
self.assertFalse(enc_dec_model.config.tie_word_embeddings)
outputs_encoder_decoder = enc_dec_model( outputs_encoder_decoder = enc_dec_model(
input_values=input_values, input_values=input_values,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment