Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
641408f5
Commit
641408f5
authored
Jan 28, 2022
by
Vijay Korthikanti
Browse files
more naming cleanup
parent
04ecc834
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
90 additions
and
50 deletions
+90
-50
megatron/optimizer_param_scheduler.py
megatron/optimizer_param_scheduler.py
+78
-41
megatron/training.py
megatron/training.py
+12
-9
No files found.
megatron/optimizer_param_scheduler.py
View file @
641408f5
...
@@ -13,18 +13,18 @@
...
@@ -13,18 +13,18 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""Learning rate decay functions."""
"""Learning rate decay
and weight decay incr
functions."""
import
math
import
math
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
class
OptimizerParamScheduler
(
object
):
class
OptimizerParamScheduler
(
object
):
"""Anneals
the
learning rate
.
"""
"""Anneals learning rate
and weight decay
"""
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
def
__init__
(
self
,
optimizer
,
max_lr
,
min_lr
,
warmup_steps
,
decay_steps
,
decay_style
,
lr_
warmup_steps
,
lr_
decay_steps
,
lr_
decay_style
,
start_wd
,
end_wd
,
wd_incr_style
,
start_wd
,
end_wd
,
wd_incr_steps
,
wd_incr_style
,
use_checkpoint_opt_param_scheduler
=
True
,
use_checkpoint_opt_param_scheduler
=
True
,
override_opt_param_scheduler
=
False
):
override_opt_param_scheduler
=
False
):
...
@@ -36,19 +36,19 @@ class OptimizerParamScheduler(object):
...
@@ -36,19 +36,19 @@ class OptimizerParamScheduler(object):
assert
self
.
min_lr
>=
0.0
assert
self
.
min_lr
>=
0.0
assert
self
.
max_lr
>=
self
.
min_lr
assert
self
.
max_lr
>=
self
.
min_lr
self
.
warmup_steps
=
warmup_steps
self
.
lr_
warmup_steps
=
lr_
warmup_steps
self
.
num_steps
=
0
self
.
num_steps
=
0
self
.
decay_steps
=
decay_steps
self
.
lr_
decay_steps
=
lr_
decay_steps
assert
self
.
decay_steps
>
0
assert
self
.
lr_
decay_steps
>
0
assert
self
.
warmup_steps
<
self
.
decay_steps
assert
self
.
lr_
warmup_steps
<
self
.
lr_
decay_steps
self
.
decay_style
=
decay_style
self
.
lr_
decay_style
=
lr_
decay_style
self
.
start_wd
=
start_wd
self
.
start_wd
=
start_wd
self
.
end_wd
=
end_wd
self
.
end_wd
=
end_wd
assert
self
.
start_wd
>=
0.0
assert
self
.
start_wd
>=
0.0
assert
self
.
end_wd
>=
self
.
start_wd
assert
self
.
end_wd
>=
self
.
start_wd
self
.
wd_incr_steps
=
wd_incr_steps
self
.
wd_incr_style
=
wd_incr_style
self
.
wd_incr_style
=
wd_incr_style
self
.
override_opt_param_scheduler
=
override_opt_param_scheduler
self
.
override_opt_param_scheduler
=
override_opt_param_scheduler
...
@@ -59,26 +59,27 @@ class OptimizerParamScheduler(object):
...
@@ -59,26 +59,27 @@ class OptimizerParamScheduler(object):
# Set the learning rate
# Set the learning rate
self
.
step
(
0
)
self
.
step
(
0
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
lr_
decay_style
))
def
get_wd
(
self
):
def
get_wd
(
self
):
if
self
.
num_steps
>
self
.
decay_steps
:
""" Weight decay incr functions"""
if
self
.
num_steps
>
self
.
wd_incr_steps
:
return
self
.
end_wd
return
self
.
end_wd
if
self
.
wd_incr_style
==
'constant'
:
if
self
.
wd_incr_style
==
'constant'
:
assert
self
.
start_wd
==
self
.
end_wd
assert
self
.
start_wd
==
self
.
end_wd
return
self
.
end_wd
return
self
.
end_wd
decay
_ratio
=
float
(
self
.
num_steps
)
/
float
(
self
.
decay
_steps
)
incr
_ratio
=
float
(
self
.
num_steps
)
/
float
(
self
.
wd_incr
_steps
)
assert
decay
_ratio
>=
0.0
assert
incr
_ratio
>=
0.0
assert
decay
_ratio
<=
1.0
assert
incr
_ratio
<=
1.0
delta_wd
=
self
.
end_wd
-
self
.
start_wd
delta_wd
=
self
.
end_wd
-
self
.
start_wd
if
self
.
wd_incr_style
==
'linear'
:
if
self
.
wd_incr_style
==
'linear'
:
coeff
=
decay
_ratio
coeff
=
incr
_ratio
elif
self
.
wd_incr_style
==
'cosine'
:
elif
self
.
wd_incr_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
(
1
-
decay
_ratio
))
+
1.0
)
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
(
1
-
incr
_ratio
))
+
1.0
)
else
:
else
:
raise
Exception
(
'{} weight decay increment style is not supported.'
.
format
(
raise
Exception
(
'{} weight decay increment style is not supported.'
.
format
(
self
.
wd_incr_style
))
self
.
wd_incr_style
))
...
@@ -91,33 +92,33 @@ class OptimizerParamScheduler(object):
...
@@ -91,33 +92,33 @@ class OptimizerParamScheduler(object):
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
# Use linear warmup for the initial part.
# Use linear warmup for the initial part.
if
self
.
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
warmup_steps
:
if
self
.
lr_
warmup_steps
>
0
and
self
.
num_steps
<=
self
.
lr_
warmup_steps
:
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
return
self
.
max_lr
*
float
(
self
.
num_steps
)
/
\
float
(
self
.
warmup_steps
)
float
(
self
.
lr_
warmup_steps
)
# If the learning rate is constant, just return the initial value.
# If the learning rate is constant, just return the initial value.
if
self
.
decay_style
==
'constant'
:
if
self
.
lr_
decay_style
==
'constant'
:
return
self
.
max_lr
return
self
.
max_lr
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
# For any steps larger than `self.
lr_
decay_steps`, use `self.min_lr`.
if
self
.
num_steps
>
self
.
decay_steps
:
if
self
.
num_steps
>
self
.
lr_
decay_steps
:
return
self
.
min_lr
return
self
.
min_lr
# If we are done with the warmup period, use the decay style.
# If we are done with the warmup period, use the decay style.
num_steps_
=
self
.
num_steps
-
self
.
warmup_steps
num_steps_
=
self
.
num_steps
-
self
.
lr_
warmup_steps
decay_steps_
=
self
.
decay_steps
-
self
.
warmup_steps
decay_steps_
=
self
.
lr_
decay_steps
-
self
.
lr_
warmup_steps
decay_ratio
=
float
(
num_steps_
)
/
float
(
decay_steps_
)
decay_ratio
=
float
(
num_steps_
)
/
float
(
decay_steps_
)
assert
decay_ratio
>=
0.0
assert
decay_ratio
>=
0.0
assert
decay_ratio
<=
1.0
assert
decay_ratio
<=
1.0
delta_lr
=
self
.
max_lr
-
self
.
min_lr
delta_lr
=
self
.
max_lr
-
self
.
min_lr
if
self
.
decay_style
==
'linear'
:
if
self
.
lr_
decay_style
==
'linear'
:
coeff
=
(
1.0
-
decay_ratio
)
coeff
=
(
1.0
-
decay_ratio
)
elif
self
.
decay_style
==
'cosine'
:
elif
self
.
lr_
decay_style
==
'cosine'
:
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
decay_ratio
)
+
1.0
)
coeff
=
0.5
*
(
math
.
cos
(
math
.
pi
*
decay_ratio
)
+
1.0
)
else
:
else
:
raise
Exception
(
'{} decay style is not supported.'
.
format
(
raise
Exception
(
'{} decay style is not supported.'
.
format
(
self
.
decay_style
))
self
.
lr_
decay_style
))
return
self
.
min_lr
+
coeff
*
delta_lr
return
self
.
min_lr
+
coeff
*
delta_lr
...
@@ -135,11 +136,15 @@ class OptimizerParamScheduler(object):
...
@@ -135,11 +136,15 @@ class OptimizerParamScheduler(object):
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{
state_dict
=
{
'max_lr'
:
self
.
max_lr
,
'max_lr'
:
self
.
max_lr
,
'warmup_steps'
:
self
.
warmup_steps
,
'
lr_
warmup_steps'
:
self
.
lr_
warmup_steps
,
'num_steps'
:
self
.
num_steps
,
'num_steps'
:
self
.
num_steps
,
'decay_style'
:
self
.
decay_style
,
'lr_decay_style'
:
self
.
lr_decay_style
,
'decay_steps'
:
self
.
decay_steps
,
'lr_decay_steps'
:
self
.
lr_decay_steps
,
'min_lr'
:
self
.
min_lr
'min_lr'
:
self
.
min_lr
,
'start_wd'
:
self
.
start_wd
,
'end_wd'
:
self
.
end_wd
,
'wd_incr_style'
:
self
.
wd_incr_style
,
'wd_incr_steps'
:
self
.
wd_incr_steps
}
}
return
state_dict
return
state_dict
...
@@ -153,7 +158,7 @@ class OptimizerParamScheduler(object):
...
@@ -153,7 +158,7 @@ class OptimizerParamScheduler(object):
if
not
self
.
use_checkpoint_opt_param_scheduler
:
if
not
self
.
use_checkpoint_opt_param_scheduler
:
assert
cls_value
==
sd_value
,
\
assert
cls_value
==
sd_value
,
\
f
'
AnnealingLR
: class input value
{
cls_value
}
and checkpoint'
\
f
'
OptimizerParamScheduler
: class input value
{
cls_value
}
and checkpoint'
\
f
'value
{
sd_value
}
for
{
name
}
do not match'
f
'value
{
sd_value
}
for
{
name
}
do not match'
print_rank_0
(
' > using checkpoint value {} for {}'
.
format
(
sd_value
,
print_rank_0
(
' > using checkpoint value {} for {}'
.
format
(
sd_value
,
name
))
name
))
...
@@ -174,24 +179,56 @@ class OptimizerParamScheduler(object):
...
@@ -174,24 +179,56 @@ class OptimizerParamScheduler(object):
if
'warmup_iter'
in
sd
:
if
'warmup_iter'
in
sd
:
warmup_steps_
=
sd
[
'warmup_iter'
]
warmup_steps_
=
sd
[
'warmup_iter'
]
el
se
:
el
if
'warmup_steps'
in
sd
:
warmup_steps_
=
sd
[
'warmup_steps'
]
warmup_steps_
=
sd
[
'warmup_steps'
]
self
.
warmup_steps
=
self
.
_check_and_set
(
self
.
warmup_steps
,
else
:
warmup_steps_
,
lr_warmup_steps_
=
sd
[
'lr_warmup_steps'
]
self
.
lr_warmup_steps
=
self
.
_check_and_set
(
self
.
lr_warmup_steps
,
lr_warmup_steps_
,
'warmup iterations'
)
'warmup iterations'
)
if
'end_iter'
in
sd
:
if
'end_iter'
in
sd
:
decay_steps_
=
sd
[
'end_iter'
]
lr_decay_steps_
=
sd
[
'end_iter'
]
elif
'decay_steps'
in
sd
:
lr_decay_steps_
=
sd
[
'decay_steps'
]
else
:
else
:
decay_steps_
=
sd
[
'decay_steps'
]
lr_
decay_steps_
=
sd
[
'
lr_
decay_steps'
]
self
.
decay_steps
=
self
.
_check_and_set
(
self
.
decay_steps
,
decay_steps_
,
self
.
lr_
decay_steps
=
self
.
_check_and_set
(
self
.
lr_
decay_steps
,
lr_
decay_steps_
,
'total number of iterations'
)
'total number of iterations'
)
self
.
decay_style
=
self
.
_check_and_set
(
self
.
decay_style
,
sd
[
'decay_style'
],
if
'decay_style'
in
sd
:
'decay style'
)
lr_decay_style_
=
sd
[
'decay_style'
]
else
:
lr_decay_style_
=
sd
[
'lr_decay_style'
]
self
.
lr_decay_style
=
self
.
_check_and_set
(
self
.
lr_decay_style
,
lr_decay_style_
,
'learning rate decay style'
)
if
'num_iters'
in
sd
:
if
'num_iters'
in
sd
:
num_steps
=
sd
[
'num_iters'
]
num_steps
=
sd
[
'num_iters'
]
else
:
else
:
num_steps
=
sd
[
'num_steps'
]
num_steps
=
sd
[
'num_steps'
]
self
.
step
(
increment
=
num_steps
)
self
.
step
(
increment
=
num_steps
)
if
'start_wd'
in
sd
:
self
.
start_wd
=
self
.
_check_and_set
(
self
.
start_wd
,
sd
[
'start_wd'
],
"start weight decay"
)
self
.
end_wd
=
self
.
_check_and_set
(
self
.
end_wd
,
sd
[
'end_wd'
],
"end weight decay"
)
self
.
wd_incr_steps
=
self
.
_check_and_set
(
self
.
wd_incr_steps
,
sd
[
'wd_incr_steps'
],
"total number of weight decay iterations"
)
self
.
wd_incr_style
=
self
.
_check_and_set
(
self
.
wd_incr_style
,
sd
[
'wd_incr_style'
],
"weight decay incr style"
)
megatron/training.py
View file @
641408f5
...
@@ -312,11 +312,12 @@ def get_optimizer_param_scheduler(optimizer):
...
@@ -312,11 +312,12 @@ def get_optimizer_param_scheduler(optimizer):
if
args
.
train_iters
:
if
args
.
train_iters
:
if
args
.
lr_decay_iters
is
None
:
if
args
.
lr_decay_iters
is
None
:
args
.
lr_decay_iters
=
args
.
train_iters
args
.
lr_decay_iters
=
args
.
train_iters
decay_steps
=
args
.
lr_decay_iters
*
args
.
global_batch_size
lr_decay_steps
=
args
.
lr_decay_iters
*
args
.
global_batch_size
wd_incr_steps
=
args
.
train_iters
*
args
.
global_batch_size
if
args
.
lr_warmup_fraction
is
not
None
:
if
args
.
lr_warmup_fraction
is
not
None
:
warmup_steps
=
args
.
lr_warmup_fraction
*
decay_steps
lr_
warmup_steps
=
args
.
lr_warmup_fraction
*
lr_
decay_steps
else
:
else
:
warmup_steps
=
args
.
lr_warmup_iters
*
args
.
global_batch_size
lr_
warmup_steps
=
args
.
lr_warmup_iters
*
args
.
global_batch_size
# Sample-based training.
# Sample-based training.
elif
args
.
train_samples
:
elif
args
.
train_samples
:
# We need to set training iters for later use. Technically
# We need to set training iters for later use. Technically
...
@@ -325,11 +326,12 @@ def get_optimizer_param_scheduler(optimizer):
...
@@ -325,11 +326,12 @@ def get_optimizer_param_scheduler(optimizer):
update_train_iters
(
args
)
update_train_iters
(
args
)
if
args
.
lr_decay_samples
is
None
:
if
args
.
lr_decay_samples
is
None
:
args
.
lr_decay_samples
=
args
.
train_samples
args
.
lr_decay_samples
=
args
.
train_samples
decay_steps
=
args
.
lr_decay_samples
lr_decay_steps
=
args
.
lr_decay_samples
wd_incr_steps
=
args
.
train_samples
if
args
.
lr_warmup_fraction
is
not
None
:
if
args
.
lr_warmup_fraction
is
not
None
:
warmup_steps
=
args
.
lr_warmup_fraction
*
decay_steps
lr_
warmup_steps
=
args
.
lr_warmup_fraction
*
lr_
decay_steps
else
:
else
:
warmup_steps
=
args
.
lr_warmup_samples
lr_
warmup_steps
=
args
.
lr_warmup_samples
else
:
else
:
raise
Exception
(
raise
Exception
(
'either train-iters or train-samples should be provided.'
)
'either train-iters or train-samples should be provided.'
)
...
@@ -338,11 +340,12 @@ def get_optimizer_param_scheduler(optimizer):
...
@@ -338,11 +340,12 @@ def get_optimizer_param_scheduler(optimizer):
optimizer
,
optimizer
,
max_lr
=
args
.
lr
,
max_lr
=
args
.
lr
,
min_lr
=
args
.
min_lr
,
min_lr
=
args
.
min_lr
,
warmup_steps
=
warmup_steps
,
lr_
warmup_steps
=
lr_
warmup_steps
,
decay_steps
=
decay_steps
,
lr_
decay_steps
=
lr_
decay_steps
,
decay_style
=
args
.
lr_decay_style
,
lr_
decay_style
=
args
.
lr_decay_style
,
start_wd
=
args
.
start_weight_decay
,
start_wd
=
args
.
start_weight_decay
,
end_wd
=
args
.
end_weight_decay
,
end_wd
=
args
.
end_weight_decay
,
wd_incr_steps
=
wd_incr_steps
,
wd_incr_style
=
args
.
weight_decay_incr_style
,
wd_incr_style
=
args
.
weight_decay_incr_style
,
use_checkpoint_opt_param_scheduler
=
args
.
use_checkpoint_opt_param_scheduler
,
use_checkpoint_opt_param_scheduler
=
args
.
use_checkpoint_opt_param_scheduler
,
override_opt_param_scheduler
=
args
.
override_opt_param_scheduler
)
override_opt_param_scheduler
=
args
.
override_opt_param_scheduler
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment