Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d6522e28
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "dee3e45b93e65e3de9cdf28f5ffbe148d91e361e"
Commit
d6522e28
authored
Jul 17, 2019
by
yzy5630
Browse files
change loss and optimizer to new API
parent
123da5a2
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
42 deletions
+24
-42
examples/lm_finetuning/finetune_on_pregenerated.py
examples/lm_finetuning/finetune_on_pregenerated.py
+10
-19
examples/lm_finetuning/simple_lm_finetuning.py
examples/lm_finetuning/simple_lm_finetuning.py
+14
-23
No files found.
examples/lm_finetuning/finetune_on_pregenerated.py
View file @
d6522e28
...
@@ -155,11 +155,10 @@ def main():
...
@@ -155,11 +155,10 @@ def main():
help
=
"Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.
\n
"
help
=
"Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.
\n
"
"0 (default value): dynamic loss scaling.
\n
"
"0 (default value): dynamic loss scaling.
\n
"
"Positive power of 2: static loss scaling value.
\n
"
)
"Positive power of 2: static loss scaling value.
\n
"
)
parser
.
add_argument
(
"--warmup_proportion"
,
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0.1
,
default
=
0
,
type
=
float
,
type
=
int
,
help
=
"Proportion of training to perform linear learning rate warmup for. "
help
=
"Linear warmup over warmup_steps."
)
"E.g., 0.1 = 10%% of training."
)
parser
.
add_argument
(
"--learning_rate"
,
parser
.
add_argument
(
"--learning_rate"
,
default
=
3e-5
,
default
=
3e-5
,
type
=
float
,
type
=
float
,
...
@@ -270,13 +269,9 @@ def main():
...
@@ -270,13 +269,9 @@ def main():
optimizer
=
FP16_Optimizer
(
optimizer
,
dynamic_loss_scale
=
True
)
optimizer
=
FP16_Optimizer
(
optimizer
,
dynamic_loss_scale
=
True
)
else
:
else
:
optimizer
=
FP16_Optimizer
(
optimizer
,
static_loss_scale
=
args
.
loss_scale
)
optimizer
=
FP16_Optimizer
(
optimizer
,
static_loss_scale
=
args
.
loss_scale
)
warmup_linear
=
WarmupLinearSchedule
(
warmup
=
args
.
warmup_proportion
,
t_total
=
num_train_optimization_steps
)
else
:
else
:
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
eps
=
args
.
adam_epsilon
)
lr
=
args
.
learning_rate
,
scheduler
=
WarmupLinearSchedule
(
optimizer
,
warmup_steps
=
args
.
warmup_steps
,
t_total
=
num_train_optimization_steps
)
warmup
=
args
.
warmup_proportion
,
t_total
=
num_train_optimization_steps
)
global_step
=
0
global_step
=
0
logging
.
info
(
"***** Running training *****"
)
logging
.
info
(
"***** Running training *****"
)
...
@@ -298,7 +293,8 @@ def main():
...
@@ -298,7 +293,8 @@ def main():
for
step
,
batch
in
enumerate
(
train_dataloader
):
for
step
,
batch
in
enumerate
(
train_dataloader
):
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
input_ids
,
input_mask
,
segment_ids
,
lm_label_ids
,
is_next
=
batch
input_ids
,
input_mask
,
segment_ids
,
lm_label_ids
,
is_next
=
batch
loss
=
model
(
input_ids
,
segment_ids
,
input_mask
,
lm_label_ids
,
is_next
)
outputs
=
model
(
input_ids
,
segment_ids
,
input_mask
,
lm_label_ids
,
is_next
)
loss
=
outputs
[
0
]
if
n_gpu
>
1
:
if
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu.
loss
=
loss
.
mean
()
# mean() to average on multi-gpu.
if
args
.
gradient_accumulation_steps
>
1
:
if
args
.
gradient_accumulation_steps
>
1
:
...
@@ -314,18 +310,13 @@ def main():
...
@@ -314,18 +310,13 @@ def main():
mean_loss
=
tr_loss
*
args
.
gradient_accumulation_steps
/
nb_tr_steps
mean_loss
=
tr_loss
*
args
.
gradient_accumulation_steps
/
nb_tr_steps
pbar
.
set_postfix_str
(
f
"Loss:
{
mean_loss
:.
5
f
}
"
)
pbar
.
set_postfix_str
(
f
"Loss:
{
mean_loss
:.
5
f
}
"
)
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
scheduler
.
step
()
# Update learning rate schedule
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step
=
args
.
learning_rate
*
warmup_linear
.
get_lr
(
global_step
,
args
.
warmup_proportion
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr_this_step
optimizer
.
step
()
optimizer
.
step
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
global_step
+=
1
global_step
+=
1
# Save a trained model
# Save a trained model
if
torch
.
distributed
.
get_rank
()
==
0
:
if
n_gpu
>
1
and
torch
.
distributed
.
get_rank
()
==
0
or
n_gpu
<=
1
:
logging
.
info
(
"** ** * Saving fine-tuned model ** ** * "
)
logging
.
info
(
"** ** * Saving fine-tuned model ** ** * "
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
...
...
examples/lm_finetuning/simple_lm_finetuning.py
View file @
d6522e28
...
@@ -438,11 +438,10 @@ def main():
...
@@ -438,11 +438,10 @@ def main():
default
=
3.0
,
default
=
3.0
,
type
=
float
,
type
=
float
,
help
=
"Total number of training epochs to perform."
)
help
=
"Total number of training epochs to perform."
)
parser
.
add_argument
(
"--warmup_proportion"
,
parser
.
add_argument
(
"--warmup_steps"
,
default
=
0.1
,
default
=
0
,
type
=
float
,
type
=
int
,
help
=
"Proportion of training to perform linear learning rate warmup for. "
help
=
"Linear warmup over warmup_steps."
)
"E.g., 0.1 = 10%% of training."
)
parser
.
add_argument
(
"--no_cuda"
,
parser
.
add_argument
(
"--no_cuda"
,
action
=
'store_true'
,
action
=
'store_true'
,
help
=
"Whether not to use CUDA when available"
)
help
=
"Whether not to use CUDA when available"
)
...
@@ -504,7 +503,7 @@ def main():
...
@@ -504,7 +503,7 @@ def main():
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
if
os
.
path
.
exists
(
args
.
output_dir
)
and
os
.
listdir
(
args
.
output_dir
):
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
raise
ValueError
(
"Output directory ({}) already exists and is not empty."
.
format
(
args
.
output_dir
))
if
not
os
.
path
.
exists
(
args
.
output_dir
)
and
torch
.
distributed
.
get_rank
()
==
0
:
if
not
os
.
path
.
exists
(
args
.
output_dir
)
and
(
n_gpu
>
1
and
torch
.
distributed
.
get_rank
()
==
0
or
n_gpu
<=
1
)
:
os
.
makedirs
(
args
.
output_dir
)
os
.
makedirs
(
args
.
output_dir
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
tokenizer
=
BertTokenizer
.
from_pretrained
(
args
.
bert_model
,
do_lower_case
=
args
.
do_lower_case
)
...
@@ -558,14 +557,10 @@ def main():
...
@@ -558,14 +557,10 @@ def main():
optimizer
=
FP16_Optimizer
(
optimizer
,
dynamic_loss_scale
=
True
)
optimizer
=
FP16_Optimizer
(
optimizer
,
dynamic_loss_scale
=
True
)
else
:
else
:
optimizer
=
FP16_Optimizer
(
optimizer
,
static_loss_scale
=
args
.
loss_scale
)
optimizer
=
FP16_Optimizer
(
optimizer
,
static_loss_scale
=
args
.
loss_scale
)
warmup_linear
=
WarmupLinearSchedule
(
warmup
=
args
.
warmup_proportion
,
t_total
=
num_train_optimization_steps
)
else
:
else
:
optimizer
=
BertAdam
(
optimizer_grouped_parameters
,
optimizer
=
AdamW
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
eps
=
args
.
adam_epsilon
)
lr
=
args
.
learning_rate
,
scheduler
=
WarmupLinearSchedule
(
optimizer
,
warmup_steps
=
args
.
warmup_steps
,
t_total
=
num_train_optimization_steps
)
warmup
=
args
.
warmup_proportion
,
t_total
=
num_train_optimization_steps
)
global_step
=
0
global_step
=
0
if
args
.
do_train
:
if
args
.
do_train
:
...
@@ -589,7 +584,8 @@ def main():
...
@@ -589,7 +584,8 @@ def main():
for
step
,
batch
in
enumerate
(
tqdm
(
train_dataloader
,
desc
=
"Iteration"
)):
for
step
,
batch
in
enumerate
(
tqdm
(
train_dataloader
,
desc
=
"Iteration"
)):
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
input_ids
,
input_mask
,
segment_ids
,
lm_label_ids
,
is_next
=
batch
input_ids
,
input_mask
,
segment_ids
,
lm_label_ids
,
is_next
=
batch
loss
=
model
(
input_ids
,
segment_ids
,
input_mask
,
lm_label_ids
,
is_next
)
outputs
=
model
(
input_ids
,
segment_ids
,
input_mask
,
lm_label_ids
,
is_next
)
loss
=
outputs
[
0
]
if
n_gpu
>
1
:
if
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu.
loss
=
loss
.
mean
()
# mean() to average on multi-gpu.
if
args
.
gradient_accumulation_steps
>
1
:
if
args
.
gradient_accumulation_steps
>
1
:
...
@@ -602,22 +598,17 @@ def main():
...
@@ -602,22 +598,17 @@ def main():
nb_tr_examples
+=
input_ids
.
size
(
0
)
nb_tr_examples
+=
input_ids
.
size
(
0
)
nb_tr_steps
+=
1
nb_tr_steps
+=
1
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
scheduler
.
step
()
# Update learning rate schedule
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step
=
args
.
learning_rate
*
warmup_linear
.
get_lr
(
global_step
,
args
.
warmup_proportion
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr_this_step
optimizer
.
step
()
optimizer
.
step
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
global_step
+=
1
global_step
+=
1
# Save a trained model
# Save a trained model
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
if
args
.
do_train
and
(
n_gpu
>
1
and
torch
.
distributed
.
get_rank
()
==
0
or
n_gpu
<=
1
):
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
if
args
.
do_train
and
torch
.
distributed
.
get_rank
()
==
0
:
logger
.
info
(
"** ** * Saving fine - tuned model ** ** * "
)
logger
.
info
(
"** ** * Saving fine - tuned model ** ** * "
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
os
.
path
.
join
(
args
.
output_dir
,
WEIGHTS_NAME
)
output_config_file
=
os
.
path
.
join
(
args
.
output_dir
,
CONFIG_NAME
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
torch
.
save
(
model_to_save
.
state_dict
(),
output_model_file
)
model_to_save
.
config
.
to_json_file
(
output_config_file
)
model_to_save
.
config
.
to_json_file
(
output_config_file
)
tokenizer
.
save_vocabulary
(
args
.
output_dir
)
tokenizer
.
save_vocabulary
(
args
.
output_dir
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment