Unverified Commit 27d55125 authored by Julien Chaumond's avatar Julien Chaumond Committed by GitHub
Browse files

Configs: saner num_labels in configs. (#3967)

parent e80be7f1
...@@ -86,11 +86,13 @@ class PretrainedConfig(object): ...@@ -86,11 +86,13 @@ class PretrainedConfig(object):
# Fine-tuning task arguments # Fine-tuning task arguments
self.architectures = kwargs.pop("architectures", None) self.architectures = kwargs.pop("architectures", None)
self.finetuning_task = kwargs.pop("finetuning_task", None) self.finetuning_task = kwargs.pop("finetuning_task", None)
self.num_labels = kwargs.pop("num_labels", 2) self.id2label = kwargs.pop("id2label", None)
self.id2label = kwargs.pop("id2label", {i: f"LABEL_{i}" for i in range(self.num_labels)}) self.label2id = kwargs.pop("label2id", None)
self.id2label = dict((int(key), value) for key, value in self.id2label.items()) if self.id2label is not None:
self.label2id = kwargs.pop("label2id", dict(zip(self.id2label.values(), self.id2label.keys()))) self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = dict((key, int(value)) for key, value in self.label2id.items()) # Keys are always strings in JSON so convert ids to int here.
else:
self.num_labels = kwargs.pop("num_labels", 2)
# Tokenizer arguments TODO: eventually tokenizer and models should share the same config # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
self.prefix = kwargs.pop("prefix", None) self.prefix = kwargs.pop("prefix", None)
...@@ -115,15 +117,12 @@ class PretrainedConfig(object): ...@@ -115,15 +117,12 @@ class PretrainedConfig(object):
@property @property
def num_labels(self): def num_labels(self):
return self._num_labels return len(self.id2label)
@num_labels.setter @num_labels.setter
def num_labels(self, num_labels): def num_labels(self, num_labels):
self._num_labels = num_labels self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)}
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment