Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
383ef967
Commit
383ef967
authored
Sep 17, 2019
by
Marianne Stecklina
Committed by
thomwolf
Oct 15, 2019
Browse files
Implement fine-tuning BERT on CoNLL-2003 named entity recognition task
parent
5adb39e7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
30 additions
and
64 deletions
+30
-64
examples/run_ner.py
examples/run_ner.py
+18
-46
examples/utils_ner.py
examples/utils_ner.py
+12
-18
No files found.
examples/run_ner.py
View file @
383ef967
...
@@ -55,7 +55,7 @@ def set_seed(args):
...
@@ -55,7 +55,7 @@ def set_seed(args):
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
def
train
(
args
,
train_dataset
,
model
,
tokenizer
,
labels
,
pad_token_label_id
):
def
train
(
args
,
train_dataset
,
model
,
tokenizer
,
pad_token_label_id
):
""" Train the model """
""" Train the model """
if
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
tb_writer
=
SummaryWriter
()
tb_writer
=
SummaryWriter
()
...
@@ -148,7 +148,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
...
@@ -148,7 +148,7 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
logging_steps
>
0
and
global_step
%
args
.
logging_steps
==
0
:
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
logging_steps
>
0
and
global_step
%
args
.
logging_steps
==
0
:
# Log metrics
# Log metrics
if
args
.
local_rank
==
-
1
and
args
.
evaluate_during_training
:
# Only evaluate when single GPU otherwise metrics may not average well
if
args
.
local_rank
==
-
1
and
args
.
evaluate_during_training
:
# Only evaluate when single GPU otherwise metrics may not average well
results
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
labels
,
pad_token_label_id
)
results
=
evaluate
(
args
,
model
,
tokenizer
,
pad_token_label_id
)
for
key
,
value
in
results
.
items
():
for
key
,
value
in
results
.
items
():
tb_writer
.
add_scalar
(
"eval_{}"
.
format
(
key
),
value
,
global_step
)
tb_writer
.
add_scalar
(
"eval_{}"
.
format
(
key
),
value
,
global_step
)
tb_writer
.
add_scalar
(
"lr"
,
scheduler
.
get_lr
()[
0
],
global_step
)
tb_writer
.
add_scalar
(
"lr"
,
scheduler
.
get_lr
()[
0
],
global_step
)
...
@@ -160,7 +160,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
...
@@ -160,7 +160,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint-{}"
.
format
(
global_step
))
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
"checkpoint-{}"
.
format
(
global_step
))
if
not
os
.
path
.
exists
(
output_dir
):
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
os
.
makedirs
(
output_dir
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
# Take care of distributed/parallel training
model_to_save
=
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
# Take care of distributed/parallel training
model_to_save
.
save_pretrained
(
output_dir
)
model_to_save
.
save_pretrained
(
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
...
@@ -178,8 +179,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
...
@@ -178,8 +179,8 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
return
global_step
,
tr_loss
/
global_step
return
global_step
,
tr_loss
/
global_step
def
evaluate
(
args
,
model
,
tokenizer
,
labels
,
pad_token_label_id
,
mode
,
prefix
=
""
):
def
evaluate
(
args
,
model
,
tokenizer
,
pad_token_label_id
,
prefix
=
""
):
eval_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
labels
,
pad_token_label_id
,
mode
=
mod
e
)
eval_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
pad_token_label_id
,
evaluate
=
Tru
e
)
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
max
(
1
,
args
.
n_gpu
)
# Note that DistributedSampler samples randomly
# Note that DistributedSampler samples randomly
...
@@ -219,7 +220,7 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
...
@@ -219,7 +220,7 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
eval_loss
=
eval_loss
/
nb_eval_steps
eval_loss
=
eval_loss
/
nb_eval_steps
preds
=
np
.
argmax
(
preds
,
axis
=
2
)
preds
=
np
.
argmax
(
preds
,
axis
=
2
)
label_map
=
{
i
:
label
for
i
,
label
in
enumerate
(
labels
)}
label_map
=
{
i
:
label
for
i
,
label
in
enumerate
(
get_
labels
()
)}
out_label_list
=
[[]
for
_
in
range
(
out_label_ids
.
shape
[
0
])]
out_label_list
=
[[]
for
_
in
range
(
out_label_ids
.
shape
[
0
])]
preds_list
=
[[]
for
_
in
range
(
out_label_ids
.
shape
[
0
])]
preds_list
=
[[]
for
_
in
range
(
out_label_ids
.
shape
[
0
])]
...
@@ -241,15 +242,15 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
...
@@ -241,15 +242,15 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
for
key
in
sorted
(
results
.
keys
()):
for
key
in
sorted
(
results
.
keys
()):
logger
.
info
(
" %s = %s"
,
key
,
str
(
results
[
key
]))
logger
.
info
(
" %s = %s"
,
key
,
str
(
results
[
key
]))
return
results
,
preds_list
return
results
def
load_and_cache_examples
(
args
,
tokenizer
,
labels
,
pad_token_label_id
,
mod
e
):
def
load_and_cache_examples
(
args
,
tokenizer
,
pad_token_label_id
,
evaluate
=
Fals
e
):
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Load data features from cache or dataset file
# Load data features from cache or dataset file
cached_features_file
=
os
.
path
.
join
(
args
.
data_dir
,
"cached_{}_{}_{}"
.
format
(
mode
,
cached_features_file
=
os
.
path
.
join
(
args
.
data_dir
,
"cached_{}_{}_{}"
.
format
(
"dev"
if
evaluate
else
"train"
,
list
(
filter
(
None
,
args
.
model_name_or_path
.
split
(
"/"
))).
pop
(),
list
(
filter
(
None
,
args
.
model_name_or_path
.
split
(
"/"
))).
pop
(),
str
(
args
.
max_seq_length
)))
str
(
args
.
max_seq_length
)))
if
os
.
path
.
exists
(
cached_features_file
):
if
os
.
path
.
exists
(
cached_features_file
):
...
@@ -257,8 +258,9 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
...
@@ -257,8 +258,9 @@ def load_and_cache_examples(args, tokenizer, labels, pad_token_label_id, mode):
features
=
torch
.
load
(
cached_features_file
)
features
=
torch
.
load
(
cached_features_file
)
else
:
else
:
logger
.
info
(
"Creating features from dataset file at %s"
,
args
.
data_dir
)
logger
.
info
(
"Creating features from dataset file at %s"
,
args
.
data_dir
)
examples
=
read_examples_from_file
(
args
.
data_dir
,
mode
)
label_list
=
get_labels
()
features
=
convert_examples_to_features
(
examples
,
labels
,
args
.
max_seq_length
,
tokenizer
,
examples
=
read_examples_from_file
(
args
.
data_dir
,
evaluate
=
evaluate
)
features
=
convert_examples_to_features
(
examples
,
label_list
,
args
.
max_seq_length
,
tokenizer
,
cls_token_at_end
=
bool
(
args
.
model_type
in
[
"xlnet"
]),
cls_token_at_end
=
bool
(
args
.
model_type
in
[
"xlnet"
]),
# xlnet has a cls token at the end
# xlnet has a cls token at the end
cls_token
=
tokenizer
.
cls_token
,
cls_token
=
tokenizer
.
cls_token
,
...
@@ -303,8 +305,6 @@ def main():
...
@@ -303,8 +305,6 @@ def main():
help
=
"The output directory where the model predictions and checkpoints will be written."
)
help
=
"The output directory where the model predictions and checkpoints will be written."
)
## Other parameters
## Other parameters
parser
.
add_argument
(
"--labels"
,
default
=
""
,
type
=
str
,
help
=
"Path to a file containing all labels. If not specified, CoNLL-2003 labels are used."
)
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
help
=
"Pretrained config name or path if not the same as model_name"
)
help
=
"Pretrained config name or path if not the same as model_name"
)
parser
.
add_argument
(
"--tokenizer_name"
,
default
=
""
,
type
=
str
,
parser
.
add_argument
(
"--tokenizer_name"
,
default
=
""
,
type
=
str
,
...
@@ -318,8 +318,6 @@ def main():
...
@@ -318,8 +318,6 @@ def main():
help
=
"Whether to run training."
)
help
=
"Whether to run training."
)
parser
.
add_argument
(
"--do_eval"
,
action
=
"store_true"
,
parser
.
add_argument
(
"--do_eval"
,
action
=
"store_true"
,
help
=
"Whether to run eval on the dev set."
)
help
=
"Whether to run eval on the dev set."
)
parser
.
add_argument
(
"--do_predict"
,
action
=
"store_true"
,
help
=
"Whether to run predictions on the test set."
)
parser
.
add_argument
(
"--evaluate_during_training"
,
action
=
"store_true"
,
parser
.
add_argument
(
"--evaluate_during_training"
,
action
=
"store_true"
,
help
=
"Whether to run evaluation during training at each logging step."
)
help
=
"Whether to run evaluation during training at each logging step."
)
parser
.
add_argument
(
"--do_lower_case"
,
action
=
"store_true"
,
parser
.
add_argument
(
"--do_lower_case"
,
action
=
"store_true"
,
...
@@ -408,8 +406,8 @@ def main():
...
@@ -408,8 +406,8 @@ def main():
set_seed
(
args
)
set_seed
(
args
)
# Prepare CONLL-2003 task
# Prepare CONLL-2003 task
label
s
=
get_labels
(
args
.
labels
)
label
_list
=
get_labels
()
num_labels
=
len
(
label
s
)
num_labels
=
len
(
label
_list
)
# Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
# Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
pad_token_label_id
=
CrossEntropyLoss
().
ignore_index
pad_token_label_id
=
CrossEntropyLoss
().
ignore_index
...
@@ -435,8 +433,8 @@ def main():
...
@@ -435,8 +433,8 @@ def main():
# Training
# Training
if
args
.
do_train
:
if
args
.
do_train
:
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
labels
,
pad_token_label_id
,
mode
=
"train"
)
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
pad_token_label_id
,
evaluate
=
False
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
,
labels
,
pad_token_label_id
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
,
pad_token_label_id
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
...
@@ -468,7 +466,7 @@ def main():
...
@@ -468,7 +466,7 @@ def main():
global_step
=
checkpoint
.
split
(
"-"
)[
-
1
]
if
len
(
checkpoints
)
>
1
else
""
global_step
=
checkpoint
.
split
(
"-"
)[
-
1
]
if
len
(
checkpoints
)
>
1
else
""
model
=
model_class
.
from_pretrained
(
checkpoint
)
model
=
model_class
.
from_pretrained
(
checkpoint
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
result
,
_
=
evaluate
(
args
,
model
,
tokenizer
,
labels
,
pad_token_label_id
,
mode
=
"dev"
,
prefix
=
global_step
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
pad_token_label_id
,
prefix
=
global_step
)
if
global_step
:
if
global_step
:
result
=
{
"{}_{}"
.
format
(
global_step
,
k
):
v
for
k
,
v
in
result
.
items
()}
result
=
{
"{}_{}"
.
format
(
global_step
,
k
):
v
for
k
,
v
in
result
.
items
()}
results
.
update
(
result
)
results
.
update
(
result
)
...
@@ -477,32 +475,6 @@ def main():
...
@@ -477,32 +475,6 @@ def main():
for
key
in
sorted
(
results
.
keys
()):
for
key
in
sorted
(
results
.
keys
()):
writer
.
write
(
"{} = {}
\n
"
.
format
(
key
,
str
(
results
[
key
])))
writer
.
write
(
"{} = {}
\n
"
.
format
(
key
,
str
(
results
[
key
])))
if
args
.
do_predict
and
args
.
local_rank
in
[
-
1
,
0
]:
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
model
=
model_class
.
from_pretrained
(
args
.
output_dir
)
model
.
to
(
args
.
device
)
result
,
predictions
=
evaluate
(
args
,
model
,
tokenizer
,
labels
,
pad_token_label_id
,
mode
=
"test"
)
# Save results
output_test_results_file
=
os
.
path
.
join
(
args
.
output_dir
,
"test_results.txt"
)
with
open
(
output_test_results_file
,
"w"
)
as
writer
:
for
key
in
sorted
(
result
.
keys
()):
writer
.
write
(
"{} = {}
\n
"
.
format
(
key
,
str
(
result
[
key
])))
# Save predictions
output_test_predictions_file
=
os
.
path
.
join
(
args
.
output_dir
,
"test_predictions.txt"
)
with
open
(
output_test_predictions_file
,
"w"
)
as
writer
:
with
open
(
os
.
path
.
join
(
args
.
data_dir
,
"test.txt"
),
"r"
)
as
f
:
example_id
=
0
for
line
in
f
:
if
line
.
startswith
(
"-DOCSTART-"
)
or
line
==
""
or
line
==
"
\n
"
:
writer
.
write
(
line
)
if
not
predictions
[
example_id
]:
example_id
+=
1
elif
predictions
[
example_id
]:
output_line
=
line
.
split
()[
0
]
+
" "
+
predictions
[
example_id
].
pop
(
0
)
+
"
\n
"
writer
.
write
(
output_line
)
else
:
logger
.
warning
(
"Maximum sequence length exceeded: No prediction for '%s'."
,
line
.
split
()[
0
])
return
results
return
results
...
...
examples/utils_ner.py
View file @
383ef967
...
@@ -51,8 +51,13 @@ class InputFeatures(object):
...
@@ -51,8 +51,13 @@ class InputFeatures(object):
self
.
label_ids
=
label_ids
self
.
label_ids
=
label_ids
def
read_examples_from_file
(
data_dir
,
mode
):
def
read_examples_from_file
(
data_dir
,
evaluate
=
False
):
file_path
=
os
.
path
.
join
(
data_dir
,
"{}.txt"
.
format
(
mode
))
if
evaluate
:
file_path
=
os
.
path
.
join
(
data_dir
,
"dev.txt"
)
guid_prefix
=
"dev"
else
:
file_path
=
os
.
path
.
join
(
data_dir
,
"train.txt"
)
guid_prefix
=
"train"
guid_index
=
1
guid_index
=
1
examples
=
[]
examples
=
[]
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
file_path
,
encoding
=
"utf-8"
)
as
f
:
...
@@ -61,7 +66,7 @@ def read_examples_from_file(data_dir, mode):
...
@@ -61,7 +66,7 @@ def read_examples_from_file(data_dir, mode):
for
line
in
f
:
for
line
in
f
:
if
line
.
startswith
(
"-DOCSTART-"
)
or
line
==
""
or
line
==
"
\n
"
:
if
line
.
startswith
(
"-DOCSTART-"
)
or
line
==
""
or
line
==
"
\n
"
:
if
words
:
if
words
:
examples
.
append
(
InputExample
(
guid
=
"{}-{}"
.
format
(
mode
,
guid_index
),
examples
.
append
(
InputExample
(
guid
=
"{}-{}"
.
format
(
guid_prefix
,
guid_index
),
words
=
words
,
words
=
words
,
labels
=
labels
))
labels
=
labels
))
guid_index
+=
1
guid_index
+=
1
...
@@ -70,13 +75,9 @@ def read_examples_from_file(data_dir, mode):
...
@@ -70,13 +75,9 @@ def read_examples_from_file(data_dir, mode):
else
:
else
:
splits
=
line
.
split
(
" "
)
splits
=
line
.
split
(
" "
)
words
.
append
(
splits
[
0
])
words
.
append
(
splits
[
0
])
if
len
(
splits
)
>
1
:
labels
.
append
(
splits
[
-
1
][:
-
1
])
labels
.
append
(
splits
[
-
1
].
replace
(
"
\n
"
,
""
))
else
:
# Examples could have no label for mode = "test"
labels
.
append
(
"O"
)
if
words
:
if
words
:
examples
.
append
(
InputExample
(
guid
=
"%s-%d"
.
format
(
mode
,
guid_index
),
examples
.
append
(
InputExample
(
guid
=
"%s-%d"
.
format
(
guid_prefix
,
guid_index
),
words
=
words
,
words
=
words
,
labels
=
labels
))
labels
=
labels
))
return
examples
return
examples
...
@@ -201,12 +202,5 @@ def convert_examples_to_features(examples,
...
@@ -201,12 +202,5 @@ def convert_examples_to_features(examples,
return
features
return
features
def
get_labels
(
path
):
def
get_labels
():
if
path
:
return
[
"O"
,
"B-MISC"
,
"I-MISC"
,
"B-PER"
,
"I-PER"
,
"B-ORG"
,
"I-ORG"
,
"B-LOC"
,
"I-LOC"
]
with
open
(
path
,
"r"
)
as
f
:
labels
=
f
.
read
().
splitlines
()
if
"O"
not
in
labels
:
labels
=
[
"O"
]
+
labels
return
labels
else
:
return
[
"O"
,
"B-MISC"
,
"I-MISC"
,
"B-PER"
,
"I-PER"
,
"B-ORG"
,
"I-ORG"
,
"B-LOC"
,
"I-LOC"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment