Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
75e12a27
"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "437afda2b6903b6548fb931a4bbe256226ee2581"
Commit
75e12a27
authored
Aug 08, 2018
by
Alexei Baevski
Committed by
Myle Ott
Sep 03, 2018
Browse files
cosine + triangular lr scheduler
parent
1d38624f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
157 additions
and
1 deletion
+157
-1
fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
+83
-0
fairseq/optim/lr_scheduler/triangular_lr_scheduler.py
fairseq/optim/lr_scheduler/triangular_lr_scheduler.py
+73
-0
train.py
train.py
+1
-1
No files found.
fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
0 → 100644
View file @
75e12a27
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
math
from
.
import
FairseqLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
'cosine'
)
class
CosineSchedule
(
FairseqLRScheduler
):
"""Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
where
t_curr is current percentage of updates within the current period range
t_i is the current period range, which is scaled by t_mul after every iteration
"""
def
__init__
(
self
,
args
,
optimizer
):
super
().
__init__
(
args
,
optimizer
)
if
len
(
args
.
lr
)
>
1
:
raise
ValueError
(
'Cannot use a fixed learning rate schedule with cosine.'
' Consider --lr-scheduler=fixed instead.'
)
self
.
min_lr
=
args
.
lr
[
0
]
self
.
max_lr
=
args
.
max_lr
assert
self
.
max_lr
>
self
.
min_lr
,
'max_lr must be more than lr'
self
.
t_mult
=
args
.
t_mult
self
.
period
=
args
.
lr_period_updates
self
.
lr_shrink
=
args
.
lr_shrink
# initial learning rate
self
.
lr
=
self
.
max_lr
self
.
optimizer
.
set_lr
(
self
.
lr
)
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
parser
.
add_argument
(
'--max-lr'
,
required
=
True
,
type
=
float
,
metavar
=
'LR'
,
help
=
'max learning rate, must be more than args.lr'
)
parser
.
add_argument
(
'--t-mult'
,
default
=
1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'factor to grow the length of each period'
)
parser
.
add_argument
(
'--lr-period-updates'
,
default
=
5000
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial number of updates per period'
)
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
super
().
step
(
epoch
,
val_loss
)
# we don't change the learning rate at epoch boundaries
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
if
self
.
t_mult
!=
1
:
i
=
math
.
floor
(
math
.
log
(
1
-
num_updates
/
self
.
period
*
(
1
-
self
.
t_mult
),
self
.
t_mult
))
t_i
=
self
.
t_mult
**
i
*
self
.
period
t_curr
=
num_updates
-
(
1
-
self
.
t_mult
**
i
)
/
(
1
-
self
.
t_mult
)
*
self
.
period
else
:
i
=
math
.
floor
(
num_updates
/
self
.
period
)
t_i
=
self
.
period
t_curr
=
num_updates
-
(
self
.
period
*
i
)
lr_shrink
=
self
.
lr_shrink
**
i
min_lr
=
self
.
min_lr
*
lr_shrink
max_lr
=
self
.
max_lr
*
lr_shrink
self
.
lr
=
min_lr
+
0.5
*
(
max_lr
-
min_lr
)
*
(
1
+
math
.
cos
(
math
.
pi
*
t_curr
/
t_i
))
self
.
optimizer
.
set_lr
(
self
.
lr
)
return
self
.
lr
fairseq/optim/lr_scheduler/triangular_lr_scheduler.py
0 → 100644
View file @
75e12a27
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
math
from
.
import
FairseqLRScheduler
,
register_lr_scheduler
@
register_lr_scheduler
(
'triangular'
)
class
TriangularSchedule
(
FairseqLRScheduler
):
"""Assign LR based on a triangular cyclical schedule.
See https://arxiv.org/pdf/1506.01186.pdf for details
"""
def
__init__
(
self
,
args
,
optimizer
):
super
().
__init__
(
args
,
optimizer
)
if
len
(
args
.
lr
)
>
1
:
raise
ValueError
(
'Cannot use a fixed learning rate schedule with triangular.'
' Consider --lr-scheduler=fixed instead.'
)
lr
=
args
.
lr
[
0
]
assert
args
.
max_lr
>
lr
,
'max_lr must be more than lr'
self
.
min_lr
=
lr
self
.
max_lr
=
args
.
max_lr
self
.
stepsize
=
args
.
lr_period_updates
//
2
self
.
lr_shrink
=
args
.
lr_shrink
self
.
shrink_min
=
args
.
shrink_min
# initial learning rate
self
.
lr
=
self
.
min_lr
self
.
optimizer
.
set_lr
(
self
.
lr
)
@
staticmethod
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
parser
.
add_argument
(
'--max-lr'
,
required
=
True
,
type
=
float
,
metavar
=
'LR'
,
help
=
'max learning rate, must be more than args.lr'
)
parser
.
add_argument
(
'--lr-period-updates'
,
default
=
5000
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial number of updates per period (cycle length)'
)
parser
.
add_argument
(
'--shrink-min'
,
action
=
'store_true'
,
help
=
'if set, also shrinks min lr'
)
def
step
(
self
,
epoch
,
val_loss
=
None
):
"""Update the learning rate at the end of the given epoch."""
super
().
step
(
epoch
,
val_loss
)
# we don't change the learning rate at epoch boundaries
return
self
.
optimizer
.
get_lr
()
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
cycle
=
math
.
floor
(
num_updates
/
(
2
*
self
.
stepsize
))
lr_shrink
=
self
.
lr_shrink
**
cycle
max_lr
=
self
.
max_lr
*
lr_shrink
if
self
.
shrink_min
:
min_lr
=
self
.
min_lr
*
lr_shrink
else
:
min_lr
=
self
.
min_lr
x
=
abs
(
num_updates
/
self
.
stepsize
-
2
*
(
cycle
+
1
)
+
1
)
self
.
lr
=
min_lr
+
(
max_lr
-
min_lr
)
*
max
(
0
,
(
1
-
x
))
self
.
optimizer
.
set_lr
(
self
.
lr
)
return
self
.
lr
train.py
View file @
75e12a27
...
@@ -141,7 +141,7 @@ def train(args, trainer, task, epoch_itr):
...
@@ -141,7 +141,7 @@ def train(args, trainer, task, epoch_itr):
trainer
.
get_meter
(
'wps'
).
reset
()
trainer
.
get_meter
(
'wps'
).
reset
()
num_updates
=
trainer
.
get_num_updates
()
num_updates
=
trainer
.
get_num_updates
()
if
args
.
save_interval_updates
>
0
and
num_updates
%
args
.
save_interval_updates
==
0
:
if
args
.
save_interval_updates
>
0
and
num_updates
%
args
.
save_interval_updates
==
0
and
num_updates
>
0
:
valid_losses
=
validate
(
args
,
trainer
,
task
,
epoch_itr
,
[
first_valid
])
valid_losses
=
validate
(
args
,
trainer
,
task
,
epoch_itr
,
[
first_valid
])
save_checkpoint
(
args
,
trainer
,
epoch_itr
,
valid_losses
[
0
])
save_checkpoint
(
args
,
trainer
,
epoch_itr
,
valid_losses
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment