Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ec07cf5a
Commit
ec07cf5a
authored
Jul 11, 2019
by
thomwolf
Browse files
rewamp optimization
parent
4fef5919
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
136 additions
and
387 deletions
+136
-387
examples/run_glue.py
examples/run_glue.py
+45
-32
pytorch_transformers/__init__.py
pytorch_transformers/__init__.py
+2
-2
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+1
-1
pytorch_transformers/modeling_xlnet.py
pytorch_transformers/modeling_xlnet.py
+0
-13
pytorch_transformers/optimization.py
pytorch_transformers/optimization.py
+79
-192
pytorch_transformers/optimization_openai.py
pytorch_transformers/optimization_openai.py
+0
-127
pytorch_transformers/tests/optimization_test.py
pytorch_transformers/tests/optimization_test.py
+9
-20
No files found.
examples/run_glue.py
View file @
ec07cf5a
...
...
@@ -25,19 +25,21 @@ import random
import
numpy
as
np
import
torch
from
tensorboardX
import
SummaryWriter
from
torch.utils.data
import
(
DataLoader
,
RandomSampler
,
SequentialSampler
,
TensorDataset
)
from
torch.utils.data.distributed
import
DistributedSampler
from
tensorboardX
import
SummaryWriter
from
tqdm
import
tqdm
,
trange
from
pytorch_transformers
import
WEIGHTS_NAME
from
pytorch_transformers
import
(
BertConfig
,
BertForSequenceClassification
,
BertTokenizer
,
XLMConfig
,
XLM
ForSequenceClassification
,
XLMTokenizer
,
XLNetConfig
,
XLNetForSequenceClassification
,
from
pytorch_transformers
import
(
WEIGHTS_NAME
,
BertConfig
,
BertForSequenceClassification
,
BertTokenizer
,
XLMConfig
,
XLMForSequenceClassification
,
XLM
Tokenizer
,
XLNetConfig
,
XLNetForSequenceClassification
,
XLNetTokenizer
)
from
pytorch_transformers.optimization
import
BertAdam
from
pytorch_transformers
import
AdamW
,
WarmupLinearSchedule
from
utils_glue
import
(
compute_metrics
,
convert_examples_to_features
,
output_modes
,
processors
)
...
...
@@ -56,24 +58,24 @@ def train(args, train_dataset, model, tokenizer):
if
args
.
local_rank
in
[
-
1
,
0
]:
tb_writer
=
SummaryWriter
()
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
args
.
n_gpu
args
.
train_batch_size
=
args
.
per_gpu_train_batch_size
*
max
(
1
,
args
.
n_gpu
)
train_sampler
=
RandomSampler
(
train_dataset
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
train_dataset
)
train_dataloader
=
DataLoader
(
train_dataset
,
sampler
=
train_sampler
,
batch_size
=
args
.
train_batch_size
)
if
args
.
max_steps
>
0
:
num_train_optimization_steps
=
args
.
max_steps
t_total
=
args
.
max_steps
args
.
num_train_epochs
=
args
.
max_steps
//
(
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
)
+
1
else
:
num_train_optimization_steps
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
t_total
=
len
(
train_dataloader
)
//
args
.
gradient_accumulation_steps
*
args
.
num_train_epochs
# Prepare optimizer
# Prepare optimizer
and schedule (linear warmup and decay)
no_decay
=
[
'bias'
,
'LayerNorm.weight'
]
optimizer_grouped_parameters
=
[
{
'params'
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
not
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
args
.
weight_decay
},
{
'params'
:
[
p
for
n
,
p
in
model
.
named_parameters
()
if
any
(
nd
in
n
for
nd
in
no_decay
)],
'weight_decay'
:
0.0
}
]
optimizer
=
Bert
Adam
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
,
t_total
=
num_train_optimization_steps
,
warmup
=
args
.
warmup_
proportion
)
optimizer
=
Adam
W
(
optimizer_grouped_parameters
,
lr
=
args
.
learning_rate
)
schedule
=
WarmupLinearSchedule
(
optimizer
,
warmup
_steps
=
args
.
warmup_
steps
,
t_total
=
t_total
)
if
args
.
fp16
:
try
:
from
apex
import
amp
...
...
@@ -89,11 +91,11 @@ def train(args, train_dataset, model, tokenizer):
logger
.
info
(
" Total train batch size (w. parallel, distributed & accumulation) = %d"
,
args
.
train_batch_size
*
args
.
gradient_accumulation_steps
*
(
torch
.
distributed
.
get_world_size
()
if
args
.
local_rank
!=
-
1
else
1
))
logger
.
info
(
" Gradient Accumulation steps = %d"
,
args
.
gradient_accumulation_steps
)
logger
.
info
(
" Total optimization steps = %d"
,
num_train_optimization_steps
)
logger
.
info
(
" Total optimization steps = %d"
,
t_total
)
global_step
=
0
tr_loss
,
logging_loss
=
0.0
,
0.0
optimizer
.
zero_grad
()
model
.
zero_grad
()
for
_
in
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
]):
for
step
,
batch
in
enumerate
(
tqdm
(
train_dataloader
,
desc
=
"Iteration"
,
disable
=
args
.
local_rank
not
in
[
-
1
,
0
])):
model
.
train
()
...
...
@@ -103,7 +105,7 @@ def train(args, train_dataset, model, tokenizer):
'token_type_ids'
:
batch
[
2
]
if
args
.
model_type
in
[
'bert'
,
'xlnet'
]
else
None
,
# XLM don't use segment_ids
'labels'
:
batch
[
3
]}
ouputs
=
model
(
**
inputs
)
loss
=
ouputs
[
0
]
loss
=
ouputs
[
0
]
# model outputs are always tuple in pytorch-transformers (see doc)
if
args
.
n_gpu
>
1
:
loss
=
loss
.
mean
()
# mean() to average on multi-gpu parallel training
...
...
@@ -113,22 +115,25 @@ def train(args, train_dataset, model, tokenizer):
if
args
.
fp16
:
with
amp
.
scale_loss
(
loss
,
optimizer
)
as
scaled_loss
:
scaled_loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
amp
.
master_params
(
optimizer
),
args
.
max_grad_norm
)
else
:
loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
args
.
max_grad_norm
)
tr_loss
+=
loss
.
item
()
if
(
step
+
1
)
%
args
.
gradient_accumulation_steps
==
0
:
scheduler
.
step
()
# Update learning rate schedule
optimizer
.
step
()
optimizer
.
zero_grad
()
model
.
zero_grad
()
global_step
+=
1
if
args
.
local_rank
in
[
-
1
,
0
]
and
args
.
logging_steps
>
0
and
global_step
%
args
.
logging_steps
==
0
:
# Log metrics
if
args
.
local_rank
==
-
1
:
# Only evaluate
o
n single GPU otherwise metrics may not average well
results
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
if
args
.
local_rank
==
-
1
:
# Only evaluate
whe
n single GPU otherwise metrics may not average well
results
=
evaluate
(
args
,
model
,
tokenizer
)
for
key
,
value
in
results
.
items
():
tb_writer
.
add_scalar
(
'eval_{}'
.
format
(
key
),
value
,
global_step
)
tb_writer
.
add_scalar
(
'lr'
,
optimiz
er
.
get_lr
()[
0
],
global_step
)
tb_writer
.
add_scalar
(
'lr'
,
schedul
er
.
get_lr
()[
0
],
global_step
)
tb_writer
.
add_scalar
(
'loss'
,
(
tr_loss
-
logging_loss
)
/
args
.
logging_steps
,
global_step
)
logging_loss
=
tr_loss
...
...
@@ -140,6 +145,7 @@ def train(args, train_dataset, model, tokenizer):
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Take care of distributed/parallel training
model_to_save
.
save_pretrained
(
output_dir
)
torch
.
save
(
args
,
os
.
path
.
join
(
output_dir
,
'training_args.bin'
))
logger
.
info
(
"Saving model checkpoint to %s"
,
output_dir
)
if
args
.
max_steps
>
0
and
global_step
>
args
.
max_steps
:
break
...
...
@@ -162,20 +168,21 @@ def evaluate(args, model, tokenizer, prefix=""):
if
not
os
.
path
.
exists
(
eval_output_dir
)
and
args
.
local_rank
in
[
-
1
,
0
]:
os
.
makedirs
(
eval_output_dir
)
args
.
eval_batch_size
=
args
.
per_gpu_eval_batch_size
*
args
.
n_gpu
# Note that DistributedSampler samples randomly
eval_sampler
=
SequentialSampler
(
eval_dataset
)
if
args
.
local_rank
==
-
1
else
DistributedSampler
(
eval_dataset
)
eval_dataloader
=
DataLoader
(
eval_dataset
,
sampler
=
eval_sampler
,
batch_size
=
args
.
eval_batch_size
)
# Eval!
logger
.
info
(
"***** Running evaluation *****"
)
logger
.
info
(
"***** Running evaluation
{}
*****"
.
format
(
prefix
)
)
logger
.
info
(
" Num examples = %d"
,
len
(
eval_dataset
))
logger
.
info
(
" Batch size = %d"
,
args
.
eval_batch_size
)
model
.
eval
()
eval_loss
=
0
nb_eval_steps
=
0
preds
=
None
out_label_ids
=
None
for
batch
in
tqdm
(
eval_dataloader
,
desc
=
"Evaluating"
):
model
.
eval
()
batch
=
tuple
(
t
.
to
(
args
.
device
)
for
t
in
batch
)
with
torch
.
no_grad
():
...
...
@@ -213,7 +220,7 @@ def evaluate(args, model, tokenizer, prefix=""):
return
results
def
load_and_cache_examples
(
args
,
task
,
tokenizer
,
evaluate
=
False
,
overwrite_cache
=
False
):
def
load_and_cache_examples
(
args
,
task
,
tokenizer
,
evaluate
=
False
):
processor
=
processors
[
task
]()
output_mode
=
output_modes
[
task
]
# Load data features from cache or dataset file
...
...
@@ -285,20 +292,22 @@ def main():
parser
.
add_argument
(
"--per_gpu_train_batch_size"
,
default
=
8
,
type
=
int
,
help
=
"Batch size per GPU for training."
)
parser
.
add_argument
(
"--eval_batch_size"
,
default
=
8
,
type
=
int
,
help
=
"
Total b
atch size for eval."
)
parser
.
add_argument
(
"--
per_gpu_
eval_batch_size"
,
default
=
8
,
type
=
int
,
help
=
"
B
atch size
per GPU
for eval
uation
."
)
parser
.
add_argument
(
'--gradient_accumulation_steps'
,
type
=
int
,
default
=
1
,
help
=
"Number of updates steps to accumulate before performing a backward/update pass."
)
parser
.
add_argument
(
"--learning_rate"
,
default
=
5e-5
,
type
=
float
,
help
=
"The initial learning rate for Adam."
)
parser
.
add_argument
(
"--weight_decay"
,
default
=
0.0
,
type
=
float
,
help
=
"Weight deay if we apply some."
)
parser
.
add_argument
(
"--max_grad_norm"
,
default
=
1.0
,
type
=
float
,
help
=
"Max gradient norm."
)
parser
.
add_argument
(
"--num_train_epochs"
,
default
=
3.0
,
type
=
float
,
help
=
"Total number of training epochs to perform."
)
parser
.
add_argument
(
"--max_steps"
,
default
=-
1
,
type
=
int
,
help
=
"If > 0: set total number of training steps to perform. Override num_train_epochs."
)
parser
.
add_argument
(
"--warmup_
proportion
"
,
default
=
0
.1
,
type
=
floa
t
,
help
=
"
Proportion of training with linear learning rate warmup (0.1 = 10%% of training)
."
)
parser
.
add_argument
(
"--warmup_
steps
"
,
default
=
0
,
type
=
in
t
,
help
=
"
Linear warmup over warmup_steps
."
)
parser
.
add_argument
(
'--logging_steps'
,
type
=
int
,
default
=
50
,
help
=
"Log every X updates steps."
)
...
...
@@ -409,6 +418,7 @@ def main():
if
not
os
.
path
.
exists
(
args
.
output_dir
)
and
args
.
local_rank
in
[
-
1
,
0
]:
os
.
makedirs
(
args
.
output_dir
)
logger
.
info
(
"Saving model checkpoint to %s"
,
args
.
output_dir
)
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Take care of distributed/parallel training
...
...
@@ -427,15 +437,18 @@ def main():
if
args
.
do_eval
and
args
.
local_rank
in
[
-
1
,
0
]:
checkpoints
=
[
args
.
output_dir
+
'./'
+
WEIGHTS_NAME
]
if
args
.
eval_all_checkpoints
:
checkpoints
=
list
(
os
.
path
.
dirname
(
c
)
for
c
in
glob
.
glob
(
args
.
output_dir
+
'/**/'
+
WEIGHTS_NAME
,
recursive
=
True
))
checkpoints
=
list
(
os
.
path
.
dirname
(
c
)
for
c
in
sorted
(
glob
.
glob
(
args
.
output_dir
+
'/**/'
+
WEIGHTS_NAME
,
recursive
=
True
)))
logging
.
getLogger
(
"pytorch_transformers.modeling_utils"
).
setLevel
(
logging
.
WARN
)
# Reduce logging
logger
.
info
(
"Evaluate the following checkpoints: %s"
,
checkpoints
)
results
=
{}
for
checkpoint
in
checkpoints
:
global_step
=
int
(
checkpoint
s
.
split
(
'-'
)[
-
1
])
model
=
model_class
.
from_pretrained
(
checkpoint
s
)
global_step
=
int
(
checkpoint
.
split
(
'-'
)[
-
1
])
model
=
model_class
.
from_pretrained
(
checkpoint
)
model
.
to
(
args
.
device
)
result
=
evaluate
(
args
,
model
,
tokenizer
,
prefix
=
global_step
)
result
=
dict
(
n
+
'_{}'
.
format
())
result
=
dict
((
k
+
'_{}'
.
format
(
global_step
),
v
)
for
k
,
v
in
result
.
items
())
results
.
update
(
result
)
return
results
...
...
pytorch_transformers/__init__.py
View file @
ec07cf5a
...
...
@@ -36,7 +36,7 @@ from .modeling_xlm import (XLMConfig, XLMModel,
from
.modeling_utils
import
(
WEIGHTS_NAME
,
CONFIG_NAME
,
TF_WEIGHTS_NAME
,
PretrainedConfig
,
PreTrainedModel
,
prune_layer
,
Conv1D
)
from
.optimization
import
BertAdam
from
.optimization_openai
import
OpenAIAdam
from
.optimization
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
WarmupCosineSchedule
,
WarmupCosineWithHardRestartsSchedule
,
WarmupLinearSchedule
)
from
.file_utils
import
(
PYTORCH_PRETRAINED_BERT_CACHE
,
cached_path
)
pytorch_transformers/modeling_utils.py
View file @
ec07cf5a
...
...
@@ -104,7 +104,7 @@ class PretrainedConfig(object):
for
key
in
to_remove
:
kwargs
.
pop
(
key
,
None
)
logger
.
info
(
"Model config
{}"
.
format
(
config
)
)
logger
.
info
(
"Model config
%s"
,
config
)
return
config
@
classmethod
...
...
pytorch_transformers/modeling_xlnet.py
View file @
ec07cf5a
...
...
@@ -211,10 +211,6 @@ class XLNetConfig(PretrainedConfig):
layer_norm_eps
=
1e-12
,
dropout
=
0.1
,
dropatt
=
0.1
,
init
=
"normal"
,
init_range
=
0.1
,
init_std
=
0.02
,
mem_len
=
None
,
reuse_len
=
None
,
bi_data
=
False
,
...
...
@@ -258,11 +254,6 @@ class XLNetConfig(PretrainedConfig):
dropout: float, dropout rate.
dropatt: float, dropout rate on attention probabilities.
init: str, the initialization scheme, either "normal" or "uniform".
init_range: float, initialize the parameters with a uniform distribution
in [-init_range, init_range]. Only effective when init="uniform".
init_std: float, initialize the parameters with a normal distribution
with mean 0 and stddev init_std. Only effective when init="normal".
mem_len: int, the number of tokens to cache.
reuse_len: int, the number of tokens in the currect batch to be cached
and reused in the future.
...
...
@@ -297,11 +288,7 @@ class XLNetConfig(PretrainedConfig):
self
.
initializer_range
=
initializer_range
self
.
layer_norm_eps
=
layer_norm_eps
self
.
init
=
init
self
.
init_range
=
init_range
self
.
init_std
=
init_std
self
.
dropout
=
dropout
self
.
dropatt
=
dropatt
self
.
mem_len
=
mem_len
self
.
reuse_len
=
reuse_len
self
.
bi_data
=
bi_data
...
...
pytorch_transformers/optimization.py
View file @
ec07cf5a
...
...
@@ -14,174 +14,92 @@
# limitations under the License.
"""PyTorch optimization for BERT model."""
import
logging
import
math
import
torch
from
torch.optim
import
Optimizer
from
torch.optim.optimizer
import
required
from
torch.nn.utils
import
clip_grad_norm_
import
logging
import
abc
import
sys
from
torch.optim.lr_scheduler
import
LambdaLR
logger
=
logging
.
getLogger
(
__name__
)
class
ConstantLRSchedule
(
LambdaLR
):
def
__init__
(
self
,
optimizer
,
last_epoch
=-
1
):
super
(
ConstantLR
,
self
).
__init__
(
optimizer
,
lambda
x
:
x
,
last_epoch
=
last_epoch
)
if
sys
.
version_info
>=
(
3
,
4
):
ABC
=
abc
.
ABC
else
:
ABC
=
abc
.
ABCMeta
(
'ABC'
,
(),
{})
class
_LRSchedule
(
ABC
):
""" Parent of all LRSchedules here. """
warn_t_total
=
False
# is set to True for schedules where progressing beyond t_total steps doesn't make sense
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
**
kw
):
"""
:param warmup: what fraction of t_total steps will be used for linear warmup
:param t_total: how many training steps (updates) are planned
:param kw:
"""
super
(
_LRSchedule
,
self
).
__init__
(
**
kw
)
if
t_total
<
0
:
logger
.
warning
(
"t_total value of {} results in schedule not being applied"
.
format
(
t_total
))
if
not
0.0
<=
warmup
<
1.0
and
not
warmup
==
-
1
:
raise
ValueError
(
"Invalid warmup: {} - should be in [0.0, 1.0[ or -1"
.
format
(
warmup
))
warmup
=
max
(
warmup
,
0.
)
self
.
warmup
,
self
.
t_total
=
float
(
warmup
),
float
(
t_total
)
self
.
warned_for_t_total_at_progress
=
-
1
def
get_lr
(
self
,
step
,
nowarn
=
False
):
"""
:param step: which of t_total steps we're on
:param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
:return: learning rate multiplier for current update
"""
if
self
.
t_total
<
0
:
return
1.
progress
=
float
(
step
)
/
self
.
t_total
ret
=
self
.
get_lr_
(
progress
)
# warning for exceeding t_total (only active with warmup_linear
if
not
nowarn
and
self
.
warn_t_total
and
progress
>
1.
and
progress
>
self
.
warned_for_t_total_at_progress
:
logger
.
warning
(
"Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
.
format
(
ret
,
self
.
__class__
.
__name__
))
self
.
warned_for_t_total_at_progress
=
progress
# end warning
return
ret
@
abc
.
abstractmethod
def
get_lr_
(
self
,
progress
):
class
WarmupCosineSchedule
(
LambdaLR
):
"""
:param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
:return: learning rate multiplier for current update
"""
return
1.
class
ConstantLR
(
_LRSchedule
):
def
get_lr_
(
self
,
progress
):
return
1.
class
WarmupCosineSchedule
(
_LRSchedule
):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
Linearly increases learning rate from 0 to 1 over `warmup` training steps.
Decreases learning rate from 1. to 0. over remaining `t_total - warmup` steps following a cosine curve.
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
"""
warn_t_total
=
True
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
.
5
,
**
kw
):
"""
:param warmup: see LRSchedule
:param t_total: see LRSchedule
:param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
:param kw:
"""
super
(
WarmupCosineSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
**
kw
)
self
.
cycles
=
cycles
warn_t_total
=
True
def
__init__
(
self
,
optimizer
,
warmup_steps
,
t_total
,
cycles
=
.
5
,
last_epoch
=-
1
):
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
step
/
max
(
1
,
warmup_steps
)
else
:
progress
=
(
progress
-
self
.
warmup
)
/
(
1
-
self
.
warmup
)
# progress after warmup
return
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
self
.
cycles
*
2
*
progress
))
progress
=
(
step
-
warmup_steps
)
/
max
(
1
,
t_total
-
warmup_steps
)
# progress after warmup
return
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
cycles
*
2
*
progress
))
super
(
WarmupCosineSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
WarmupCosineWithHardRestartsSchedule
(
WarmupCosineSchedule
):
class
WarmupCosineWithHardRestartsSchedule
(
LambdaLR
):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
learning rate (with hard restarts).
"""
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
1.
,
**
kw
):
super
(
WarmupCosineWithHardRestartsSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
cycles
=
cycles
,
**
kw
)
assert
(
cycles
>=
1.
)
def
__init__
(
self
,
optimizer
,
warmup_steps
,
t_total
,
cycles
=
1.
,
last_epoch
=-
1
):
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
step
/
max
(
1
,
warmup_steps
)
else
:
progress
=
(
progress
-
self
.
warmup
)
/
(
1
-
self
.
warmup
)
# progress after warmup
ret
=
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
((
self
.
cycles
*
progress
)
%
1
)))
progress
=
(
step
-
warmup_steps
)
/
max
(
1
,
t_total
-
warmup_steps
)
# progress after warmup
ret
=
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
((
cycles
*
progress
)
%
1
)))
return
ret
class
WarmupCosineWithWarmupRestartsSchedule
(
WarmupCosineWithHardRestartsSchedule
):
"""
All training progress is divided in `cycles` (default=1.) parts of equal length.
Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1.,
followed by a learning rate decreasing from 1. to 0. following a cosine curve.
"""
def
__init__
(
self
,
warmup
=
0.002
,
t_total
=-
1
,
cycles
=
1.
,
**
kw
):
assert
(
warmup
*
cycles
<
1.
)
warmup
=
warmup
*
cycles
if
warmup
>=
0
else
warmup
super
(
WarmupCosineWithWarmupRestartsSchedule
,
self
).
__init__
(
warmup
=
warmup
,
t_total
=
t_total
,
cycles
=
cycles
,
**
kw
)
def
get_lr_
(
self
,
progress
):
progress
=
progress
*
self
.
cycles
%
1.
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
else
:
progress
=
(
progress
-
self
.
warmup
)
/
(
1
-
self
.
warmup
)
# progress after warmup
ret
=
0.5
*
(
1.
+
math
.
cos
(
math
.
pi
*
progress
))
return
ret
super
(
WarmupCosineWithHardRestartsSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
WarmupConstantSchedule
(
_LRSchedule
):
class
WarmupConstantSchedule
(
LambdaLR
):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
Keeps learning rate equal to 1. after warmup.
"""
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
def
__init__
(
self
,
optimizer
,
warmup_steps
,
last_epoch
=-
1
):
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
step
/
warmup_steps
return
1.
super
(
WarmupConstantSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
WarmupLinearSchedule
(
_LRSchedule
):
class
WarmupLinearSchedule
(
LambdaLR
):
"""
Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
"""
warn_t_total
=
True
def
get_lr_
(
self
,
progress
):
if
progress
<
self
.
warmup
:
return
progress
/
self
.
warmup
return
max
((
progress
-
1.
)
/
(
self
.
warmup
-
1.
),
0.
)
def
__init__
(
self
,
optimizer
,
warmup_steps
,
t_total
,
last_epoch
=-
1
):
def
lr_lambda
(
step
):
if
step
<
warmup_steps
:
return
step
/
max
(
1
,
warmup_steps
)
return
(
t_total
-
step
)
/
max
(
1
,
t_total
-
warmup_steps
)
SCHEDULES
=
{
None
:
ConstantLR
,
"none"
:
ConstantLR
,
"warmup_cosine"
:
WarmupCosineSchedule
,
"warmup_constant"
:
WarmupConstantSchedule
,
"warmup_linear"
:
WarmupLinearSchedule
}
super
(
WarmupLinearSchedule
,
self
).
__init__
(
optimizer
,
lr_lambda
,
last_epoch
=
last_epoch
)
class
Bert
Adam
(
Optimizer
):
"""Implements
BERT version of
Adam algorithm with weight decay fix.
class
Adam
W
(
Optimizer
):
"""
Implements Adam algorithm with weight decay fix.
Parameters:
lr: learning rate
...
...
@@ -197,46 +115,21 @@ class BertAdam(Optimizer):
e: Adams epsilon. Default: 1e-6
weight_decay: Weight decay. Default: 0.01
max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
correct_bias: can be set to False to avoid correcting bias in Adam (e.g. like in Bert repository)
"""
def
__init__
(
self
,
params
,
lr
=
required
,
warmup
=-
1
,
t_total
=-
1
,
schedule
=
'warmup_linear'
,
b1
=
0.9
,
b2
=
0.999
,
e
=
1e-6
,
weight_decay
=
0.01
,
max_grad_norm
=
1.0
,
**
kwargs
):
if
lr
is
not
required
and
lr
<
0.0
:
def
__init__
(
self
,
params
,
lr
=
1e-3
,
betas
=
(
0.9
,
0.999
),
eps
=
1e-6
,
weight_decay
=
0.01
,
correct_bias
=
True
):
if
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
if
not
isinstance
(
schedule
,
_LRSchedule
)
and
schedule
not
in
SCHEDULES
:
raise
ValueError
(
"Invalid schedule parameter: {}"
.
format
(
schedule
))
if
not
0.0
<=
b1
<
1.0
:
raise
ValueError
(
"Invalid b1 parameter: {} - should be in [0.0, 1.0["
.
format
(
b1
))
if
not
0.0
<=
b2
<
1.0
:
raise
ValueError
(
"Invalid b2 parameter: {} - should be in [0.0, 1.0["
.
format
(
b2
))
if
not
e
>=
0.0
:
if
not
0.0
<=
betas
[
0
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter: {} - should be in [0.0, 1.0["
.
format
(
betas
[
0
]))
if
not
0.0
<=
betas
[
1
]
<
1.0
:
raise
ValueError
(
"Invalid beta parameter: {} - should be in [0.0, 1.0["
.
format
(
betas
[
1
]
))
if
not
0.0
<=
eps
:
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
# initialize schedule object
if
not
isinstance
(
schedule
,
_LRSchedule
):
schedule_type
=
SCHEDULES
[
schedule
]
schedule
=
schedule_type
(
warmup
=
warmup
,
t_total
=
t_total
)
else
:
if
warmup
!=
-
1
or
t_total
!=
-
1
:
logger
.
warning
(
"warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
"Please specify custom warmup and t_total in _LRSchedule object."
)
defaults
=
dict
(
lr
=
lr
,
schedule
=
schedule
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
max_grad_norm
=
max_grad_norm
)
defaults
=
dict
(
lr
=
lr
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
correct_bias
=
correct_bias
)
super
(
BertAdam
,
self
).
__init__
(
params
,
defaults
)
def
get_lr
(
self
):
lr
=
[]
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
return
[
0
]
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
].
get_lr
(
state
[
'step'
])
lr
.
append
(
lr_scheduled
)
return
lr
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
...
...
@@ -262,22 +155,28 @@ class BertAdam(Optimizer):
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
# Exponential moving average of gradient values
state
[
'
n
ex
t_m
'
]
=
torch
.
zeros_like
(
p
.
data
)
state
[
'ex
p_avg
'
]
=
torch
.
zeros_like
(
p
.
data
)
# Exponential moving average of squared gradient values
state
[
'
n
ex
t_v
'
]
=
torch
.
zeros_like
(
p
.
data
)
state
[
'ex
p_avg_sq
'
]
=
torch
.
zeros_like
(
p
.
data
)
n
ex
t_m
,
n
ex
t_v
=
state
[
'
n
ex
t_m
'
],
state
[
'
n
ex
t_v
'
]
beta1
,
beta2
=
group
[
'b
1'
],
group
[
'b2
'
]
ex
p_avg
,
ex
p_avg_sq
=
state
[
'ex
p_avg
'
],
state
[
'ex
p_avg_sq
'
]
beta1
,
beta2
=
group
[
'b
etas
'
]
# Add grad clipping
if
group
[
'max_grad_norm'
]
>
0
:
clip_grad_norm_
(
p
,
group
[
'max_grad_norm'
])
state
[
'step'
]
+=
1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
next_m
.
mul_
(
beta1
).
add_
(
1
-
beta1
,
grad
)
next_v
.
mul_
(
beta2
).
addcmul_
(
1
-
beta2
,
grad
,
grad
)
update
=
next_m
/
(
next_v
.
sqrt
()
+
group
[
'e'
])
exp_avg
.
mul_
(
beta1
).
add_
(
1
-
beta1
,
grad
)
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
1
-
beta2
,
grad
,
grad
)
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
'eps'
])
step_size
=
group
[
'lr'
]
if
group
[
'correct_bias'
]:
# No bias correction for Bert
bias_correction1
=
1
-
beta1
**
state
[
'step'
]
bias_correction2
=
1
-
beta2
**
state
[
'step'
]
step_size
=
step_size
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
p
.
data
.
addcdiv_
(
-
step_size
,
exp_avg
,
denom
)
# Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
...
...
@@ -286,20 +185,8 @@ class BertAdam(Optimizer):
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
if
group
[
'weight_decay'
]
>
0.0
:
update
+=
group
[
'weight_decay'
]
*
p
.
data
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
].
get_lr
(
state
[
'step'
])
update_with_lr
=
lr_scheduled
*
update
p
.
data
.
add_
(
-
update_with_lr
)
state
[
'step'
]
+=
1
# step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
# No bias correction
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
# Add weight decay at the end (fixed version)
if
group
[
'weight_decay'
]
>
0
:
p
.
data
.
add_
(
-
group
[
'lr'
]
*
group
[
'weight_decay'
],
p
.
data
)
return
loss
pytorch_transformers/optimization_openai.py
deleted
100644 → 0
View file @
4fef5919
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch optimization for OpenAI GPT model."""
import
math
import
torch
from
torch.optim
import
Optimizer
from
torch.optim.optimizer
import
required
from
torch.nn.utils
import
clip_grad_norm_
import
logging
from
.optimization
import
SCHEDULES
,
_LRSchedule
,
WarmupCosineWithWarmupRestartsSchedule
,
\
WarmupCosineWithHardRestartsSchedule
,
WarmupCosineSchedule
,
WarmupLinearSchedule
,
WarmupConstantSchedule
logger
=
logging
.
getLogger
(
__name__
)
class
OpenAIAdam
(
Optimizer
):
"""Implements Open AI version of Adam algorithm with weight decay fix.
"""
def
__init__
(
self
,
params
,
lr
=
required
,
schedule
=
'warmup_linear'
,
warmup
=-
1
,
t_total
=-
1
,
b1
=
0.9
,
b2
=
0.999
,
e
=
1e-8
,
weight_decay
=
0
,
vector_l2
=
False
,
max_grad_norm
=-
1
,
**
kwargs
):
if
lr
is
not
required
and
lr
<
0.0
:
raise
ValueError
(
"Invalid learning rate: {} - should be >= 0.0"
.
format
(
lr
))
if
not
isinstance
(
schedule
,
_LRSchedule
)
and
schedule
not
in
SCHEDULES
:
raise
ValueError
(
"Invalid schedule parameter: {}"
.
format
(
schedule
))
if
not
0.0
<=
b1
<
1.0
:
raise
ValueError
(
"Invalid b1 parameter: {} - should be in [0.0, 1.0["
.
format
(
b1
))
if
not
0.0
<=
b2
<
1.0
:
raise
ValueError
(
"Invalid b2 parameter: {} - should be in [0.0, 1.0["
.
format
(
b2
))
if
not
e
>=
0.0
:
raise
ValueError
(
"Invalid epsilon value: {} - should be >= 0.0"
.
format
(
e
))
# initialize schedule object
if
not
isinstance
(
schedule
,
_LRSchedule
):
schedule_type
=
SCHEDULES
[
schedule
]
schedule
=
schedule_type
(
warmup
=
warmup
,
t_total
=
t_total
)
else
:
if
warmup
!=
-
1
or
t_total
!=
-
1
:
logger
.
warning
(
"warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
"Please specify custom warmup and t_total in _LRSchedule object."
)
defaults
=
dict
(
lr
=
lr
,
schedule
=
schedule
,
b1
=
b1
,
b2
=
b2
,
e
=
e
,
weight_decay
=
weight_decay
,
vector_l2
=
vector_l2
,
max_grad_norm
=
max_grad_norm
)
super
(
OpenAIAdam
,
self
).
__init__
(
params
,
defaults
)
def
get_lr
(
self
):
lr
=
[]
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
return
[
0
]
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
].
get_lr
(
state
[
'step'
])
lr
.
append
(
lr_scheduled
)
return
lr
def
step
(
self
,
closure
=
None
):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
grad
=
p
.
grad
.
data
if
grad
.
is_sparse
:
raise
RuntimeError
(
'Adam does not support sparse gradients, please consider SparseAdam instead'
)
state
=
self
.
state
[
p
]
# State initialization
if
len
(
state
)
==
0
:
state
[
'step'
]
=
0
# Exponential moving average of gradient values
state
[
'exp_avg'
]
=
torch
.
zeros_like
(
p
.
data
)
# Exponential moving average of squared gradient values
state
[
'exp_avg_sq'
]
=
torch
.
zeros_like
(
p
.
data
)
exp_avg
,
exp_avg_sq
=
state
[
'exp_avg'
],
state
[
'exp_avg_sq'
]
beta1
,
beta2
=
group
[
'b1'
],
group
[
'b2'
]
state
[
'step'
]
+=
1
# Add grad clipping
if
group
[
'max_grad_norm'
]
>
0
:
clip_grad_norm_
(
p
,
group
[
'max_grad_norm'
])
# Decay the first and second moment running average coefficient
exp_avg
.
mul_
(
beta1
).
add_
(
1
-
beta1
,
grad
)
exp_avg_sq
.
mul_
(
beta2
).
addcmul_
(
1
-
beta2
,
grad
,
grad
)
denom
=
exp_avg_sq
.
sqrt
().
add_
(
group
[
'e'
])
bias_correction1
=
1
-
beta1
**
state
[
'step'
]
bias_correction2
=
1
-
beta2
**
state
[
'step'
]
lr_scheduled
=
group
[
'lr'
]
lr_scheduled
*=
group
[
'schedule'
].
get_lr
(
state
[
'step'
])
step_size
=
lr_scheduled
*
math
.
sqrt
(
bias_correction2
)
/
bias_correction1
p
.
data
.
addcdiv_
(
-
step_size
,
exp_avg
,
denom
)
# Add weight decay at the end (fixed version)
if
(
len
(
p
.
size
())
>
1
or
group
[
'vector_l2'
])
and
group
[
'weight_decay'
]
>
0
:
p
.
data
.
add_
(
-
lr_scheduled
*
group
[
'weight_decay'
],
p
.
data
)
return
loss
pytorch_transformers/tests/optimization_test.py
View file @
ec07cf5a
...
...
@@ -20,10 +20,9 @@ import unittest
import
torch
from
pytorch_transformers
import
BertAdam
from
pytorch_transformers
import
OpenAIAdam
from
pytorch_transformers.optimization
import
ConstantLR
,
WarmupLinearSchedule
,
WarmupConstantSchedule
,
\
WarmupCosineWithWarmupRestartsSchedule
,
WarmupCosineWithHardRestartsSchedule
,
WarmupCosineSchedule
from
pytorch_transformers
import
(
AdamW
,
ConstantLRSchedule
,
WarmupConstantSchedule
,
WarmupCosineSchedule
,
WarmupCosineWithHardRestartsSchedule
,
WarmupLinearSchedule
)
import
numpy
as
np
...
...
@@ -34,12 +33,12 @@ class OptimizationTest(unittest.TestCase):
for
a
,
b
in
zip
(
list1
,
list2
):
self
.
assertAlmostEqual
(
a
,
b
,
delta
=
tol
)
def
test_adam
(
self
):
def
test_adam
_w
(
self
):
w
=
torch
.
tensor
([
0.1
,
-
0.2
,
-
0.1
],
requires_grad
=
True
)
target
=
torch
.
tensor
([
0.4
,
0.2
,
-
0.5
])
criterion
=
torch
.
nn
.
MSELoss
()
# No warmup, constant schedule, no gradient clipping
optimizer
=
Bert
Adam
(
params
=
[
w
],
lr
=
2e-1
,
optimizer
=
Adam
W
(
params
=
[
w
],
lr
=
2e-1
,
weight_decay
=
0.0
,
max_grad_norm
=-
1
)
for
_
in
range
(
100
):
...
...
@@ -52,23 +51,13 @@ class OptimizationTest(unittest.TestCase):
class
ScheduleInitTest
(
unittest
.
TestCase
):
def
test_bert_sched_init
(
self
):
m
=
torch
.
nn
.
Linear
(
50
,
50
)
optim
=
BertAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
None
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
optim
=
BertAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
"none"
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
optim
=
BertAdam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
01
,
t_total
=
1000
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
WarmupLinearSchedule
))
# shouldn't fail
def
test_openai_sched_init
(
self
):
def
test_sched_init
(
self
):
m
=
torch
.
nn
.
Linear
(
50
,
50
)
optim
=
OpenAI
Adam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
None
)
optim
=
Adam
W
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
None
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
optim
=
OpenAI
Adam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
"none"
)
optim
=
Adam
W
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
1
,
t_total
=
1000
,
schedule
=
"none"
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
ConstantLR
))
optim
=
OpenAI
Adam
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
01
,
t_total
=
1000
)
optim
=
Adam
W
(
m
.
parameters
(),
lr
=
0.001
,
warmup
=
.
01
,
t_total
=
1000
)
self
.
assertTrue
(
isinstance
(
optim
.
param_groups
[
0
][
"schedule"
],
WarmupLinearSchedule
))
# shouldn't fail
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment