Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Fairseq
Commits
ba9f32cc
Commit
ba9f32cc
authored
Aug 16, 2018
by
alexeib
Committed by
Myle Ott
Sep 03, 2018
Browse files
add warmup support back to cosine lr sched (important for mt)
parent
8d6665f2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
21 deletions
+43
-21
fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
+43
-21
No files found.
fairseq/optim/lr_scheduler/cosine_lr_scheduler.py
View file @
ba9f32cc
...
@@ -13,16 +13,18 @@ from . import FairseqLRScheduler, register_lr_scheduler
...
@@ -13,16 +13,18 @@ from . import FairseqLRScheduler, register_lr_scheduler
@
register_lr_scheduler
(
'cosine'
)
@
register_lr_scheduler
(
'cosine'
)
class
CosineSchedule
(
FairseqLRScheduler
):
class
CosineSchedule
(
FairseqLRScheduler
):
"""Assign LR based on a cyclical schedule that follows the cosine function.
"""Assign LR based on a cyclical schedule that follows the cosine function.
See https://arxiv.org/pdf/1608.03983.pdf for details
See https://arxiv.org/pdf/1608.03983.pdf for details
We also support a warmup phase where we linearly increase the learning rate
from some initial learning rate (`--warmup-init-lr`) until the configured
learning rate (`--lr`).
During warmup:
lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates)
lr = lrs[update_num]
After warmup:
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
where
where
t_curr is current percentage of updates within the current period range
t_curr is current percentage of updates within the current period range
t_i is the current period range, which is scaled by t_mul after every iteration
t_i is the current period range, which is scaled by t_mul after every iteration
"""
"""
def
__init__
(
self
,
args
,
optimizer
):
def
__init__
(
self
,
args
,
optimizer
):
...
@@ -33,6 +35,10 @@ class CosineSchedule(FairseqLRScheduler):
...
@@ -33,6 +35,10 @@ class CosineSchedule(FairseqLRScheduler):
' Consider --lr-scheduler=fixed instead.'
' Consider --lr-scheduler=fixed instead.'
)
)
warmup_end_lr
=
args
.
max_lr
if
args
.
warmup_init_lr
<
0
:
args
.
warmup_init_lr
=
args
.
lr
[
0
]
self
.
min_lr
=
args
.
lr
[
0
]
self
.
min_lr
=
args
.
lr
[
0
]
self
.
max_lr
=
args
.
max_lr
self
.
max_lr
=
args
.
max_lr
...
@@ -40,15 +46,27 @@ class CosineSchedule(FairseqLRScheduler):
...
@@ -40,15 +46,27 @@ class CosineSchedule(FairseqLRScheduler):
self
.
t_mult
=
args
.
t_mult
self
.
t_mult
=
args
.
t_mult
self
.
period
=
args
.
lr_period_updates
self
.
period
=
args
.
lr_period_updates
if
args
.
warmup_updates
>
0
:
# linearly warmup for the first args.warmup_updates
self
.
lr_step
=
(
warmup_end_lr
-
args
.
warmup_init_lr
)
/
args
.
warmup_updates
else
:
self
.
lr_step
=
1
self
.
warmup_updates
=
args
.
warmup_updates
self
.
lr_shrink
=
args
.
lr_shrink
self
.
lr_shrink
=
args
.
lr_shrink
# initial learning rate
# initial learning rate
self
.
lr
=
self
.
max
_lr
self
.
lr
=
args
.
warmup_init
_lr
self
.
optimizer
.
set_lr
(
self
.
lr
)
self
.
optimizer
.
set_lr
(
self
.
lr
)
@
staticmethod
@
staticmethod
def
add_args
(
parser
):
def
add_args
(
parser
):
"""Add arguments to the parser for this LR scheduler."""
"""Add arguments to the parser for this LR scheduler."""
parser
.
add_argument
(
'--warmup-updates'
,
default
=
0
,
type
=
int
,
metavar
=
'N'
,
help
=
'warmup the learning rate linearly for the first N updates'
)
parser
.
add_argument
(
'--warmup-init-lr'
,
default
=-
1
,
type
=
float
,
metavar
=
'LR'
,
help
=
'initial learning rate during warmup phase; default is args.lr'
)
parser
.
add_argument
(
'--max-lr'
,
required
=
True
,
type
=
float
,
metavar
=
'LR'
,
parser
.
add_argument
(
'--max-lr'
,
required
=
True
,
type
=
float
,
metavar
=
'LR'
,
help
=
'max learning rate, must be more than args.lr'
)
help
=
'max learning rate, must be more than args.lr'
)
parser
.
add_argument
(
'--t-mult'
,
default
=
1
,
type
=
float
,
metavar
=
'LR'
,
parser
.
add_argument
(
'--t-mult'
,
default
=
1
,
type
=
float
,
metavar
=
'LR'
,
...
@@ -64,12 +82,16 @@ class CosineSchedule(FairseqLRScheduler):
...
@@ -64,12 +82,16 @@ class CosineSchedule(FairseqLRScheduler):
def
step_update
(
self
,
num_updates
):
def
step_update
(
self
,
num_updates
):
"""Update the learning rate after each update."""
"""Update the learning rate after each update."""
if
num_updates
<
self
.
args
.
warmup_updates
:
self
.
lr
=
self
.
args
.
warmup_init_lr
+
num_updates
*
self
.
lr_step
else
:
curr_updates
=
num_updates
-
self
.
args
.
warmup_updates
if
self
.
t_mult
!=
1
:
if
self
.
t_mult
!=
1
:
i
=
math
.
floor
(
math
.
log
(
1
-
num
_updates
/
self
.
period
*
(
1
-
self
.
t_mult
),
self
.
t_mult
))
i
=
math
.
floor
(
math
.
log
(
1
-
curr
_updates
/
self
.
period
*
(
1
-
self
.
t_mult
),
self
.
t_mult
))
t_i
=
self
.
t_mult
**
i
*
self
.
period
t_i
=
self
.
t_mult
**
i
*
self
.
period
t_curr
=
num
_updates
-
(
1
-
self
.
t_mult
**
i
)
/
(
1
-
self
.
t_mult
)
*
self
.
period
t_curr
=
curr
_updates
-
(
1
-
self
.
t_mult
**
i
)
/
(
1
-
self
.
t_mult
)
*
self
.
period
else
:
else
:
i
=
math
.
floor
(
num
_updates
/
self
.
period
)
i
=
math
.
floor
(
curr
_updates
/
self
.
period
)
t_i
=
self
.
period
t_i
=
self
.
period
t_curr
=
num_updates
-
(
self
.
period
*
i
)
t_curr
=
num_updates
-
(
self
.
period
*
i
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment