"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d97fd871e5ba57b23b1775ef2939ffea128dd08d"
Unverified Commit b8810847 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Trainer] suppress warning for length-related columns (#15421)

* [Trainer] suppress warning for length-related columns

* improve message

* Update src/transformers/trainer.py
parent 3385ca25
......@@ -546,15 +546,19 @@ class Trainer:
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += ["label", "label_ids"]
columns = [k for k in self._signature_columns if k in dataset.column_names]
ignored_columns = list(set(dataset.column_names) - set(self._signature_columns))
if len(ignored_columns) > 0:
dset_description = "" if description is None else f"in the {description} set "
logger.info(
f"The following columns {dset_description} don't have a corresponding argument in "
f"`{self.model.__class__.__name__}.forward` and have been ignored: {', '.join(ignored_columns)}."
f" If {', '.join(ignored_columns)} are not expected by `{self.model.__class__.__name__}.forward`, "
f" you can safely ignore this message."
)
columns = [k for k in self._signature_columns if k in dataset.column_names]
if version.parse(datasets.__version__) < version.parse("1.4.0"):
dataset.set_format(
type=dataset.format["type"], columns=columns, format_kwargs=dataset.format["format_kwargs"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment