Unverified Commit bebbdd0f authored by Matt's avatar Matt Committed by GitHub
Browse files

Appending label2id and id2label to models to ensure inference works properly (#12102)

parent 4cda08de
...@@ -370,6 +370,10 @@ def main(): ...@@ -370,6 +370,10 @@ def main():
elif data_args.task_name is None and not is_regression: elif data_args.task_name is None and not is_regression:
label_to_id = {v: i for i, v in enumerate(label_list)} label_to_id = {v: i for i, v in enumerate(label_list)}
if label_to_id is not None:
model.config.label2id = label_to_id
model.config.id2label = {id: label for label, id in config.label2id.items()}
if data_args.max_seq_length > tokenizer.model_max_length: if data_args.max_seq_length > tokenizer.model_max_length:
logger.warning( logger.warning(
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the" f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
......
...@@ -282,6 +282,10 @@ def main(): ...@@ -282,6 +282,10 @@ def main():
elif args.task_name is None: elif args.task_name is None:
label_to_id = {v: i for i, v in enumerate(label_list)} label_to_id = {v: i for i, v in enumerate(label_list)}
if label_to_id is not None:
model.config.label2id = label_to_id
model.config.id2label = {id: label for label, id in config.label2id.items()}
padding = "max_length" if args.pad_to_max_length else False padding = "max_length" if args.pad_to_max_length else False
def preprocess_function(examples): def preprocess_function(examples):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment