Unverified Commit 457dd439 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Examples] Correct run ner label2id for fine-tuned models (#15017)



* up

* up

* make style

* apply sylvains suggestions

* apply changes to accelerate as well

* more changes

* Apply suggestions from code review
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 8d6acc6c
...@@ -36,6 +36,7 @@ from transformers import ( ...@@ -36,6 +36,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
DataCollatorForTokenClassification, DataCollatorForTokenClassification,
HfArgumentParser, HfArgumentParser,
PretrainedConfig,
PreTrainedTokenizerFast, PreTrainedTokenizerFast,
Trainer, Trainer,
TrainingArguments, TrainingArguments,
...@@ -296,20 +297,12 @@ def main(): ...@@ -296,20 +297,12 @@ def main():
if isinstance(features[label_column_name].feature, ClassLabel): if isinstance(features[label_column_name].feature, ClassLabel):
label_list = features[label_column_name].feature.names label_list = features[label_column_name].feature.names
# No need to convert the labels since they are already ints. label_keys = list(range(len(label_list)))
label_to_id = {i: i for i in range(len(label_list))}
else: else:
label_list = get_label_list(raw_datasets["train"][label_column_name]) label_list = get_label_list(raw_datasets["train"][label_column_name])
label_to_id = {l: i for i, l in enumerate(label_list)} label_keys = label_list
num_labels = len(label_list)
# Map that sends B-Xxx label to its I-Xxx counterpart num_labels = len(label_list)
b_to_i_label = []
for idx, label in enumerate(label_list):
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
else:
b_to_i_label.append(idx)
# Load pretrained model and tokenizer # Load pretrained model and tokenizer
# #
...@@ -319,8 +312,6 @@ def main(): ...@@ -319,8 +312,6 @@ def main():
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model_args.config_name if model_args.config_name else model_args.model_name_or_path, model_args.config_name if model_args.config_name else model_args.model_name_or_path,
num_labels=num_labels, num_labels=num_labels,
label2id=label_to_id,
id2label={i: l for l, i in label_to_id.items()},
finetuning_task=data_args.task_name, finetuning_task=data_args.task_name,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
...@@ -363,6 +354,30 @@ def main(): ...@@ -363,6 +354,30 @@ def main():
"requirement" "requirement"
) )
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
label_name_to_id = {k: v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys}
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.",
)
else:
label_to_id = {k: i for i, k in enumerate(label_keys)}
model.config.label2id = label_to_id
model.config.id2label = {i: l for l, i in label_to_id.items()}
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []
for idx, label in enumerate(label_list):
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
else:
b_to_i_label.append(idx)
# Preprocessing the dataset # Preprocessing the dataset
# Padding strategy # Padding strategy
padding = "max_length" if data_args.pad_to_max_length else False padding = "max_length" if data_args.pad_to_max_length else False
......
...@@ -42,6 +42,7 @@ from transformers import ( ...@@ -42,6 +42,7 @@ from transformers import (
AutoModelForTokenClassification, AutoModelForTokenClassification,
AutoTokenizer, AutoTokenizer,
DataCollatorForTokenClassification, DataCollatorForTokenClassification,
PretrainedConfig,
SchedulerType, SchedulerType,
default_data_collator, default_data_collator,
get_scheduler, get_scheduler,
...@@ -321,20 +322,12 @@ def main(): ...@@ -321,20 +322,12 @@ def main():
if isinstance(features[label_column_name].feature, ClassLabel): if isinstance(features[label_column_name].feature, ClassLabel):
label_list = features[label_column_name].feature.names label_list = features[label_column_name].feature.names
# No need to convert the labels since they are already ints. label_keys = list(range(len(label_list)))
label_to_id = {i: i for i in range(len(label_list))}
else: else:
label_list = get_label_list(raw_datasets["train"][label_column_name]) label_list = get_label_list(raw_datasets["train"][label_column_name])
label_to_id = {l: i for i, l in enumerate(label_list)} label_keys = label_list
num_labels = len(label_list)
# Map that sends B-Xxx label to its I-Xxx counterpart num_labels = len(label_list)
b_to_i_label = []
for idx, label in enumerate(label_list):
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
else:
b_to_i_label.append(idx)
# Load pretrained model and tokenizer # Load pretrained model and tokenizer
# #
...@@ -372,6 +365,30 @@ def main(): ...@@ -372,6 +365,30 @@ def main():
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
label_name_to_id = {k: v for k, v in model.config.label2id.items()}
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)):
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys}
else:
logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.",
)
else:
label_to_id = {k: i for i, k in enumerate(label_keys)}
model.config.label2id = label_to_id
model.config.id2label = {i: l for l, i in label_to_id.items()}
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []
for idx, label in enumerate(label_list):
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
else:
b_to_i_label.append(idx)
# Preprocessing the datasets. # Preprocessing the datasets.
# First we tokenize all the texts. # First we tokenize all the texts.
padding = "max_length" if args.pad_to_max_length else False padding = "max_length" if args.pad_to_max_length else False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment