Unverified Commit 443fdaf2 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[SpeechEncoderDecoder] Fix from pretrained (#15043)

parent ae929dcb
......@@ -380,7 +380,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
)
if "config" not in kwargs_encoder:
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path, **kwargs_encoder)
encoder_config = AutoConfig.from_pretrained(encoder_pretrained_model_name_or_path)
if encoder_config.is_decoder is True or encoder_config.add_cross_attention is True:
logger.info(
f"Initializing {encoder_pretrained_model_name_or_path} as a encoder model "
......@@ -391,7 +391,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
kwargs_encoder["config"] = encoder_config
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args)
encoder = AutoModel.from_pretrained(encoder_pretrained_model_name_or_path, *model_args, **kwargs_encoder)
decoder = kwargs_decoder.pop("model", None)
if decoder is None:
......@@ -402,7 +402,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
)
if "config" not in kwargs_decoder:
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
decoder_config = AutoConfig.from_pretrained(decoder_pretrained_model_name_or_path)
if decoder_config.is_decoder is False or decoder_config.add_cross_attention is False:
logger.info(
f"Initializing {decoder_pretrained_model_name_or_path} as a decoder model. "
......@@ -424,7 +424,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
"`decoder_config` to `.from_encoder_decoder_pretrained(...)`"
)
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path)
decoder = AutoModelForCausalLM.from_pretrained(decoder_pretrained_model_name_or_path, **kwargs_decoder)
# instantiate config with corresponding kwargs
config = SpeechEncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config, **kwargs)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment