Unverified Commit b623ddc0 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Pass kwargs to configuration (#3147)

* Pass kwargs to configuration

* Setter

* test
parent 0001d056
...@@ -98,6 +98,18 @@ class PretrainedConfig(object): ...@@ -98,6 +98,18 @@ class PretrainedConfig(object):
logger.error("Can't set {} with value {} for {}".format(key, value, self)) logger.error("Can't set {} with value {} for {}".format(key, value, self))
raise err raise err
@property
def num_labels(self):
return self._num_labels
@num_labels.setter
def num_labels(self, num_labels):
self._num_labels = num_labels
self.id2label = {i: "LABEL_{}".format(i) for i in range(self.num_labels)}
self.id2label = dict((int(key), value) for key, value in self.id2label.items())
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
self.label2id = dict((key, int(value)) for key, value in self.label2id.items())
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" """
Save a configuration object to the directory `save_directory`, so that it Save a configuration object to the directory `save_directory`, so that it
......
...@@ -57,8 +57,18 @@ class ConfigTester(object): ...@@ -57,8 +57,18 @@ class ConfigTester(object):
self.parent.assertEqual(config_second.to_dict(), config_first.to_dict()) self.parent.assertEqual(config_second.to_dict(), config_first.to_dict())
def create_and_test_config_with_num_labels(self):
config = self.config_class(**self.inputs_dict, num_labels=5)
self.parent.assertEqual(len(config.id2label), 5)
self.parent.assertEqual(len(config.label2id), 5)
config.num_labels = 3
self.parent.assertEqual(len(config.id2label), 3)
self.parent.assertEqual(len(config.label2id), 3)
def run_common_tests(self): def run_common_tests(self):
self.create_and_test_config_common_properties() self.create_and_test_config_common_properties()
self.create_and_test_config_to_json_string() self.create_and_test_config_to_json_string()
self.create_and_test_config_to_json_file() self.create_and_test_config_to_json_file()
self.create_and_test_config_from_and_save_pretrained() self.create_and_test_config_from_and_save_pretrained()
self.create_and_test_config_with_num_labels()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment