Unverified Commit 443fdaf2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SpeechEncoderDecoder] Fix from pretrained (#15043)

parent ae929dcb
...@@ -380,7 +380,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -380,7 +380,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
) )
if "config" not in kwargs_encoder: if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path, **kwargs_encoder) encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True: if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info( logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model " f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
...@@ -391,7 +391,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -391,7 +391,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
kwargs_encoder["config"] = encoder_config kwargs_encoder["config"] = encoder_config
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args) encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
decoder = kwargs_decoder.pop("model", None) decoder = kwargs_decoder.pop("model", None)
if decoder is None: if decoder is None:
...@@ -402,7 +402,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -402,7 +402,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
) )
if "config" not in kwargs_decoder: if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder) decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False: if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info( logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. " f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
...@@ -424,7 +424,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): ...@@ -424,7 +424,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`" "`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
) )
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path) decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs # instantiate config with corresponding kwargs
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs) config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment