Unverified Commit 27d55125 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Configs: saner num_labels in configs. (#3967)

parent e80be7f1
......@@ -86,11 +86,13 @@ class PretrainedConfig(object):
# Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None)
self.finetuning_task = kwargs.pop("finetuning_task", None)
self.num_labels = kwargs.pop("num_labels", 2)
self.id2label = kwargs.pop("id2label", {i: f"LABEL_{i}" for i in range(self.num_labels)})
self.id2label = kwargs.pop("id2label", None)
self.label2id = kwargs.pop("label2id", None)
if self.id2label is not None:
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys())))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
# Keys are always strings in JSON so convert ids to int here.
else:
self.num_labels = kwargs.pop("num_labels", 2)
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.prefix = kwargs.pop("prefix", None)
......@@ -115,15 +117,12 @@ class PretrainedConfig(object):
@property
def num_labels(self):
return self._num_labels
return len(self.id2label)
@num_labels.setter
def num_labels(self, num_labels):
self._num_labels = num_labels
self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)}
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
def save_pretrained(self, save_directory):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment