Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9c6a48c8
Commit
9c6a48c8
authored
Jan 27, 2019
by
Matej Svejda
Browse files
fix learning rate/fp16 and warmup problem for all examples
parent
01ff4f82
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
39 additions
and
46 deletions
+39
-46
examples/run_classifier.py
examples/run_classifier.py
+7
-1
examples/run_lm_finetuning.py
examples/run_lm_finetuning.py
+8
-12
examples/run_squad.py
examples/run_squad.py
+8
-11
examples/run_squad2.py
examples/run_squad2.py
+8
-11
examples/run_swag.py
examples/run_swag.py
+8
-11
No files found.
examples/run_classifier.py
View file @
9c6a48c8
...
...
@@ -33,7 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.modeling
import
BertForSequenceClassification
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
warmup_linear
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
...
...
@@ -536,6 +536,12 @@ def main():
nb_tr_examples
+=
input_ids
.
size
(
0
)
nb_tr_steps
+=
1
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step
=
args
.
learning_rate
*
warmup_linear
(
global_step
/
t_total
,
args
.
warmup_proportion
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr_this_step
optimizer
.
step
()
optimizer
.
zero_grad
()
global_step
+=
1
...
...
examples/run_lm_finetuning.py
View file @
9c6a48c8
...
...
@@ -31,7 +31,7 @@ from torch.utils.data.distributed import DistributedSampler
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.modeling
import
BertForPreTraining
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
warmup_linear
from
torch.utils.data
import
Dataset
import
random
...
...
@@ -42,12 +42,6 @@ logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message
logger
=
logging
.
getLogger
(
__name__
)
def
warmup_linear
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
return
x
/
warmup
return
1.0
-
x
class
BERTDataset
(
Dataset
):
def
__init__
(
self
,
corpus_path
,
tokenizer
,
seq_len
,
encoding
=
"utf-8"
,
corpus_lines
=
None
,
on_memory
=
True
):
self
.
vocab
=
tokenizer
.
vocab
...
...
@@ -527,7 +521,7 @@ def main():
train_dataset
=
BERTDataset
(
args
.
train_file
,
tokenizer
,
seq_len
=
args
.
max_seq_length
,
corpus_lines
=
None
,
on_memory
=
args
.
on_memory
)
num_train_steps
=
int
(
len
(
train_dataset
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
)
len
(
train_dataset
)
/
args
.
train_batch_size
*
args
.
num_train_epochs
)
# Prepare model
model
=
BertForPreTraining
.
from_pretrained
(
args
.
bert_model
)
...
...
@@ -607,7 +601,9 @@ def main():
nb_tr_examples
+=
input_ids
.
size
(
0
)
nb_tr_steps
+=
1
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step
=
args
.
learning_rate
*
warmup_linear
(
global_step
/
num_train_steps
,
args
.
warmup_proportion
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr_this_step
...
...
examples/run_squad.py
View file @
9c6a48c8
...
...
@@ -36,7 +36,7 @@ from torch.utils.data.distributed import DistributedSampler
from
pytorch_pretrained_bert.tokenization
import
whitespace_tokenize
,
BasicTokenizer
,
BertTokenizer
from
pytorch_pretrained_bert.modeling
import
BertForQuestionAnswering
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
warmup_linear
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
...
...
@@ -670,11 +670,6 @@ def _compute_softmax(scores):
probs
.
append
(
score
/
total_sum
)
return
probs
def
warmup_linear
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
return
x
/
warmup
return
1.0
-
x
def
main
():
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -794,7 +789,7 @@ def main():
train_examples
=
read_squad_examples
(
input_file
=
args
.
train_file
,
is_training
=
True
)
num_train_steps
=
int
(
len
(
train_examples
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
)
len
(
train_examples
)
/
args
.
train_batch_size
*
args
.
num_train_epochs
)
# Prepare model
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
...
...
@@ -905,7 +900,9 @@ def main():
else
:
loss
.
backward
()
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step
=
args
.
learning_rate
*
warmup_linear
(
global_step
/
t_total
,
args
.
warmup_proportion
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr_this_step
...
...
examples/run_squad2.py
View file @
9c6a48c8
...
...
@@ -36,7 +36,7 @@ from torch.utils.data.distributed import DistributedSampler
from
pytorch_pretrained_bert.tokenization
import
whitespace_tokenize
,
BasicTokenizer
,
BertTokenizer
from
pytorch_pretrained_bert.modeling
import
BertForQuestionAnswering
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
warmup_linear
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
...
...
@@ -759,11 +759,6 @@ def _compute_softmax(scores):
probs
.
append
(
score
/
total_sum
)
return
probs
def
warmup_linear
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
return
x
/
warmup
return
1.0
-
x
def
main
():
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -887,7 +882,7 @@ def main():
train_examples
=
read_squad_examples
(
input_file
=
args
.
train_file
,
is_training
=
True
)
num_train_steps
=
int
(
len
(
train_examples
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
)
len
(
train_examples
)
/
args
.
train_batch_size
*
args
.
num_train_epochs
)
# Prepare model
model
=
BertForQuestionAnswering
.
from_pretrained
(
args
.
bert_model
,
...
...
@@ -999,7 +994,9 @@ def main():
else
:
loss
.
backward
()
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step
=
args
.
learning_rate
*
warmup_linear
(
global_step
/
t_total
,
args
.
warmup_proportion
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr_this_step
...
...
examples/run_swag.py
View file @
9c6a48c8
...
...
@@ -29,7 +29,7 @@ from torch.utils.data.distributed import DistributedSampler
from
pytorch_pretrained_bert.tokenization
import
BertTokenizer
from
pytorch_pretrained_bert.modeling
import
BertForMultipleChoice
from
pytorch_pretrained_bert.optimization
import
BertAdam
from
pytorch_pretrained_bert.optimization
import
BertAdam
,
warmup_linear
from
pytorch_pretrained_bert.file_utils
import
PYTORCH_PRETRAINED_BERT_CACHE
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(name)s - %(message)s'
,
...
...
@@ -233,11 +233,6 @@ def select_field(features, field):
for
feature
in
features
]
def
warmup_linear
(
x
,
warmup
=
0.002
):
if
x
<
warmup
:
return
x
/
warmup
return
1.0
-
x
def
main
():
parser
=
argparse
.
ArgumentParser
()
...
...
@@ -358,7 +353,7 @@ def main():
if
args
.
do_train
:
train_examples
=
read_swag_examples
(
os
.
path
.
join
(
args
.
data_dir
,
'train.csv'
),
is_training
=
True
)
num_train_steps
=
int
(
len
(
train_examples
)
/
args
.
train_batch_size
/
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
)
len
(
train_examples
)
/
args
.
train_batch_size
*
args
.
num_train_epochs
)
# Prepare model
model
=
BertForMultipleChoice
.
from_pretrained
(
args
.
bert_model
,
...
...
@@ -457,7 +452,9 @@ def main():
else
:
loss
.
backward
()
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
if
args
.
fp16
:
# modify learning rate with special warm up BERT uses
# if args.fp16 is False, BertAdam is used that handles this automatically
lr_this_step
=
args
.
learning_rate
*
warmup_linear
(
global_step
/
t_total
,
args
.
warmup_proportion
)
for
param_group
in
optimizer
.
param_groups
:
param_group
[
'lr'
]
=
lr_this_step
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment