Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
383ef967
Commit
383ef967
authored
Sep 17, 2019
by
Marianne Stecklina
Committed by
thomwolf
Oct 15, 2019
Browse files
Implement fine-tuning BERT on CoNLL-2003 named entity recognition task
parent
5adb39e7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
64 deletions
+30
-64
examples/run_ner.py
examples/run_ner.py
+18
-46
examples/utils_ner.py
examples/utils_ner.py
+12
-18
No files found.
examples/run_ner.py
View file @
383ef967
...
...
@@ -55,7 +55,7 @@ def set_seed(args):
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
def
train
(
args
,
train_dataset
,
model
,
tokenizer
,
labels
,
pad_token_label_id
):
def
train
(
args
,
train_dataset
,
model
,
tokenizer
,
pad_token_label_id
):
""" Train the model """
if
args
.
local_rank
in
[
-
1
,
0
]:
tb_writer
=
SummaryWriter
()
...
...
@@ -148,7 +148,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
logging_steps
>
0
and
global_step
%
args
.
logging_steps
==
0
:
# Log metrics
if
args
.
local_rank
==
-
1
and
args
.
evaluate_during_training
:
# Only evaluate when single GPU otherwise metrics may not average well
results
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
labels
,
pad_token_label_id
)
results
=
evaluate
(
args
,
model
,
tokenizer
,
pad_token_label_id
)
for
key
,
value
in
results
.
items
():
tb_writer
.
add_scalar
(
"eval_{}"
.
format
(
key
),
value
,
global_step
)
tb_writer
.
add_scalar
(
"lr"
,
scheduler
.
get_lr
()[
0
],
global_step
)
...
...
@@ -160,7 +160,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint-{}"
.
format
(
global_step
))
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
# Take care of distributed/parallel training
model_to_save
=
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
# Take care of distributed/parallel training
model_to_save
.
save_pretrained
(
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
...
...
@@ -178,8 +179,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
return
global_step
,
tr_loss
/
global_step
def
evaluate
(
args
,
model
,
tokenizer
,
labels
,
pad_token_label_id
,
mode
,
prefix
=
""
):
eval_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
labels
,
pad_token_label_id
,
mode
=
mod
e
)
def
evaluate
(
args
,
model
,
tokenizer
,
pad_token_label_id
,
prefix
=
""
):
eval_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
pad_token_label_id
,
evaluate
=
Tru
e
)
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
# Note that DistributedSampler samples randomly
...
...
@@ -219,7 +220,7 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
eval_loss
=
eval_loss
/
nb_eval_steps
preds
=
np
.
argmax
(
preds
,
axis
=
2
)
label_map
=
{
i
:
label
for
i
,
label
in
enumerate
(
labels
)}
label_map
=
{
i
:
label
for
i
,
label
in
enumerate
(
get_
labels
()
)}
out_label_list
=
[[]
for
_
in
range
(
out_label_ids
.
shape
[
0
])]
preds_list
=
[[]
for
_
in
range
(
out_label_ids
.
shape
[
0
])]
...
...
@@ -241,15 +242,15 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
for
key
in
sorted
(
results
.
keys
()):
logger
.
info
(
" %s = %s"
,
key
,
str
(
results
[
key
]))
return
results
,
preds_list
return
results
def
load_and_cache_examples
(
args
,
tokenizer
,
labels
,
pad_token_label_id
,
mod
e
):
def
load_and_cache_examples
(
args
,
tokenizer
,
pad_token_label_id
,
evaluate
=
Fals
e
):
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Load data features from cache or dataset file
cached_features_file
=
os
.
path
.
join
(
args
.
data_dir
,
"cached_{}_{}_{}"
.
format
(
mode
,
cached_features_file
=
os
.
path
.
join
(
args
.
data_dir
,
"cached_{}_{}_{}"
.
format
(
"dev"
if
evaluate
else
"train"
,
list
(
filter
(
None
,
args
.
model_name_or_path
.
split
(
"/"
))).
pop
(),
str
(
args
.
max_seq_length
)))
if
os
.
path
.
exists
(
cached_features_file
):
...
...
@@ -257,8 +258,9 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
features
=
torch
.
load
(
cached_features_file
)
else
:
logger
.
info
(
"Creating features from dataset file at %s"
,
args
.
data_dir
)
examples
=
read_examples_from_file
(
args
.
data_dir
,
mode
)
features
=
convert_examples_to_features
(
examples
,
labels
,
args
.
max_seq_length
,
tokenizer
,
label_list
=
get_labels
()
examples
=
read_examples_from_file
(
args
.
data_dir
,
evaluate
=
evaluate
)
features
=
convert_examples_to_features
(
examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
cls_token_at_end
=
bool
(
args
.
model_type
in
[
"xlnet"
]),
# xlnet has a cls token at the end
cls_token
=
tokenizer
.
cls_token
,
...
...
@@ -303,8 +305,6 @@ def main():
help
=
"The output directory where the model predictions and checkpoints will be written."
)
## Other parameters
parser
.
add_argument
(
"--labels"
,
default
=
""
,
type
=
str
,
help
=
"Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
)
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
)
parser
.
add_argument
(
"--tokenizer_name"
,
default
=
""
,
type
=
str
,
...
...
@@ -318,8 +318,6 @@ def main():
help
=
"Whether to run training."
)
parser
.
add_argument
(
"--do_eval"
,
action
=
"store_true"
,
help
=
"Whether to run eval on the dev set."
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
help
=
"Whether to run predictions on the test set."
)
parser
.
add_argument
(
"--evaluate_during_training"
,
action
=
"store_true"
,
help
=
"Whether to run evaluation during training at each logging step."
)
parser
.
add_argument
(
"--do_lower_case"
,
action
=
"store_true"
,
...
...
@@ -408,8 +406,8 @@ def main():
set_seed
(
args
)
# Prepare CONLL-2003 task
label
s
=
get_labels
(
args
.
labels
)
num_labels
=
len
(
label
s
)
label
_list
=
get_labels
()
num_labels
=
len
(
label
_list
)
# Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
pad_token_label_id
=
CrossEntropyLoss
().
ignore_index
...
...
@@ -435,8 +433,8 @@ def main():
# Training
if
args
.
do_train
:
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
labels
,
pad_token_label_id
,
mode
=
"train"
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
,
labels
,
pad_token_label_id
)
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
pad_token_label_id
,
evaluate
=
False
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
,
pad_token_label_id
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
...
...
@@ -468,7 +466,7 @@ def main():
global_step
=
checkpoint
.
split
(
"-"
)[
-
1
]
if
len
(
checkpoints
)
>
1
else
""
model
=
model_class
.
from_pretrained
(
checkpoint
)
model
.
to
(
args
.
device
)
result
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
labels
,
pad_token_label_id
,
mode
=
"dev"
,
prefix
=
global_step
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
pad_token_label_id
,
prefix
=
global_step
)
if
global_step
:
result
=
{
"{}_{}"
.
format
(
global_step
,
k
):
v
for
k
,
v
in
result
.
items
()}
results
.
update
(
result
)
...
...
@@ -477,32 +475,6 @@ def main():
for
key
in
sorted
(
results
.
keys
()):
writer
.
write
(
"{} = {}
\n
"
.
format
(
key
,
str
(
results
[
key
])))
if
args
.
do_predict
and
args
.
local_rank
in
[
-
1
,
0
]:
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
model
=
model_class
.
from_pretrained
(
args
.
output_dir
)
model
.
to
(
args
.
device
)
result
,
predictions
=
evaluate
(
args
,
model
,
tokenizer
,
labels
,
pad_token_label_id
,
mode
=
"test"
)
# Save results
output_test_results_file
=
os
.
path
.
join
(
args
.
output_dir
,
"test_results.txt"
)
with
open
(
output_test_results_file
,
"w"
)
as
writer
:
for
key
in
sorted
(
result
.
keys
()):
writer
.
write
(
"{} = {}
\n
"
.
format
(
key
,
str
(
result
[
key
])))
# Save predictions
output_test_predictions_file
=
os
.
path
.
join
(
args
.
output_dir
,
"test_predictions.txt"
)
with
open
(
output_test_predictions_file
,
"w"
)
as
writer
:
with
open
(
os
.
path
.
join
(
args
.
data_dir
,
"test.txt"
),
"r"
)
as
f
:
example_id
=
0
for
line
in
f
:
if
line
.
startswith
(
"-DOCSTART-"
)
or
line
==
""
or
line
==
"
\n
"
:
writer
.
write
(
line
)
if
not
predictions
[
example_id
]:
example_id
+=
1
elif
predictions
[
example_id
]:
output_line
=
line
.
split
()[
0
]
+
" "
+
predictions
[
example_id
].
pop
(
0
)
+
"
\n
"
writer
.
write
(
output_line
)
else
:
logger
.
warning
(
"Maximum sequence length exceeded: No prediction for '%s'."
,
line
.
split
()[
0
])
return
results
...
...
examples/utils_ner.py
View file @
383ef967
...
...
@@ -51,8 +51,13 @@ class InputFeatures(object):
self
.
label_ids
=
label_ids
def
read_examples_from_file
(
data_dir
,
mode
):
file_path
=
os
.
path
.
join
(
data_dir
,
"{}.txt"
.
format
(
mode
))
def
read_examples_from_file
(
data_dir
,
evaluate
=
False
):
if
evaluate
:
file_path
=
os
.
path
.
join
(
data_dir
,
"dev.txt"
)
guid_prefix
=
"dev"
else
:
file_path
=
os
.
path
.
join
(
data_dir
,
"train.txt"
)
guid_prefix
=
"train"
guid_index
=
1
examples
=
[]
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
f
:
...
...
@@ -61,7 +66,7 @@ def read_examples_from_file(data_dir, mode):
for
line
in
f
:
if
line
.
startswith
(
"-DOCSTART-"
)
or
line
==
""
or
line
==
"
\n
"
:
if
words
:
examples
.
append
(
InputExample
(
guid
=
"{}-{}"
.
format
(
mode
,
guid_index
),
examples
.
append
(
InputExample
(
guid
=
"{}-{}"
.
format
(
guid_prefix
,
guid_index
),
words
=
words
,
labels
=
labels
))
guid_index
+=
1
...
...
@@ -70,13 +75,9 @@ def read_examples_from_file(data_dir, mode):
else
:
splits
=
line
.
split
(
" "
)
words
.
append
(
splits
[
0
])
if
len
(
splits
)
>
1
:
labels
.
append
(
splits
[
-
1
].
replace
(
"
\n
"
,
""
))
else
:
# Examples could have no label for mode = "test"
labels
.
append
(
"O"
)
labels
.
append
(
splits
[
-
1
][:
-
1
])
if
words
:
examples
.
append
(
InputExample
(
guid
=
"%s-%d"
.
format
(
mode
,
guid_index
),
examples
.
append
(
InputExample
(
guid
=
"%s-%d"
.
format
(
guid_prefix
,
guid_index
),
words
=
words
,
labels
=
labels
))
return
examples
...
...
@@ -201,12 +202,5 @@ def convert_examples_to_features(examples,
return
features
def
get_labels
(
path
):
if
path
:
with
open
(
path
,
"r"
)
as
f
:
labels
=
f
.
read
().
splitlines
()
if
"O"
not
in
labels
:
labels
=
[
"O"
]
+
labels
return
labels
else
:
def
get_labels
():
return
[
"O"
,
"B-MISC"
,
"I-MISC"
,
"B-PER"
,
"I-PER"
,
"B-ORG"
,
"I-ORG"
,
"B-LOC"
,
"I-LOC"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment