Unverified Commit 8e0d3b42 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Add defensive check for config num_labels and id2label (#16709)

* Add defensive check for config num_labels and id2label

* Actually check value...

* Only warning inside init plus better error message
parent 6bed0647
......@@ -304,7 +304,12 @@ class PretrainedConfig(PushToHubMixin):
self.id2label = kwargs.pop("id2label", None)
self.label2id = kwargs.pop("label2id", None)
if self.id2label is not None:
kwargs.pop("num_labels", None)
num_labels = kwargs.pop("num_labels", None)
if num_labels is not None and len(self.id2label) != num_labels:
logger.warn(
f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
)
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
# Keys are always strings in JSON so convert ids to int here.
else:
......@@ -678,6 +683,15 @@ class PretrainedConfig(PushToHubMixin):
config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
# Update config with kwargs if needed
if "num_labels" in kwargs and "id2label" in kwargs:
num_labels = kwargs["num_labels"]
id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
if len(id2label) != num_labels:
raise ValueError(
f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
"one of them."
)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment