Unverified Commit 17b6e0d4 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix run_glue evaluation when model has a label correspondence (#10401)

parent 26f8b2cb
......@@ -324,7 +324,7 @@ def main():
# Some have all caps in their config, some don't.
label_name_to_id = {k.lower(): v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
label_to_id = {i: label_name_to_id[label_list[i]] for i in range(num_labels)}
label_to_id = {i: int(label_name_to_id[label_list[i]]) for i in range(num_labels)}
else:
logger.warn(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
......@@ -350,7 +350,7 @@ def main():
# Map labels to IDs (not necessary for GLUE tasks)
if label_to_id is not None and "label" in examples:
result["label"] = [label_to_id[l] for l in examples["label"]]
result["label"] = [(label_to_id[l] if l != -1 else -1) for l in examples["label"]]
return result
datasets = datasets.map(preprocess_function, batched=True, load_from_cache_file=not data_args.overwrite_cache)
......
......@@ -289,8 +289,9 @@ class PretrainedConfig(object):
@num_labels.setter
def num_labels(self, num_labels: int):
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
if self.id2label is None or len(self.id2label) != num_labels:
self.id2label = {i: "LABEL_{}".format(i) for i in range(num_labels)}
self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment