Unverified Commit f875fb0e authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Fix label attribution in token classification examples (#14055)

parent 31560f63
...@@ -303,6 +303,14 @@ def main(): ...@@ -303,6 +303,14 @@ def main():
label_to_id = {l: i for i, l in enumerate(label_list)} label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list) num_labels = len(label_list)
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []
for idx, label in enumerate(label_list):
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
else:
b_to_i_label.append(idx)
# Load pretrained model and tokenizer # Load pretrained model and tokenizer
# #
# Distributed training: # Distributed training:
...@@ -385,7 +393,10 @@ def main(): ...@@ -385,7 +393,10 @@ def main():
# For the other tokens in a word, we set the label to either the current label or -100, depending on # For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag. # the label_all_tokens flag.
else: else:
label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100) if data_args.label_all_tokens:
label_ids.append(b_to_i_label[label_to_id[label[word_idx]]])
else:
label_ids.append(-100)
previous_word_idx = word_idx previous_word_idx = word_idx
labels.append(label_ids) labels.append(label_ids)
......
...@@ -328,6 +328,14 @@ def main(): ...@@ -328,6 +328,14 @@ def main():
label_to_id = {l: i for i, l in enumerate(label_list)} label_to_id = {l: i for i, l in enumerate(label_list)}
num_labels = len(label_list) num_labels = len(label_list)
# Map that sends B-Xxx label to its I-Xxx counterpart
b_to_i_label = []
for idx, label in enumerate(label_list):
if label.startswith("B-") and label.replace("B-", "I-") in label_list:
b_to_i_label.append(label_list.index(label.replace("B-", "I-")))
else:
b_to_i_label.append(idx)
# Load pretrained model and tokenizer # Load pretrained model and tokenizer
# #
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
...@@ -396,7 +404,10 @@ def main(): ...@@ -396,7 +404,10 @@ def main():
# For the other tokens in a word, we set the label to either the current label or -100, depending on # For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag. # the label_all_tokens flag.
else: else:
label_ids.append(label_to_id[label[word_idx]] if args.label_all_tokens else -100) if args.label_all_tokens:
label_ids.append(b_to_i_label[label_to_id[label[word_idx]]])
else:
label_ids.append(-100)
previous_word_idx = word_idx previous_word_idx = word_idx
labels.append(label_ids) labels.append(label_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment