Unverified Commit 7f9ccffc authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Use word_ids to get labels in run_ner (#8962)

* Use word_ids to get labels in run_ner

* Add sanity check
parent de6befd4
...@@ -35,6 +35,7 @@ from transformers import ( ...@@ -35,6 +35,7 @@ from transformers import (
AutoTokenizer, AutoTokenizer,
DataCollatorForTokenClassification, DataCollatorForTokenClassification,
HfArgumentParser, HfArgumentParser,
PreTrainedTokenizerFast,
Trainer, Trainer,
TrainingArguments, TrainingArguments,
set_seed, set_seed,
...@@ -250,6 +251,14 @@ def main(): ...@@ -250,6 +251,14 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
) )
# Tokenizer check: this script requires a fast tokenizer.
if not isinstance(tokenizer, PreTrainedTokenizerFast):
raise ValueError(
"This example script only works for models that have a fast tokenizer. Checkout the big table of models "
"at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this "
"requirement"
)
# Preprocessing the dataset # Preprocessing the dataset
# Padding strategy # Padding strategy
padding = "max_length" if data_args.pad_to_max_length else False padding = "max_length" if data_args.pad_to_max_length else False
...@@ -262,28 +271,25 @@ def main(): ...@@ -262,28 +271,25 @@ def main():
truncation=True, truncation=True,
# We use this argument because the texts in our dataset are lists of words (with a label for each word). # We use this argument because the texts in our dataset are lists of words (with a label for each word).
is_split_into_words=True, is_split_into_words=True,
return_offsets_mapping=True,
) )
offset_mappings = tokenized_inputs.pop("offset_mapping")
labels = [] labels = []
for label, offset_mapping in zip(examples[label_column_name], offset_mappings): for i, label in enumerate(examples[label_column_name]):
label_index = 0 word_ids = tokenized_inputs.word_ids(batch_index=i)
current_label = -100 previous_word_idx = None
label_ids = [] label_ids = []
for offset in offset_mapping: for word_idx in word_ids:
# We set the label for the first token of each word. Special characters will have an offset of (0, 0) # Special tokens have a word id that is None. We set the label to -100 so they are automatically
# so the test ignores them. # ignored in the loss function.
if offset[0] == 0 and offset[1] != 0: if word_idx is None:
current_label = label_to_id[label[label_index]]
label_index += 1
label_ids.append(current_label)
# For special tokens, we set the label to -100 so it's automatically ignored in the loss function.
elif offset[0] == 0 and offset[1] == 0:
label_ids.append(-100) label_ids.append(-100)
# We set the label for the first token of each word.
elif word_idx != previous_word_idx:
label_ids.append(label_to_id[label[word_idx]])
# For the other tokens in a word, we set the label to either the current label or -100, depending on # For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag. # the label_all_tokens flag.
else: else:
label_ids.append(current_label if data_args.label_all_tokens else -100) label_ids.append(label_to_id[label[word_idx]] if data_args.label_all_tokens else -100)
previous_word_idx = word_idx
labels.append(label_ids) labels.append(label_ids)
tokenized_inputs["labels"] = labels tokenized_inputs["labels"] = labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment