Commit f69dbecc authored by Morgan Funtowicz's avatar Morgan Funtowicz
Browse files

Expose classification labels mapping (and reverse) in model config.

parent 6709739a
......@@ -58,6 +58,8 @@ class PretrainedConfig(object):
self.use_bfloat16 = kwargs.pop('use_bfloat16', False)
self.pruned_heads = kwargs.pop('pruned_heads', {})
self.is_decoder = kwargs.pop('is_decoder', False)
self.idx2label = kwargs.pop('idx2label', {i: 'LABEL_{}'.format(i) for i in range(self.num_labels)})
self.label2idx = kwargs.pop('label2idx', dict(zip(self.idx2label.values(), self.idx2label.keys())))
def save_pretrained(self, save_directory):
""" Save a configuration object to the directory `save_directory`, so that it
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment