Unverified Commit 21150cb0 authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Hotfix for failing `MusicgenForConditionalGeneration` tests (#25091)


Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
parent f9cc3338
...@@ -1154,19 +1154,24 @@ class GenerationMixin: ...@@ -1154,19 +1154,24 @@ class GenerationMixin:
# allow encoder kwargs # allow encoder kwargs
encoder = getattr(self, "encoder", None) encoder = getattr(self, "encoder", None)
if encoder is None: # `MusicgenForConditionalGeneration` has `text_encoder` and `audio_encoder`.
encoder = getattr(base_model, "encoder") # Also, it has `base_model_prefix = "encoder_decoder"` but there is no `self.encoder_decoder`
# TODO: A better way to handle this.
if encoder is None and base_model is not None:
encoder = getattr(base_model, "encoder", None)
encoder_model_args = set(inspect.signature(encoder.forward).parameters) if encoder is not None:
model_args |= encoder_model_args encoder_model_args = set(inspect.signature(encoder.forward).parameters)
model_args |= encoder_model_args
# allow decoder kwargs # allow decoder kwargs
decoder = getattr(self, "decoder", None) decoder = getattr(self, "decoder", None)
if decoder is None: if decoder is None and base_model is not None:
decoder = getattr(base_model, "decoder") decoder = getattr(base_model, "decoder", None)
decoder_model_args = set(inspect.signature(decoder.forward).parameters) if decoder is not None:
model_args |= {f"decoder_{x}" for x in decoder_model_args} decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}
for key, value in model_kwargs.items(): for key, value in model_kwargs.items():
if value is not None and key not in model_args: if value is not None and key not in model_args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment