Commit 80eacb8f authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Adding labels mapping for classification models in their respective config.

parent f69dbecc
...@@ -58,8 +58,8 @@ class PretrainedConfig(object): ...@@ -58,8 +58,8 @@ class PretrainedConfig(object):
self.use_bfloat16 = kwargs.pop('use_bfloat16', False) self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {}) self.pruned_heads = kwargs.pop('pruned_heads', {})
self.is_decoder = kwargs.pop('is_decoder', False) self.is_decoder = kwargs.pop('is_decoder', False)
self.idx2label = kwargs.pop('idx2label', {i: 'LABEL_{}'.format(i) for i in range(self.num_labels)}) self.id2label = kwargs.pop('id2label', {i: 'LABEL_{}'.format(i) for i in range(self.num_labels)})
self.label2idx = kwargs.pop('label2idx', dict(zip(self.idx2label.values(), self.idx2label.keys()))) self.label2id = kwargs.pop('label2id', dict(zip(self.id2label.values(), self.id2label.keys())))
def save_pretrained(self, save_directory): def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it """ Save a configuration object to the directory `save_directory`, so that it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment