Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
7f9ccffc
Unverified
Commit
7f9ccffc
authored
Dec 07, 2020
by
Sylvain Gugger
Committed by
GitHub
Dec 07, 2020
Browse files
Use word_ids to get labels in run_ner (#8962)
* Use word_ids to get labels in run_ner * Add sanity check
parent
de6befd4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
21 additions
and
15 deletions
+21
-15
examples/token-classification/run_ner.py
examples/token-classification/run_ner.py
+21
-15
No files found.
examples/token-classification/run_ner.py
View file @
7f9ccffc
...
@@ -35,6 +35,7 @@ from transformers import (
...
@@ -35,6 +35,7 @@ from transformers import (
AutoTokenizer
,
AutoTokenizer
,
DataCollatorForTokenClassification
,
DataCollatorForTokenClassification
,
HfArgumentParser
,
HfArgumentParser
,
PreTrainedTokenizerFast
,
Trainer
,
Trainer
,
TrainingArguments
,
TrainingArguments
,
set_seed
,
set_seed
,
...
@@ -250,6 +251,14 @@ def main():
...
@@ -250,6 +251,14 @@ def main():
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
)
)
# Tokenizer check: this script requires a fast tokenizer.
if
not
isinstance
(
tokenizer
,
PreTrainedTokenizerFast
):
raise
ValueError
(
"This example script only works for models that have a fast tokenizer. Checkout the big table of models "
"at https://huggingface.co/transformers/index.html#bigtable to find the model types that meet this "
"requirement"
)
# Preprocessing the dataset
# Preprocessing the dataset
# Padding strategy
# Padding strategy
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
False
padding
=
"max_length"
if
data_args
.
pad_to_max_length
else
False
...
@@ -262,28 +271,25 @@ def main():
...
@@ -262,28 +271,25 @@ def main():
truncation
=
True
,
truncation
=
True
,
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
# We use this argument because the texts in our dataset are lists of words (with a label for each word).
is_split_into_words
=
True
,
is_split_into_words
=
True
,
return_offsets_mapping
=
True
,
)
)
offset_mappings
=
tokenized_inputs
.
pop
(
"offset_mapping"
)
labels
=
[]
labels
=
[]
for
label
,
offset_mapping
in
zip
(
examples
[
label_column_name
]
,
offset_mappings
):
for
i
,
label
in
enumerate
(
examples
[
label_column_name
]):
label
_index
=
0
word_ids
=
tokenized_inputs
.
word_ids
(
batch
_index
=
i
)
current_label
=
-
100
previous_word_idx
=
None
label_ids
=
[]
label_ids
=
[]
for
offset
in
offset_mapping
:
for
word_idx
in
word_ids
:
# We set the label for the first token of each word. Special characters will have an offset of (0, 0)
# Special tokens have a word id that is None. We set the label to -100 so they are automatically
# so the test ignores them.
# ignored in the loss function.
if
offset
[
0
]
==
0
and
offset
[
1
]
!=
0
:
if
word_idx
is
None
:
current_label
=
label_to_id
[
label
[
label_index
]]
label_index
+=
1
label_ids
.
append
(
current_label
)
# For special tokens, we set the label to -100 so it's automatically ignored in the loss function.
elif
offset
[
0
]
==
0
and
offset
[
1
]
==
0
:
label_ids
.
append
(
-
100
)
label_ids
.
append
(
-
100
)
# We set the label for the first token of each word.
elif
word_idx
!=
previous_word_idx
:
label_ids
.
append
(
label_to_id
[
label
[
word_idx
]])
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# For the other tokens in a word, we set the label to either the current label or -100, depending on
# the label_all_tokens flag.
# the label_all_tokens flag.
else
:
else
:
label_ids
.
append
(
current_label
if
data_args
.
label_all_tokens
else
-
100
)
label_ids
.
append
(
label_to_id
[
label
[
word_idx
]]
if
data_args
.
label_all_tokens
else
-
100
)
previous_word_idx
=
word_idx
labels
.
append
(
label_ids
)
labels
.
append
(
label_ids
)
tokenized_inputs
[
"labels"
]
=
labels
tokenized_inputs
[
"labels"
]
=
labels
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment