Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
89896fe0
Commit
89896fe0
authored
Dec 09, 2019
by
Bilal Khan
Browse files
Update run_ner to save optimizer and scheduler states, then resume training from a checkpoint
parent
fdc05cd6
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
34 additions
and
1 deletion
+34
-1
examples/run_ner.py
examples/run_ner.py
+34
-1
No files found.
examples/run_ner.py
View file @
89896fe0
...
...
@@ -85,6 +85,13 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
]
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
eps
=
args
.
adam_epsilon
)
scheduler
=
get_linear_schedule_with_warmup
(
optimizer
,
num_warmup_steps
=
args
.
warmup_steps
,
num_training_steps
=
t_total
)
# Check if saved optimizer or scheduler states exist
if
os
.
path
.
isfile
(
os
.
path
.
join
(
args
.
model_name_or_path
,
'optimizer.pt'
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
args
.
model_name_or_path
,
'scheduler.pt'
)):
# Load in optimizer and scheduler states
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
args
.
model_name_or_path
,
'optimizer.pt'
)))
scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
args
.
model_name_or_path
,
'scheduler.pt'
)))
if
args
.
fp16
:
try
:
from
apex
import
amp
...
...
@@ -114,13 +121,33 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
0
epochs_trained
=
0
steps_trained_in_current_epoch
=
0
# Check if continuing training from a checkpoint
if
os
.
path
.
exists
(
args
.
model_name_or_path
):
# set global_step to gobal_step of last saved checkpoint from model path
global_step
=
int
(
args
.
model_name_or_path
.
split
(
'-'
)[
-
1
].
split
(
'/'
)[
0
])
epochs_trained
=
global_step
//
(
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
)
steps_trained_in_current_epoch
=
global_step
%
(
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from global step %d"
,
global_step
)
logger
.
info
(
" Will skip the first %d steps in the first epoch"
,
steps_trained_in_current_epoch
)
tr_loss
,
logging_loss
=
0.0
,
0.0
model
.
zero_grad
()
train_iterator
=
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
train_iterator
=
trange
(
epochs_trained
,
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
set_seed
(
args
)
# Added here for reproductibility (even between python 2 and 3)
for
_
in
train_iterator
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
for
step
,
batch
in
enumerate
(
epoch_iterator
):
# Skip past any already trained steps if resuming training
if
steps_trained_in_current_epoch
>
0
:
steps_trained_in_current_epoch
-=
1
continue
model
.
train
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
inputs
=
{
"input_ids"
:
batch
[
0
],
...
...
@@ -172,9 +199,15 @@ def train(args, train_dataset, model, tokenizer, labels, pad_token_label_id):
os
.
makedirs
(
output_dir
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
"module"
)
else
model
# Take care of distributed/parallel training
model_to_save
.
save_pretrained
(
output_dir
)
tokenizer
.
save_pretrained
(
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
"training_args.bin"
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
torch
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
'optimizer.pt'
))
torch
.
save
(
scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
'scheduler.pt'
))
logger
.
info
(
"Saving optimizer and scheduler states to %s"
,
output_dir
)
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
epoch_iterator
.
close
()
break
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment