Unverified Commit 1e662f0f authored by Yih-Dar's avatar Yih-Dar Committed by GitHub
Browse files

Allow generic composite models to pass more kwargs (#24927)



* fix

* Update src/transformers/generation/utils.py
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>

* update

---------
Co-authored-by: default avatarydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: default avatarJoao Gante <joaofranciscocardosogante@gmail.com>
parent b51312e2
...@@ -1147,6 +1147,27 @@ class GenerationMixin: ...@@ -1147,6 +1147,27 @@ class GenerationMixin:
# `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;) # `prepare_inputs_for_generation` doesn't accept them, then a stricter check can be made ;)
if "kwargs" in model_args or "model_kwargs" in model_args: if "kwargs" in model_args or "model_kwargs" in model_args:
model_args |= set(inspect.signature(self.forward).parameters) model_args |= set(inspect.signature(self.forward).parameters)
# Encoder-Decoder models may also need Encoder arguments from `model_kwargs`
if self.config.is_encoder_decoder:
base_model = getattr(self, self.base_model_prefix, None)
# allow encoder kwargs
encoder = getattr(self, "encoder", None)
if encoder is None:
encoder = getattr(base_model, "encoder")
encoder_model_args = set(inspect.signature(encoder.forward).parameters)
model_args |= encoder_model_args
# allow decoder kwargs
decoder = getattr(self, "decoder", None)
if decoder is None:
decoder = getattr(base_model, "decoder")
decoder_model_args = set(inspect.signature(decoder.forward).parameters)
model_args |= {f"decoder_{x}" for x in decoder_model_args}
for key, value in model_kwargs.items(): for key, value in model_kwargs.items():
if value is not None and key not in model_args: if value is not None and key not in model_args:
unused_model_args.append(key) unused_model_args.append(key)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment