Unverified Commit 139e8301 authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Update label2id in the model config for run_glue (#13334)

parent 6f3c99ac
...@@ -380,6 +380,9 @@ def main(): ...@@ -380,6 +380,9 @@ def main():
if label_to_id is not None: if label_to_id is not None:
model.config.label2id = label_to_id model.config.label2id = label_to_id
model.config.id2label = {id: label for label, id in config.label2id.items()} model.config.id2label = {id: label for label, id in config.label2id.items()}
elif data_args.task_name is not None and not is_regression:
model.config.label2id = {l: i for i, l in enumerate(label_list)}
model.config.id2label = {id: label for label, id in config.label2id.items()}
if data_args.max_seq_length > tokenizer.model_max_length: if data_args.max_seq_length > tokenizer.model_max_length:
logger.warning( logger.warning(
......
...@@ -288,6 +288,9 @@ def main(): ...@@ -288,6 +288,9 @@ def main():
if label_to_id is not None: if label_to_id is not None:
model.config.label2id = label_to_id model.config.label2id = label_to_id
model.config.id2label = {id: label for label, id in config.label2id.items()} model.config.id2label = {id: label for label, id in config.label2id.items()}
elif args.task_name is not None and not is_regression:
model.config.label2id = {l: i for i, l in enumerate(label_list)}
model.config.id2label = {id: label for label, id in config.label2id.items()}
padding = "max_length" if args.pad_to_max_length else False padding = "max_length" if args.pad_to_max_length else False
......
...@@ -355,6 +355,9 @@ def main(): ...@@ -355,6 +355,9 @@ def main():
if label_to_id is not None: if label_to_id is not None:
config.label2id = label_to_id config.label2id = label_to_id
config.id2label = {id: label for label, id in config.label2id.items()} config.id2label = {id: label for label, id in config.label2id.items()}
elif data_args.task_name is not None and not is_regression:
config.label2id = {l: i for i, l in enumerate(label_list)}
config.id2label = {id: label for label, id in config.label2id.items()}
if data_args.max_seq_length > tokenizer.model_max_length: if data_args.max_seq_length > tokenizer.model_max_length:
logger.warning( logger.warning(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment