Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
20b65860
Unverified
Commit
20b65860
authored
Nov 19, 2020
by
Sylvain Gugger
Committed by
GitHub
Nov 19, 2020
Browse files
Fix run_ner script (#8664)
* Fix run_ner script * Pin datasets
parent
ca0109bd
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
8 deletions
+19
-8
examples/requirements.txt
examples/requirements.txt
+1
-1
examples/token-classification/run_ner.py
examples/token-classification/run_ner.py
+18
-7
No files found.
examples/requirements.txt
View file @
20b65860
...
...
@@ -13,7 +13,7 @@ streamlit
elasticsearch
nltk
pandas
datasets
datasets
>= 1.1.3
fire
pytest
conllu
...
...
examples/token-classification/run_ner.py
View file @
20b65860
...
...
@@ -15,7 +15,8 @@
"""
Fine-tuning the library models for token classification.
"""
# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as comments.
# You can also adapt this script on your own token classification task and datasets. Pointers for this are left as
# comments.
import
logging
import
os
...
...
@@ -24,7 +25,7 @@ from dataclasses import dataclass, field
from
typing
import
Optional
import
numpy
as
np
from
datasets
import
load_dataset
from
datasets
import
ClassLabel
,
load_dataset
from
seqeval.metrics
import
accuracy_score
,
f1_score
,
precision_score
,
recall_score
import
transformers
...
...
@@ -198,12 +199,17 @@ def main():
if
training_args
.
do_train
:
column_names
=
datasets
[
"train"
].
column_names
features
=
datasets
[
"train"
].
features
else
:
column_names
=
datasets
[
"validation"
].
column_names
text_column_name
=
"words"
if
"words"
in
column_names
else
column_names
[
0
]
label_column_name
=
data_args
.
task_name
if
data_args
.
task_name
in
column_names
else
column_names
[
1
]
features
=
datasets
[
"validation"
].
features
text_column_name
=
"tokens"
if
"tokens"
in
column_names
else
column_names
[
0
]
label_column_name
=
(
f
"
{
data_args
.
task_name
}
_tags"
if
f
"
{
data_args
.
task_name
}
_tags"
in
column_names
else
column_names
[
1
]
)
# Labeling (this part will be easier when https://github.com/huggingface/datasets/issues/797 is solved)
# In the event the labels are not a `Sequence[ClassLabel]`, we will need to go through the dataset to get the
# unique labels.
def
get_label_list
(
labels
):
unique_labels
=
set
()
for
label
in
labels
:
...
...
@@ -212,6 +218,11 @@ def main():
label_list
.
sort
()
return
label_list
if
isinstance
(
features
[
label_column_name
].
feature
,
ClassLabel
):
label_list
=
features
[
label_column_name
].
feature
.
names
# No need to convert the labels since they are already ints.
label_to_id
=
{
i
:
i
for
i
in
range
(
len
(
label_list
))}
else
:
label_list
=
get_label_list
(
datasets
[
"train"
][
label_column_name
])
label_to_id
=
{
l
:
i
for
i
,
l
in
enumerate
(
label_list
)}
num_labels
=
len
(
label_list
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment