Unverified Commit 45cac3fa authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix labels stored in model config for token classification examples (#15482)

* Playing

* Properly set labels in model config for token classification example

* Port to run_ner_no_trainer

* Quality
parent c74f3d4c
...@@ -295,12 +295,15 @@ def main(): ...@@ -295,12 +295,15 @@ def main():
label_list.sort() label_list.sort()
return label_list return label_list
if isinstance(features[label_column_name].feature, ClassLabel): # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
# Otherwise, we have to get the list of labels manually.
labels_are_int = isinstance(features[label_column_name].feature, ClassLabel)
if labels_are_int:
label_list = features[label_column_name].feature.names label_list = features[label_column_name].feature.names
label_keys = list(range(len(label_list))) label_to_id = {i: i for i in range(len(label_list))}
else: else:
label_list = get_label_list(raw_datasets["train"][label_column_name]) label_list = get_label_list(raw_datasets["train"][label_column_name])
label_keys = label_list label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list) num_labels = len(label_list)
...@@ -354,21 +357,26 @@ def main(): ...@@ -354,21 +357,26 @@ def main():
"requirement" "requirement"
) )
# Model has labels -> use them.
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
label_name_to_id = {k: v for k, v in model.config.label2id.items()} if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)):
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): # Reorganize `label_list` to match the ordering of the model.
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys} if labels_are_int:
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
label_list = [model.config.id2label[i] for i in range(num_labels)]
else:
label_list = [model.config.id2label[i] for i in range(num_labels)]
label_to_id = {l: i for i, l in enumerate(label_list)}
else: else:
logger.warning( logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ", "Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.", "\nIgnoring the model labels as a result.",
) )
else:
label_to_id = {k: i for i, k in enumerate(label_keys)}
model.config.label2id = label_to_id # Set the correspondences label/ID inside the model config
model.config.id2label = {i: l for l, i in label_to_id.items()} model.config.label2id = {l: i for i, l in enumerate(label_list)}
model.config.id2label = {i: l for i, l in enumerate(label_list)}
# Map that sends B-Xxx label to its I-Xxx counterpart # Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = [] b_to_i_label = []
......
...@@ -320,12 +320,15 @@ def main(): ...@@ -320,12 +320,15 @@ def main():
label_list.sort() label_list.sort()
return label_list return label_list
if isinstance(features[label_column_name].feature, ClassLabel): # If the labels are of type ClassLabel, they are already integers and we have the map stored somewhere.
# Otherwise, we have to get the list of labels manually.
labels_are_int = isinstance(features[label_column_name].feature, ClassLabel)
if labels_are_int:
label_list = features[label_column_name].feature.names label_list = features[label_column_name].feature.names
label_keys = list(range(len(label_list))) label_to_id = {i: i for i in range(len(label_list))}
else: else:
label_list = get_label_list(raw_datasets["train"][label_column_name]) label_list = get_label_list(raw_datasets["train"][label_column_name])
label_keys = label_list label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list) num_labels = len(label_list)
...@@ -365,21 +368,26 @@ def main(): ...@@ -365,21 +368,26 @@ def main():
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
# Model has labels -> use them.
if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id: if model.config.label2id != PretrainedConfig(num_labels=num_labels).label2id:
label_name_to_id = {k: v for k, v in model.config.label2id.items()} if list(sorted(model.config.label2id.keys())) == list(sorted(label_list)):
if list(sorted(label_name_to_id.keys())) == list(sorted(label_list)): # Reorganize `label_list` to match the ordering of the model.
label_to_id = {k: int(label_name_to_id[k]) for k in label_keys} if labels_are_int:
label_to_id = {i: int(model.config.label2id[l]) for i, l in enumerate(label_list)}
label_list = [model.config.id2label[i] for i in range(num_labels)]
else:
label_list = [model.config.id2label[i] for i in range(num_labels)]
label_to_id = {l: i for i, l in enumerate(label_list)}
else: else:
logger.warning( logger.warning(
"Your model seems to have been trained with labels, but they don't match the dataset: ", "Your model seems to have been trained with labels, but they don't match the dataset: ",
f"model labels: {list(sorted(label_name_to_id.keys()))}, dataset labels: {list(sorted(label_list))}." f"model labels: {list(sorted(model.config.label2id.keys()))}, dataset labels: {list(sorted(label_list))}."
"\nIgnoring the model labels as a result.", "\nIgnoring the model labels as a result.",
) )
else:
label_to_id = {k: i for i, k in enumerate(label_keys)}
model.config.label2id = label_to_id # Set the correspondences label/ID inside the model config
model.config.id2label = {i: l for l, i in label_to_id.items()} model.config.label2id = {l: i for i, l in enumerate(label_list)}
model.config.id2label = {i: l for i, l in enumerate(label_list)}
# Map that sends B-Xxx label to its I-Xxx counterpart # Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = [] b_to_i_label = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment