Unverified Commit e5079b0b authored by dancingpipi's avatar dancingpipi Committed by GitHub
Browse files

Support PeftModel signature inspect (#27865)



* Support PeftModel signature inspect

* Use get_base_model() to get the base model

---------
Co-authored-by: default avatarshujunhua1 <shujunhua1@jd.com>
parent 35478182
......@@ -695,7 +695,10 @@ class Trainer:
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
model_to_inspect = self.model
if is_peft_available() and isinstance(self.model, PeftModel):
model_to_inspect = self.model.get_base_model()
signature = inspect.signature(model_to_inspect.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment