Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7f9ccffc
Unverified
Commit
7f9ccffc
authored
Dec 07, 2020
by
Sylvain Gugger
Committed by
GitHub
Dec 07, 2020
Browse files
Use word_ids to get labels in run_ner (#8962)
* Use word_ids to get labels in run_ner * Add sanity check
parent
de6befd4
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
15 deletions
+21
-15
examples/token-classification/run_ner.py
examples/token-classification/run_ner.py
+21
-15
No files found.
examples/token-classification/run_ner.py
View file @
7f9ccffc
...
@@ -35,6 +35,7 @@ from transformers import (
...
@@ -35,6 +35,7 @@ from transformers import (
AutoTokenizer
,
AutoTokenizer
,
DataCollatorForTokenClassification
,
DataCollatorForTokenClassification
,
HfArgumentParser
,
HfArgumentParser
,
PreTrainedTokenizerFast
,
Trainer
,
Trainer
,
TrainingArguments
,
TrainingArguments
,
set_seed
,
set_seed
,
...
@@ -250,6 +251,14 @@ def main():
...
@@ -250,6 +251,14 @@ def main():
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
)
)
# Tokenizer check: this script requires a fast tokenizer.
if
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
raise
ValueError
(
"This example script only works for models that have a fast tokenizer. Checkout the big table of models "
"at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this "
"requirement"
)
# Preprocessing the dataset
# Preprocessing the dataset
# Padding strategy
# Padding strategy
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
False
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
False
...
@@ -262,28 +271,25 @@ def main():
...
@@ -262,28 +271,25 @@ def main():
truncation
=
True
,
truncation
=
True
,
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
is_split_into_words
=
True
,
is_split_into_words
=
True
,
return_offsets_mapping
=
True
,
)
)
offset_mappings
=
tokenized_inputs
.
pop
(
"offset_mapping"
)
labels
=
[]
labels
=
[]
for
label
,
offset_mapping
in
zip
(
examples
[
label_column_name
]
,
offset_mappings
):
for
i
,
label
in
enumerate
(
examples
[
label_column_name
]):
label
_index
=
0
word_ids
=
tokenized_inputs
.
word_ids
(
batch
_index
=
i
)
current_label
=
-
100
previous_word_idx
=
None
label_ids
=
[]
label_ids
=
[]
for
offset
in
offset_mapping
:
for
word_idx
in
word_ids
:
# We set the label for the first token of each word. Special characters will have an offset of (0, 0)
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
# so the test ignores them.
# ignored in the loss function.
if
offset
[
0
]
==
0
and
offset
[
1
]
!=
0
:
if
word_idx
is
None
:
current_label
=
label_to_id
[
label
[
label_index
]]
label_index
+=
1
label_ids
.
append
(
current_label
)
# For special tokens, we set the label to -100 so it's automatically ignored in the loss function.
elif
offset
[
0
]
==
0
and
offset
[
1
]
==
0
:
label_ids
.
append
(
-
100
)
label_ids
.
append
(
-
100
)
# We set the label for the first token of each word.
elif
word_idx
!=
previous_word_idx
:
label_ids
.
append
(
label_to_id
[
label
[
word_idx
]])
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag.
# the label_all_tokens flag.
else
:
else
:
label_ids
.
append
(
current_label
if
data_args
.
label_all_tokens
else
-
100
)
label_ids
.
append
(
label_to_id
[
label
[
word_idx
]]
if
data_args
.
label_all_tokens
else
-
100
)
previous_word_idx
=
word_idx
labels
.
append
(
label_ids
)
labels
.
append
(
label_ids
)
tokenized_inputs
[
"labels"
]
=
labels
tokenized_inputs
[
"labels"
]
=
labels
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment