"docs/source/en/model_doc/auto.mdx" did not exist on "b5e2b183af5e40e33a4dc7659e697d137259d56e"
Unverified Commit 21150cb0 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Hotfix for failing `MusicgenForConditionalGeneration` tests (#25091)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent f9cc3338
......@@ -1154,17 +1154,22 @@ class GenerationMixin:
# allow encoder kwargs
encoder = getattr(self, "encoder", None)
if encoder is None:
encoder = getattr(base_model, "encoder")
# `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
# Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
# TODO: A better way to handle this.
if encoder is None and base_model is not None:
encoder = getattr(base_model, "encoder", None)
if encoder is not None:
encoder_model_args = set(inspect.signature(encoder.forward).parameters)
model_args |= encoder_model_args
# allow decoder kwargs
decoder = getattr(self, "decoder", None)
if decoder is None:
decoder = getattr(base_model, "decoder")
if decoder is None and base_model is not None:
decoder = getattr(base_model, "decoder", None)
if decoder is not None:
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment