Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
fdc05cd6
Commit
fdc05cd6
authored
Dec 09, 2019
by
Bilal Khan
Browse files
Update run_squad to save optimizer and scheduler states, then resume training from a checkpoint
parent
854ec578
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
205 additions
and
98 deletions
+205
-98
examples/run_squad.py
examples/run_squad.py
+205
-98
No files found.
examples/run_squad.py
View file @
fdc05cd6
...
@@ -27,7 +27,8 @@ import glob
...
@@ -27,7 +27,8 @@ import glob
import
timeit
import
timeit
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
from
torch.utils.data.distributed
import
DistributedSampler
try
:
try
:
...
@@ -38,21 +39,21 @@ except:
...
@@ -38,21 +39,21 @@ except:
from
tqdm
import
tqdm
,
trange
from
tqdm
import
tqdm
,
trange
from
transformers
import
(
WEIGHTS_NAME
,
BertConfig
,
from
transformers
import
(
WEIGHTS_NAME
,
BertConfig
,
BertForQuestionAnswering
,
BertTokenizer
,
BertForQuestionAnswering
,
BertTokenizer
,
XLMConfig
,
XLMForQuestionAnswering
,
XLMConfig
,
XLMForQuestionAnswering
,
XLMTokenizer
,
XLNetConfig
,
XLMTokenizer
,
XLNetConfig
,
XLNetForQuestionAnswering
,
XLNetForQuestionAnswering
,
XLNetTokenizer
,
XLNetTokenizer
,
DistilBertConfig
,
DistilBertForQuestionAnswering
,
DistilBertTokenizer
,
DistilBertConfig
,
DistilBertForQuestionAnswering
,
DistilBertTokenizer
,
AlbertConfig
,
AlbertForQuestionAnswering
,
AlbertTokenizer
,
AlbertConfig
,
AlbertForQuestionAnswering
,
AlbertTokenizer
,
XLMConfig
,
XLMForQuestionAnswering
,
XLMTokenizer
,
XLMConfig
,
XLMForQuestionAnswering
,
XLMTokenizer
,
)
)
from
transformers
import
AdamW
,
get_linear_schedule_with_warmup
,
squad_convert_examples_to_features
from
transformers
import
AdamW
,
get_linear_schedule_with_warmup
,
squad_convert_examples_to_features
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
ALL_MODELS
=
sum
((
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
\
ALL_MODELS
=
sum
((
tuple
(
conf
.
pretrained_config_archive_map
.
keys
())
for
conf
in
(
BertConfig
,
XLNetConfig
,
XLMConfig
)),
())
for
conf
in
(
BertConfig
,
XLNetConfig
,
XLMConfig
)),
())
MODEL_CLASSES
=
{
MODEL_CLASSES
=
{
...
@@ -64,6 +65,7 @@ MODEL_CLASSES = {
...
@@ -64,6 +65,7 @@ MODEL_CLASSES = {
'xlm'
:
(
XLMConfig
,
XLMForQuestionAnswering
,
XLMTokenizer
)
'xlm'
:
(
XLMConfig
,
XLMForQuestionAnswering
,
XLMTokenizer
)
}
}
def
set_seed
(
args
):
def
set_seed
(
args
):
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
...
@@ -71,40 +73,60 @@ def set_seed(args):
...
@@ -71,40 +73,60 @@ def set_seed(args):
if
args
.
n_gpu
>
0
:
if
args
.
n_gpu
>
0
:
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
torch
.
cuda
.
manual_seed_all
(
args
.
seed
)
def
to_list
(
tensor
):
def
to_list
(
tensor
):
return
tensor
.
detach
().
cpu
().
tolist
()
return
tensor
.
detach
().
cpu
().
tolist
()
def
train
(
args
,
train_dataset
,
model
,
tokenizer
):
def
train
(
args
,
train_dataset
,
model
,
tokenizer
):
""" Train the model """
""" Train the model """
if
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
tb_writer
=
SummaryWriter
()
tb_writer
=
SummaryWriter
()
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
max
(
1
,
args
.
n_gpu
)
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
max
(
1
,
args
.
n_gpu
)
train_sampler
=
RandomSampler
(
train_dataset
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
train_dataset
)
train_sampler
=
RandomSampler
(
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
train_dataset
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
train_dataset
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
if
args
.
max_steps
>
0
:
if
args
.
max_steps
>
0
:
t_total
=
args
.
max_steps
t_total
=
args
.
max_steps
args
.
num_train_epochs
=
args
.
max_steps
//
(
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
)
+
1
args
.
num_train_epochs
=
args
.
max_steps
//
(
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
)
+
1
else
:
else
:
t_total
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
t_total
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
# Prepare optimizer and schedule (linear warmup and decay)
# Prepare optimizer and schedule (linear warmup and decay)
no_decay
=
[
'bias'
,
'LayerNorm.weight'
]
no_decay
=
[
'bias'
,
'LayerNorm.weight'
]
optimizer_grouped_parameters
=
[
optimizer_grouped_parameters
=
[
{
'params'
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
args
.
weight_decay
},
{
'params'
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
not
any
(
{
'params'
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
0.0
}
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
args
.
weight_decay
},
{
'params'
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
0.0
}
]
]
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
eps
=
args
.
adam_epsilon
)
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
scheduler
=
get_linear_schedule_with_warmup
(
optimizer
,
num_warmup_steps
=
args
.
warmup_steps
,
num_training_steps
=
t_total
)
lr
=
args
.
learning_rate
,
eps
=
args
.
adam_epsilon
)
scheduler
=
get_linear_schedule_with_warmup
(
optimizer
,
num_warmup_steps
=
args
.
warmup_steps
,
num_training_steps
=
t_total
)
# Check if saved optimizer or scheduler states exist
if
os
.
path
.
isfile
(
os
.
path
.
join
(
args
.
model_name_or_path
,
'optimizer.pt'
))
and
os
.
path
.
isfile
(
os
.
path
.
join
(
args
.
model_name_or_path
,
'scheduler.pt'
)):
# Load in optimizer and scheduler states
optimizer
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
args
.
model_name_or_path
,
'optimizer.pt'
)))
scheduler
.
load_state_dict
(
torch
.
load
(
os
.
path
.
join
(
args
.
model_name_or_path
,
'scheduler.pt'
)))
if
args
.
fp16
:
if
args
.
fp16
:
try
:
try
:
from
apex
import
amp
from
apex
import
amp
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
args
.
fp16_opt_level
)
model
,
optimizer
=
amp
.
initialize
(
model
,
optimizer
,
opt_level
=
args
.
fp16_opt_level
)
# multi-gpu training (should be after apex fp16 initialization)
# multi-gpu training (should be after apex fp16 initialization)
if
args
.
n_gpu
>
1
:
if
args
.
n_gpu
>
1
:
...
@@ -120,21 +142,50 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -120,21 +142,50 @@ def train(args, train_dataset, model, tokenizer):
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
"***** Running training *****"
)
logger
.
info
(
" Num examples = %d"
,
len
(
train_dataset
))
logger
.
info
(
" Num examples = %d"
,
len
(
train_dataset
))
logger
.
info
(
" Num Epochs = %d"
,
args
.
num_train_epochs
)
logger
.
info
(
" Num Epochs = %d"
,
args
.
num_train_epochs
)
logger
.
info
(
" Instantaneous batch size per GPU = %d"
,
args
.
per_gpu_train_batch_size
)
logger
.
info
(
" Instantaneous batch size per GPU = %d"
,
args
.
per_gpu_train_batch_size
)
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = %d"
,
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = %d"
,
args
.
train_batch_size
*
args
.
gradient_accumulation_steps
*
(
torch
.
distributed
.
get_world_size
()
if
args
.
local_rank
!=
-
1
else
1
))
args
.
train_batch_size
*
args
.
gradient_accumulation_steps
*
(
torch
.
distributed
.
get_world_size
()
if
args
.
local_rank
!=
-
1
else
1
))
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
1
global_step
=
1
epochs_trained
=
0
steps_trained_in_current_epoch
=
0
# Check if continuing training from a checkpoint
if
os
.
path
.
exists
(
args
.
model_name_or_path
):
# set global_step to gobal_step of last saved checkpoint from model path
global_step
=
int
(
args
.
model_name_or_path
.
split
(
'-'
)[
-
1
].
split
(
'/'
)[
0
])
epochs_trained
=
global_step
//
(
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
)
steps_trained_in_current_epoch
=
global_step
%
(
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Continuing training from checkpoint, will skip to saved global_step"
)
logger
.
info
(
" Continuing training from epoch %d"
,
epochs_trained
)
logger
.
info
(
" Continuing training from global step %d"
,
global_step
)
logger
.
info
(
" Will skip the first %d steps in the first epoch"
,
steps_trained_in_current_epoch
)
tr_loss
,
logging_loss
=
0.0
,
0.0
tr_loss
,
logging_loss
=
0.0
,
0.0
model
.
zero_grad
()
model
.
zero_grad
()
train_iterator
=
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
train_iterator
=
trange
(
epochs_trained
,
int
(
set_seed
(
args
)
# Added here for reproductibility (even between python 2 and 3)
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
# Added here for reproductibility (even between python 2 and 3)
set_seed
(
args
)
for
_
in
train_iterator
:
for
_
in
train_iterator
:
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
epoch_iterator
=
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])
for
step
,
batch
in
enumerate
(
epoch_iterator
):
for
step
,
batch
in
enumerate
(
epoch_iterator
):
# Skip past any already trained steps if resuming training
if
steps_trained_in_current_epoch
>
0
:
steps_trained_in_current_epoch
-=
1
continue
model
.
train
()
model
.
train
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
...
@@ -152,10 +203,11 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -152,10 +203,11 @@ def train(args, train_dataset, model, tokenizer):
inputs
.
update
({
'cls_index'
:
batch
[
5
],
'p_mask'
:
batch
[
6
]})
inputs
.
update
({
'cls_index'
:
batch
[
5
],
'p_mask'
:
batch
[
6
]})
outputs
=
model
(
**
inputs
)
outputs
=
model
(
**
inputs
)
loss
=
outputs
[
0
]
# model outputs are always tuple in transformers (see doc)
# model outputs are always tuple in transformers (see doc)
loss
=
outputs
[
0
]
if
args
.
n_gpu
>
1
:
if
args
.
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu parallel (not distributed) training
loss
=
loss
.
mean
()
# mean() to average on multi-gpu parallel (not distributed) training
if
args
.
gradient_accumulation_steps
>
1
:
if
args
.
gradient_accumulation_steps
>
1
:
loss
=
loss
/
args
.
gradient_accumulation_steps
loss
=
loss
/
args
.
gradient_accumulation_steps
...
@@ -168,9 +220,11 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -168,9 +220,11 @@ def train(args, train_dataset, model, tokenizer):
tr_loss
+=
loss
.
item
()
tr_loss
+=
loss
.
item
()
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
if
args
.
fp16
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
optimizer
),
args
.
max_grad_norm
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
optimizer
),
args
.
max_grad_norm
)
else
:
else
:
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
optimizer
.
step
()
optimizer
.
step
()
scheduler
.
step
()
# Update learning rate schedule
scheduler
.
step
()
# Update learning rate schedule
...
@@ -179,24 +233,41 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -179,24 +233,41 @@ def train(args, train_dataset, model, tokenizer):
# Log metrics
# Log metrics
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
logging_steps
>
0
and
global_step
%
args
.
logging_steps
==
0
:
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
logging_steps
>
0
and
global_step
%
args
.
logging_steps
==
0
:
if
args
.
local_rank
==
-
1
and
args
.
evaluate_during_training
:
# Only evaluate when single GPU otherwise metrics may not average well
# Only evaluate when single GPU otherwise metrics may not average well
if
args
.
local_rank
==
-
1
and
args
.
evaluate_during_training
:
results
=
evaluate
(
args
,
model
,
tokenizer
)
results
=
evaluate
(
args
,
model
,
tokenizer
)
for
key
,
value
in
results
.
items
():
for
key
,
value
in
results
.
items
():
tb_writer
.
add_scalar
(
'eval_{}'
.
format
(
key
),
value
,
global_step
)
tb_writer
.
add_scalar
(
tb_writer
.
add_scalar
(
'lr'
,
scheduler
.
get_lr
()[
0
],
global_step
)
'eval_{}'
.
format
(
key
),
value
,
global_step
)
tb_writer
.
add_scalar
(
'loss'
,
(
tr_loss
-
logging_loss
)
/
args
.
logging_steps
,
global_step
)
tb_writer
.
add_scalar
(
'lr'
,
scheduler
.
get_lr
()[
0
],
global_step
)
tb_writer
.
add_scalar
(
'loss'
,
(
tr_loss
-
logging_loss
)
/
args
.
logging_steps
,
global_step
)
logging_loss
=
tr_loss
logging_loss
=
tr_loss
# Save model checkpoint
# Save model checkpoint
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
save_steps
>
0
and
global_step
%
args
.
save_steps
==
0
:
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
save_steps
>
0
and
global_step
%
args
.
save_steps
==
0
:
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
'checkpoint-{}'
.
format
(
global_step
))
output_dir
=
os
.
path
.
join
(
args
.
output_dir
,
'checkpoint-{}'
.
format
(
global_step
))
if
not
os
.
path
.
exists
(
output_dir
):
if
not
os
.
path
.
exists
(
output_dir
):
os
.
makedirs
(
output_dir
)
os
.
makedirs
(
output_dir
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Take care of distributed/parallel training
# Take care of distributed/parallel training
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
model_to_save
.
save_pretrained
(
output_dir
)
model_to_save
.
save_pretrained
(
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
'training_args.bin'
))
tokenizer
.
save_pretrained
(
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
'training_args.bin'
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
torch
.
save
(
optimizer
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
'optimizer.pt'
))
torch
.
save
(
scheduler
.
state_dict
(),
os
.
path
.
join
(
output_dir
,
'scheduler.pt'
))
logger
.
info
(
"Saving optimizer and scheduler states to %s"
,
output_dir
)
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
epoch_iterator
.
close
()
epoch_iterator
.
close
()
break
break
...
@@ -211,7 +282,8 @@ def train(args, train_dataset, model, tokenizer):
...
@@ -211,7 +282,8 @@ def train(args, train_dataset, model, tokenizer):
def
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
""
):
def
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
""
):
dataset
,
examples
,
features
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
True
,
output_examples
=
True
)
dataset
,
examples
,
features
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
True
,
output_examples
=
True
)
if
not
os
.
path
.
exists
(
args
.
output_dir
)
and
args
.
local_rank
in
[
-
1
,
0
]:
if
not
os
.
path
.
exists
(
args
.
output_dir
)
and
args
.
local_rank
in
[
-
1
,
0
]:
os
.
makedirs
(
args
.
output_dir
)
os
.
makedirs
(
args
.
output_dir
)
...
@@ -220,7 +292,8 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -220,7 +292,8 @@ def evaluate(args, model, tokenizer, prefix=""):
# Note that DistributedSampler samples randomly
# Note that DistributedSampler samples randomly
eval_sampler
=
SequentialSampler
(
dataset
)
eval_sampler
=
SequentialSampler
(
dataset
)
eval_dataloader
=
DataLoader
(
dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
eval_dataloader
=
DataLoader
(
dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
# multi-gpu evaluate
# multi-gpu evaluate
if
args
.
n_gpu
>
1
:
if
args
.
n_gpu
>
1
:
...
@@ -243,12 +316,13 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -243,12 +316,13 @@ def evaluate(args, model, tokenizer, prefix=""):
'input_ids'
:
batch
[
0
],
'input_ids'
:
batch
[
0
],
'attention_mask'
:
batch
[
1
]
'attention_mask'
:
batch
[
1
]
}
}
if
args
.
model_type
!=
'distilbert'
:
if
args
.
model_type
!=
'distilbert'
:
inputs
[
'token_type_ids'
]
=
None
if
args
.
model_type
==
'xlm'
else
batch
[
2
]
# XLM don't use segment_ids
# XLM don't use segment_ids
inputs
[
'token_type_ids'
]
=
None
if
args
.
model_type
==
'xlm'
else
batch
[
2
]
example_indices
=
batch
[
3
]
example_indices
=
batch
[
3
]
# XLNet and XLM use more arguments for their predictions
# XLNet and XLM use more arguments for their predictions
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
inputs
.
update
({
'cls_index'
:
batch
[
4
],
'p_mask'
:
batch
[
5
]})
inputs
.
update
({
'cls_index'
:
batch
[
4
],
'p_mask'
:
batch
[
5
]})
...
@@ -271,9 +345,9 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -271,9 +345,9 @@ def evaluate(args, model, tokenizer, prefix=""):
cls_logits
=
output
[
4
]
cls_logits
=
output
[
4
]
result
=
SquadResult
(
result
=
SquadResult
(
unique_id
,
start_logits
,
end_logits
,
unique_id
,
start_logits
,
end_logits
,
start_top_index
=
start_top_index
,
start_top_index
=
start_top_index
,
end_top_index
=
end_top_index
,
end_top_index
=
end_top_index
,
cls_logits
=
cls_logits
cls_logits
=
cls_logits
)
)
...
@@ -286,40 +360,48 @@ def evaluate(args, model, tokenizer, prefix=""):
...
@@ -286,40 +360,48 @@ def evaluate(args, model, tokenizer, prefix=""):
all_results
.
append
(
result
)
all_results
.
append
(
result
)
evalTime
=
timeit
.
default_timer
()
-
start_time
evalTime
=
timeit
.
default_timer
()
-
start_time
logger
.
info
(
" Evaluation done in total %f secs (%f sec per example)"
,
evalTime
,
evalTime
/
len
(
dataset
))
logger
.
info
(
" Evaluation done in total %f secs (%f sec per example)"
,
evalTime
,
evalTime
/
len
(
dataset
))
# Compute predictions
# Compute predictions
output_prediction_file
=
os
.
path
.
join
(
args
.
output_dir
,
"predictions_{}.json"
.
format
(
prefix
))
output_prediction_file
=
os
.
path
.
join
(
output_nbest_file
=
os
.
path
.
join
(
args
.
output_dir
,
"nbest_predictions_{}.json"
.
format
(
prefix
))
args
.
output_dir
,
"predictions_{}.json"
.
format
(
prefix
))
output_nbest_file
=
os
.
path
.
join
(
args
.
output_dir
,
"nbest_predictions_{}.json"
.
format
(
prefix
))
if
args
.
version_2_with_negative
:
if
args
.
version_2_with_negative
:
output_null_log_odds_file
=
os
.
path
.
join
(
args
.
output_dir
,
"null_odds_{}.json"
.
format
(
prefix
))
output_null_log_odds_file
=
os
.
path
.
join
(
args
.
output_dir
,
"null_odds_{}.json"
.
format
(
prefix
))
else
:
else
:
output_null_log_odds_file
=
None
output_null_log_odds_file
=
None
# XLNet and XLM use a more complex post-processing procedure
# XLNet and XLM use a more complex post-processing procedure
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
if
args
.
model_type
in
[
'xlnet'
,
'xlm'
]:
start_n_top
=
model
.
config
.
start_n_top
if
hasattr
(
model
,
"config"
)
else
model
.
module
.
config
.
start_n_top
start_n_top
=
model
.
config
.
start_n_top
if
hasattr
(
end_n_top
=
model
.
config
.
end_n_top
if
hasattr
(
model
,
"config"
)
else
model
.
module
.
config
.
end_n_top
model
,
"config"
)
else
model
.
module
.
config
.
start_n_top
end_n_top
=
model
.
config
.
end_n_top
if
hasattr
(
model
,
"config"
)
else
model
.
module
.
config
.
end_n_top
predictions
=
compute_predictions_log_probs
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
predictions
=
compute_predictions_log_probs
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
output_prediction_file
,
args
.
max_answer_length
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
output_nbest_file
,
output_null_log_odds_file
,
start_n_top
,
end_n_top
,
start_n_top
,
end_n_top
,
args
.
version_2_with_negative
,
tokenizer
,
args
.
verbose_logging
)
args
.
version_2_with_negative
,
tokenizer
,
args
.
verbose_logging
)
else
:
else
:
predictions
=
compute_predictions_logits
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
predictions
=
compute_predictions_logits
(
examples
,
features
,
all_results
,
args
.
n_best_size
,
args
.
max_answer_length
,
args
.
do_lower_case
,
output_prediction_file
,
args
.
max_answer_length
,
args
.
do_lower_case
,
output_prediction_file
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
verbose_logging
,
output_nbest_file
,
output_null_log_odds_file
,
args
.
verbose_logging
,
args
.
version_2_with_negative
,
args
.
null_score_diff_threshold
)
args
.
version_2_with_negative
,
args
.
null_score_diff_threshold
)
# Compute the F1 and exact scores.
# Compute the F1 and exact scores.
results
=
squad_evaluate
(
examples
,
predictions
)
results
=
squad_evaluate
(
examples
,
predictions
)
return
results
return
results
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
def
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
):
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
if
args
.
local_rank
not
in
[
-
1
,
0
]
and
not
evaluate
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch
.
distributed
.
barrier
()
# Load data features from cache or dataset file
# Load data features from cache or dataset file
input_dir
=
args
.
data_dir
if
args
.
data_dir
else
"."
input_dir
=
args
.
data_dir
if
args
.
data_dir
else
"."
...
@@ -331,7 +413,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -331,7 +413,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
# Init features and dataset from cache if it exists
# Init features and dataset from cache if it exists
if
os
.
path
.
exists
(
cached_features_file
)
and
not
args
.
overwrite_cache
and
not
output_examples
:
if
os
.
path
.
exists
(
cached_features_file
)
and
not
args
.
overwrite_cache
and
not
output_examples
:
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
logger
.
info
(
"Loading features from cached file %s"
,
cached_features_file
)
features_and_dataset
=
torch
.
load
(
cached_features_file
)
features_and_dataset
=
torch
.
load
(
cached_features_file
)
features
,
dataset
=
features_and_dataset
[
"features"
],
features_and_dataset
[
"dataset"
]
features
,
dataset
=
features_and_dataset
[
"features"
],
features_and_dataset
[
"dataset"
]
else
:
else
:
...
@@ -341,18 +424,22 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -341,18 +424,22 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
try
:
try
:
import
tensorflow_datasets
as
tfds
import
tensorflow_datasets
as
tfds
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"If not data_dir is specified, tensorflow_datasets needs to be installed."
)
raise
ImportError
(
"If not data_dir is specified, tensorflow_datasets needs to be installed."
)
if
args
.
version_2_with_negative
:
if
args
.
version_2_with_negative
:
logger
.
warn
(
"tensorflow_datasets does not handle version 2 of SQuAD."
)
logger
.
warn
(
"tensorflow_datasets does not handle version 2 of SQuAD."
)
tfds_examples
=
tfds
.
load
(
"squad"
)
tfds_examples
=
tfds
.
load
(
"squad"
)
examples
=
SquadV1Processor
().
get_examples_from_dataset
(
tfds_examples
,
evaluate
=
evaluate
)
examples
=
SquadV1Processor
().
get_examples_from_dataset
(
tfds_examples
,
evaluate
=
evaluate
)
else
:
else
:
processor
=
SquadV2Processor
()
if
args
.
version_2_with_negative
else
SquadV1Processor
()
processor
=
SquadV2Processor
()
if
args
.
version_2_with_negative
else
SquadV1Processor
()
examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
if
evaluate
else
processor
.
get_train_examples
(
args
.
data_dir
)
examples
=
processor
.
get_dev_examples
(
args
.
data_dir
)
if
evaluate
else
processor
.
get_train_examples
(
args
.
data_dir
)
features
,
dataset
=
squad_convert_examples_to_features
(
features
,
dataset
=
squad_convert_examples_to_features
(
examples
=
examples
,
examples
=
examples
,
tokenizer
=
tokenizer
,
tokenizer
=
tokenizer
,
max_seq_length
=
args
.
max_seq_length
,
max_seq_length
=
args
.
max_seq_length
,
...
@@ -363,11 +450,14 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -363,11 +450,14 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
)
)
if
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
local_rank
in
[
-
1
,
0
]:
logger
.
info
(
"Saving features into cached file %s"
,
cached_features_file
)
logger
.
info
(
"Saving features into cached file %s"
,
torch
.
save
({
"features"
:
features
,
"dataset"
:
dataset
},
cached_features_file
)
cached_features_file
)
torch
.
save
({
"features"
:
features
,
"dataset"
:
dataset
},
cached_features_file
)
if
args
.
local_rank
==
0
and
not
evaluate
:
if
args
.
local_rank
==
0
and
not
evaluate
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
# Make sure only the first process in distributed training process the dataset, and the others will use the cache
torch
.
distributed
.
barrier
()
if
output_examples
:
if
output_examples
:
return
dataset
,
examples
,
features
return
dataset
,
examples
,
features
...
@@ -377,7 +467,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
...
@@ -377,7 +467,7 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
def
main
():
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
#
#
Required parameters
# Required parameters
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--model_type"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Model type selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()))
help
=
"Model type selected in the list: "
+
", "
.
join
(
MODEL_CLASSES
.
keys
()))
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--model_name_or_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
...
@@ -385,7 +475,7 @@ def main():
...
@@ -385,7 +475,7 @@ def main():
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
parser
.
add_argument
(
"--output_dir"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The output directory where the model checkpoints and predictions will be written."
)
help
=
"The output directory where the model checkpoints and predictions will be written."
)
#
#
Other parameters
# Other parameters
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
parser
.
add_argument
(
"--data_dir"
,
default
=
None
,
type
=
str
,
help
=
"The input data dir. Should contain the .json files for the task. If not specified, will run with tensorflow_datasets."
)
help
=
"The input data dir. Should contain the .json files for the task. If not specified, will run with tensorflow_datasets."
)
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
parser
.
add_argument
(
"--config_name"
,
default
=
""
,
type
=
str
,
...
@@ -468,8 +558,10 @@ def main():
...
@@ -468,8 +558,10 @@ def main():
parser
.
add_argument
(
'--fp16_opt_level'
,
type
=
str
,
default
=
'O1'
,
parser
.
add_argument
(
'--fp16_opt_level'
,
type
=
str
,
default
=
'O1'
,
help
=
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
help
=
"For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
"See details at https://nvidia.github.io/apex/amp.html"
)
"See details at https://nvidia.github.io/apex/amp.html"
)
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
''
,
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
args
.
predict_file
=
os
.
path
.
join
(
args
.
output_dir
,
'predictions_{}_{}.txt'
.
format
(
args
.
predict_file
=
os
.
path
.
join
(
args
.
output_dir
,
'predictions_{}_{}.txt'
.
format
(
...
@@ -478,19 +570,22 @@ def main():
...
@@ -478,19 +570,22 @@ def main():
)
)
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
and
not
args
.
overwrite_output_dir
:
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
)
and
args
.
do_train
and
not
args
.
overwrite_output_dir
:
raise
ValueError
(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
.
format
(
args
.
output_dir
))
raise
ValueError
(
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome."
.
format
(
args
.
output_dir
))
# Setup distant debugging if needed
# Setup distant debugging if needed
if
args
.
server_ip
and
args
.
server_port
:
if
args
.
server_ip
and
args
.
server_port
:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import
ptvsd
import
ptvsd
print
(
"Waiting for debugger attach"
)
print
(
"Waiting for debugger attach"
)
ptvsd
.
enable_attach
(
address
=
(
args
.
server_ip
,
args
.
server_port
),
redirect_output
=
True
)
ptvsd
.
enable_attach
(
address
=
(
args
.
server_ip
,
args
.
server_port
),
redirect_output
=
True
)
ptvsd
.
wait_for_attach
()
ptvsd
.
wait_for_attach
()
# Setup CUDA, GPU & distributed training
# Setup CUDA, GPU & distributed training
if
args
.
local_rank
==
-
1
or
args
.
no_cuda
:
if
args
.
local_rank
==
-
1
or
args
.
no_cuda
:
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
and
not
args
.
no_cuda
else
"cpu"
)
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
args
.
n_gpu
=
torch
.
cuda
.
device_count
()
else
:
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
else
:
# Initializes the distributed backend which will take care of sychronizing nodes/GPUs
torch
.
cuda
.
set_device
(
args
.
local_rank
)
torch
.
cuda
.
set_device
(
args
.
local_rank
)
...
@@ -500,18 +595,19 @@ def main():
...
@@ -500,18 +595,19 @@ def main():
args
.
device
=
device
args
.
device
=
device
# Setup logging
# Setup logging
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
datefmt
=
'%m/%d/%Y %H:%M:%S'
,
level
=
logging
.
INFO
if
args
.
local_rank
in
[
-
1
,
0
]
else
logging
.
WARN
)
level
=
logging
.
INFO
if
args
.
local_rank
in
[
-
1
,
0
]
else
logging
.
WARN
)
logger
.
warning
(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s"
,
logger
.
warning
(
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s"
,
args
.
local_rank
,
device
,
args
.
n_gpu
,
bool
(
args
.
local_rank
!=
-
1
),
args
.
fp16
)
args
.
local_rank
,
device
,
args
.
n_gpu
,
bool
(
args
.
local_rank
!=
-
1
),
args
.
fp16
)
# Set seed
# Set seed
set_seed
(
args
)
set_seed
(
args
)
# Load pretrained model and tokenizer
# Load pretrained model and tokenizer
if
args
.
local_rank
not
in
[
-
1
,
0
]:
if
args
.
local_rank
not
in
[
-
1
,
0
]:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
# Make sure only the first process in distributed training will download model & vocab
torch
.
distributed
.
barrier
()
args
.
model_type
=
args
.
model_type
.
lower
()
args
.
model_type
=
args
.
model_type
.
lower
()
config_class
,
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
config_class
,
model_class
,
tokenizer_class
=
MODEL_CLASSES
[
args
.
model_type
]
...
@@ -521,12 +617,14 @@ def main():
...
@@ -521,12 +617,14 @@ def main():
do_lower_case
=
args
.
do_lower_case
,
do_lower_case
=
args
.
do_lower_case
,
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
)
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
)
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
,
model
=
model_class
.
from_pretrained
(
args
.
model_name_or_path
,
from_tf
=
bool
(
'.ckpt'
in
args
.
model_name_or_path
),
from_tf
=
bool
(
'.ckpt'
in
args
.
model_name_or_path
),
config
=
config
,
config
=
config
,
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
)
cache_dir
=
args
.
cache_dir
if
args
.
cache_dir
else
None
)
if
args
.
local_rank
==
0
:
if
args
.
local_rank
==
0
:
torch
.
distributed
.
barrier
()
# Make sure only the first process in distributed training will download model & vocab
# Make sure only the first process in distributed training will download model & vocab
torch
.
distributed
.
barrier
()
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
...
@@ -540,14 +638,16 @@ def main():
...
@@ -540,14 +638,16 @@ def main():
import
apex
import
apex
apex
.
amp
.
register_half_function
(
torch
,
'einsum'
)
apex
.
amp
.
register_half_function
(
torch
,
'einsum'
)
except
ImportError
:
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
)
# Training
# Training
if
args
.
do_train
:
if
args
.
do_train
:
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
)
train_dataset
=
load_and_cache_examples
(
args
,
tokenizer
,
evaluate
=
False
,
output_examples
=
False
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
)
global_step
,
tr_loss
=
train
(
args
,
train_dataset
,
model
,
tokenizer
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
logger
.
info
(
" global_step = %s, average loss = %s"
,
global_step
,
tr_loss
)
# Save the trained model and the tokenizer
# Save the trained model and the tokenizer
if
args
.
do_train
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
if
args
.
do_train
and
(
args
.
local_rank
==
-
1
or
torch
.
distributed
.
get_rank
()
==
0
):
...
@@ -558,7 +658,8 @@ def main():
...
@@ -558,7 +658,8 @@ def main():
logger
.
info
(
"Saving model checkpoint to %s"
,
args
.
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
args
.
output_dir
)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
# They can then be reloaded using `from_pretrained()`
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Take care of distributed/parallel training
# Take care of distributed/parallel training
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
model_to_save
.
save_pretrained
(
args
.
output_dir
)
model_to_save
.
save_pretrained
(
args
.
output_dir
)
tokenizer
.
save_pretrained
(
args
.
output_dir
)
tokenizer
.
save_pretrained
(
args
.
output_dir
)
...
@@ -566,31 +667,37 @@ def main():
...
@@ -566,31 +667,37 @@ def main():
torch
.
save
(
args
,
os
.
path
.
join
(
args
.
output_dir
,
'training_args.bin'
))
torch
.
save
(
args
,
os
.
path
.
join
(
args
.
output_dir
,
'training_args.bin'
))
# Load a trained model and vocabulary that you have fine-tuned
# Load a trained model and vocabulary that you have fine-tuned
model
=
model_class
.
from_pretrained
(
args
.
output_dir
,
force_download
=
True
)
model
=
model_class
.
from_pretrained
(
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
args
.
output_dir
,
force_download
=
True
)
tokenizer
=
tokenizer_class
.
from_pretrained
(
args
.
output_dir
,
do_lower_case
=
args
.
do_lower_case
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
# Evaluation - we can ask to evaluate all the checkpoints (sub-directories) in a directory
results
=
{}
results
=
{}
if
args
.
do_eval
and
args
.
local_rank
in
[
-
1
,
0
]:
if
args
.
do_eval
and
args
.
local_rank
in
[
-
1
,
0
]:
checkpoints
=
[
args
.
output_dir
]
checkpoints
=
[
args
.
output_dir
]
if
args
.
eval_all_checkpoints
:
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
'/**/'
+
WEIGHTS_NAME
,
recursive
=
True
)))
checkpoints
=
list
(
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
logging
.
getLogger
(
"transformers.modeling_utils"
).
setLevel
(
logging
.
WARN
)
# Reduce model loading logs
glob
.
glob
(
args
.
output_dir
+
'/**/'
+
WEIGHTS_NAME
,
recursive
=
True
)))
logging
.
getLogger
(
"transformers.modeling_utils"
).
setLevel
(
logging
.
WARN
)
# Reduce model loading logs
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
for
checkpoint
in
checkpoints
:
for
checkpoint
in
checkpoints
:
# Reload the model
# Reload the model
global_step
=
checkpoint
.
split
(
'-'
)[
-
1
]
if
len
(
checkpoints
)
>
1
else
""
global_step
=
checkpoint
.
split
(
model
=
model_class
.
from_pretrained
(
checkpoint
,
force_download
=
True
)
'-'
)[
-
1
]
if
len
(
checkpoints
)
>
1
else
""
model
=
model_class
.
from_pretrained
(
checkpoint
,
force_download
=
True
)
model
.
to
(
args
.
device
)
model
.
to
(
args
.
device
)
# Evaluate
# Evaluate
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
dict
((
k
+
(
'_{}'
.
format
(
global_step
)
if
global_step
else
''
),
v
)
for
k
,
v
in
result
.
items
())
result
=
dict
((
k
+
(
'_{}'
.
format
(
global_step
)
if
global_step
else
''
),
v
)
for
k
,
v
in
result
.
items
())
results
.
update
(
result
)
results
.
update
(
result
)
logger
.
info
(
"Results: {}"
.
format
(
results
))
logger
.
info
(
"Results: {}"
.
format
(
results
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment