Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c64c2fc4
Commit
c64c2fc4
authored
Mar 20, 2019
by
Matthew Carrigan
Browse files
Fixed embarrassing indentation problem
parent
0540d360
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
52 deletions
+51
-52
examples/lm_finetuning/finetune_on_pregenerated.py
examples/lm_finetuning/finetune_on_pregenerated.py
+51
-52
No files found.
examples/lm_finetuning/finetune_on_pregenerated.py
View file @
c64c2fc4
...
...
@@ -241,8 +241,7 @@ def main():
from
apex.optimizers
import
FP16_Optimizer
from
apex.optimizers
import
FusedAdam
except
ImportError
:
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
)
raise
ImportError
(
"Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
)
optimizer
=
FusedAdam
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
...
...
@@ -259,57 +258,57 @@ def main():
warmup
=
args
.
warmup_proportion
,
t_total
=
num_train_optimization_steps
)
global_step
=
0
logging
.
info
(
"***** Running training *****"
)
logging
.
info
(
f
" Num examples =
{
total_train_examples
}
"
)
logging
.
info
(
" Batch size = %d"
,
args
.
train_batch_size
)
logging
.
info
(
" Num steps = %d"
,
num_train_optimization_steps
)
model
.
train
()
for
epoch
in
range
(
args
.
epochs
):
epoch_dataset
=
PregeneratedDataset
(
epoch
=
epoch
,
training_path
=
args
.
pregenerated_data
,
tokenizer
=
tokenizer
,
num_data_epochs
=
num_data_epochs
)
if
args
.
local_rank
==
-
1
:
train_sampler
=
RandomSampler
(
epoch_dataset
)
else
:
train_sampler
=
DistributedSampler
(
epoch_dataset
)
train_dataloader
=
DataLoader
(
epoch_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
tr_loss
=
0
nb_tr_examples
,
nb_tr_steps
=
0
,
0
with
tqdm
(
total
=
len
(
train_dataloader
),
desc
=
f
"Epoch
{
epoch
}
"
)
as
pbar
:
for
step
,
batch
in
enumerate
(
train_dataloader
):
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
input_ids
,
input_mask
,
segment_ids
,
lm_label_ids
,
is_next
=
batch
loss
=
model
(
input_ids
,
segment_ids
,
input_mask
,
lm_label_ids
,
is_next
)
if
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu.
if
args
.
gradient_accumulation_steps
>
1
:
loss
=
loss
/
args
.
gradient_accumulation_steps
global_step
=
0
logging
.
info
(
"***** Running training *****"
)
logging
.
info
(
f
" Num examples =
{
total_train_examples
}
"
)
logging
.
info
(
" Batch size = %d"
,
args
.
train_batch_size
)
logging
.
info
(
" Num steps = %d"
,
num_train_optimization_steps
)
model
.
train
()
for
epoch
in
range
(
args
.
epochs
):
epoch_dataset
=
PregeneratedDataset
(
epoch
=
epoch
,
training_path
=
args
.
pregenerated_data
,
tokenizer
=
tokenizer
,
num_data_epochs
=
num_data_epochs
)
if
args
.
local_rank
==
-
1
:
train_sampler
=
RandomSampler
(
epoch_dataset
)
else
:
train_sampler
=
DistributedSampler
(
epoch_dataset
)
train_dataloader
=
DataLoader
(
epoch_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
tr_loss
=
0
nb_tr_examples
,
nb_tr_steps
=
0
,
0
with
tqdm
(
total
=
len
(
train_dataloader
),
desc
=
f
"Epoch
{
epoch
}
"
)
as
pbar
:
for
step
,
batch
in
enumerate
(
train_dataloader
):
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
input_ids
,
input_mask
,
segment_ids
,
lm_label_ids
,
is_next
=
batch
loss
=
model
(
input_ids
,
segment_ids
,
input_mask
,
lm_label_ids
,
is_next
)
if
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu.
if
args
.
gradient_accumulation_steps
>
1
:
loss
=
loss
/
args
.
gradient_accumulation_steps
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
tr_loss
+=
loss
.
item
()
nb_tr_examples
+=
input_ids
.
size
(
0
)
nb_tr_steps
+=
1
pbar
.
update
(
1
)
mean_loss
=
tr_loss
/
nb_tr_steps
pbar
.
set_postfix_str
(
f
"Loss:
{
mean_loss
:.
5
f
}
"
)
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
optimizer
.
backward
(
loss
)
else
:
loss
.
backward
()
tr_loss
+=
loss
.
item
()
nb_tr_examples
+=
input_ids
.
size
(
0
)
nb_tr_steps
+=
1
pbar
.
update
(
1
)
mean_loss
=
tr_loss
/
nb_tr_steps
pbar
.
set_postfix_str
(
f
"Loss:
{
mean_loss
:.
5
f
}
"
)
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step
=
args
.
learning_rate
*
warmup_linear
(
global_step
/
num_train_optimization_steps
,
args
.
warmup_proportion
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr_this_step
optimizer
.
step
()
optimizer
.
zero_grad
()
global_step
+=
1
# Save a trained model
logging
.
info
(
"** ** * Saving fine-tuned model ** ** * "
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
args
.
output_dir
/
"pytorch_model.bin"
torch
.
save
(
model_to_save
.
state_dict
(),
str
(
output_model_file
))
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step
=
args
.
learning_rate
*
warmup_linear
(
global_step
/
num_train_optimization_steps
,
args
.
warmup_proportion
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr_this_step
optimizer
.
step
()
optimizer
.
zero_grad
()
global_step
+=
1
# Save a trained model
logging
.
info
(
"** ** * Saving fine-tuned model ** ** * "
)
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
output_model_file
=
args
.
output_dir
/
"pytorch_model.bin"
torch
.
save
(
model_to_save
.
state_dict
(),
str
(
output_model_file
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment