Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b6e0377b
Commit
b6e0377b
authored
Mar 29, 2020
by
Mohammad
Browse files
refactored learning-rate
parent
3366a5b0
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
62 additions
and
65 deletions
+62
-65
megatron/learning_rates.py
megatron/learning_rates.py
+60
-47
megatron/module.py
megatron/module.py
+0
-1
megatron/training.py
megatron/training.py
+2
-2
megatron/utils.py
megatron/utils.py
+0
-15
No files found.
megatron/learning_rates.py
View file @
b6e0377b
...
...
@@ -12,59 +12,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch DataLoader for TFRecords"""
import
torch
from
torch.optim.lr_scheduler
import
_LRScheduler
"""Learning rate decay functions."""
import
math
from
megatron
import
print_rank_0
class
AnnealingLR
(
_LRScheduler
):
"""Anneals the learning rate"""
DECAY_STYLES
=
[
'linear'
,
'cosine'
,
'exponential'
,
'constant'
,
'None'
]
class
AnnealingLR
(
object
):
"""Anneals the learning rate."""
def
__init__
(
self
,
optimizer
,
start_lr
,
warmup_iter
,
num_iters
,
decay_style
=
None
,
last_iter
=-
1
,
min_lr
=
0.0
,
def
__init__
(
self
,
optimizer
,
start_lr
,
warmup_iter
,
total_iters
,
decay_style
,
last_iter
,
min_lr
=
0.0
,
use_checkpoint_lr_scheduler
=
True
,
override_lr_scheduler
=
False
):
# Class values.
self
.
optimizer
=
optimizer
self
.
start_lr
=
start_lr
self
.
min_lr
=
min_lr
self
.
warmup_iter
=
warmup_iter
self
.
num_iters
=
last_iter
+
1
self
.
end_iter
=
num
_iters
self
.
decay_style
=
decay_style
.
lower
()
if
isinstance
(
decay_style
,
str
)
\
else
Non
e
self
.
num_iters
=
last_iter
self
.
end_iter
=
total
_iters
assert
self
.
end_iter
>
0
self
.
decay_style
=
decay_styl
e
self
.
override_lr_scheduler
=
override_lr_scheduler
self
.
use_checkpoint_lr_scheduler
=
use_checkpoint_lr_scheduler
if
self
.
override_lr_scheduler
:
assert
not
self
.
use_checkpoint_lr_scheduler
,
'both override and '
\
'use-checkpoint are set.'
# Set the learning rate
self
.
step
(
self
.
num_iters
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'learning rate decaying'
,
decay_style
)
print_rank_0
(
'> learning rate decay style: {}'
.
format
(
self
.
decay_style
))
def
get_lr
(
self
):
# https://openreview.net/pdf?id=BJYwwY9ll pg. 4
"""Learning rate decay functions from:
https://openreview.net/pdf?id=BJYwwY9ll pg. 4"""
num_iters_
=
min
(
self
.
num_iters
,
self
.
end_iter
-
self
.
warmup_iter
)
# Warmup.
if
self
.
warmup_iter
>
0
and
self
.
num_iters
<=
self
.
warmup_iter
:
return
float
(
self
.
start_lr
)
*
num_iters_
/
self
.
warmup_iter
num_iters_
=
num_iters_
-
self
.
warmup_iter
if
self
.
decay_style
==
'linear'
:
lr
=
self
.
start_lr
*
(
self
.
end_iter
-
num_iters_
)
/
self
.
end_iter
elif
self
.
decay_style
==
'cosine'
:
lr
=
self
.
start_lr
/
2.0
*
(
math
.
cos
(
math
.
pi
*
num_iters_
/
self
.
end_iter
)
+
1
)
elif
self
.
decay_style
==
'exponential'
:
# exp(-0.693) = 1/2
lr
=
self
.
start_lr
*
math
.
exp
(
-
0.693
*
num_iters_
/
self
.
end_iter
)
else
:
if
self
.
decay_style
==
self
.
DECAY_STYLES
[
0
]:
lr
=
self
.
start_lr
*
((
self
.
end_iter
-
(
num_iters_
-
self
.
warmup_iter
))
/
self
.
end_iter
)
elif
self
.
decay_style
==
self
.
DECAY_STYLES
[
1
]:
lr
=
self
.
start_lr
/
2.0
*
(
math
.
cos
(
math
.
pi
*
(
num_iters_
-
self
.
warmup_iter
)
/
self
.
end_iter
)
+
1
)
elif
self
.
decay_style
==
self
.
DECAY_STYLES
[
2
]:
# exp(-0.693) = 1/2
lr
=
self
.
start_lr
*
math
.
exp
(
-
0.693
*
(
num_iters_
-
self
.
warmup_iter
)
/
self
.
end_iter
)
else
:
lr
=
self
.
start_lr
return
max
(
lr
,
self
.
min_lr
)
lr
=
self
.
start_lr
return
max
(
lr
,
self
.
min_lr
)
def
step
(
self
,
step_num
=
None
):
"""Set lr for all parameters groups."""
if
step_num
is
None
:
step_num
=
self
.
num_iters
+
1
self
.
num_iters
=
step_num
...
...
@@ -72,42 +81,46 @@ class AnnealingLR(_LRScheduler):
for
group
in
self
.
optimizer
.
param_groups
:
group
[
'lr'
]
=
new_lr
def
state_dict
(
self
):
s
d
=
{
'start_lr'
:
self
.
start_lr
,
'warmup_iter'
:
self
.
warmup_iter
,
'num_iters'
:
self
.
num_iters
,
'decay_style'
:
self
.
decay_style
,
'end_iter'
:
self
.
end_iter
,
'min_lr'
:
self
.
min_lr
s
tate_dict
=
{
'start_lr'
:
self
.
start_lr
,
'warmup_iter'
:
self
.
warmup_iter
,
'num_iters'
:
self
.
num_iters
,
'decay_style'
:
self
.
decay_style
,
'end_iter'
:
self
.
end_iter
,
'min_lr'
:
self
.
min_lr
}
return
s
d
return
s
tate_dict
def
check_and_set_
(
self
,
cls_value
,
sd_value
,
name
):
def
_check_and_set
(
self
,
cls_value
,
sd_value
,
name
):
"""Auxiliary function for checking the values in the checkpoint and
setting them."""
if
self
.
override_lr_scheduler
:
print_rank_0
(
' > overriding {} value to {}'
.
format
(
name
,
cls_value
))
return
cls_value
else
:
if
not
self
.
use_checkpoint_lr_scheduler
:
assert
cls_value
==
sd_value
,
'AnnealingLR: class input value'
\
'and checkpoint values for {} do not match'
.
format
(
name
)
print_rank_0
(
' > using checkpoint value {} for {}'
.
format
(
sd_value
,
name
))
return
sd_value
if
not
self
.
use_checkpoint_lr_scheduler
:
assert
cls_value
==
sd_value
,
'AnnealingLR: class input value'
\
'and checkpoint values for {} do not match'
.
format
(
name
)
print_rank_0
(
' > using checkpoint value {} for {}'
.
format
(
sd_value
,
name
))
return
sd_value
def
load_state_dict
(
self
,
sd
):
self
.
start_lr
=
self
.
check_and_set
_
(
self
.
start_lr
,
sd
[
'start_lr'
],
self
.
start_lr
=
self
.
_
check_and_set
(
self
.
start_lr
,
sd
[
'start_lr'
],
'learning rate'
)
self
.
min_lr
=
self
.
check_and_set
_
(
self
.
min_lr
,
sd
[
'min_lr'
],
self
.
min_lr
=
self
.
_
check_and_set
(
self
.
min_lr
,
sd
[
'min_lr'
],
'minimum learning rate'
)
self
.
warmup_iter
=
self
.
check_and_set
_
(
self
.
warmup_iter
,
self
.
warmup_iter
=
self
.
_
check_and_set
(
self
.
warmup_iter
,
sd
[
'warmup_iter'
],
'warmup iterations'
)
self
.
end_iter
=
self
.
check_and_set
_
(
self
.
end_iter
,
sd
[
'end_iter'
],
self
.
end_iter
=
self
.
_
check_and_set
(
self
.
end_iter
,
sd
[
'end_iter'
],
'total number of iterations'
)
self
.
decay_style
=
self
.
check_and_set
_
(
self
.
decay_style
,
self
.
decay_style
=
self
.
_
check_and_set
(
self
.
decay_style
,
sd
[
'decay_style'
],
'decay style'
)
...
...
megatron/module.py
View file @
b6e0377b
...
...
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron Module"""
import
torch
...
...
megatron/training.py
View file @
b6e0377b
...
...
@@ -197,13 +197,13 @@ def get_learning_rate_scheduler(optimizer):
else
:
num_iters
=
args
.
train_iters
num_iters
=
max
(
1
,
num_iters
)
init_step
=
-
1
init_step
=
0
warmup_iter
=
args
.
warmup
*
num_iters
lr_scheduler
=
AnnealingLR
(
optimizer
,
start_lr
=
args
.
lr
,
warmup_iter
=
warmup_iter
,
num
_iters
=
num_iters
,
total
_iters
=
num_iters
,
decay_style
=
args
.
lr_decay_style
,
last_iter
=
init_step
,
min_lr
=
args
.
min_lr
,
...
...
megatron/utils.py
View file @
b6e0377b
...
...
@@ -89,8 +89,6 @@ def check_adlr_autoresume_termination(iteration, model,
###################################################
from
megatron
import
mpu
def
get_ltor_masks_and_position_ids
(
data
,
eod_token
,
...
...
@@ -148,16 +146,3 @@ def get_ltor_masks_and_position_ids(data,
return
attention_mask
,
loss_mask
,
position_ids
def
vocab_size_with_padding
(
num_tokens
,
args
):
after
=
num_tokens
multiple
=
args
.
make_vocab_size_divisible_by
*
\
mpu
.
get_model_parallel_world_size
()
while
(
after
%
multiple
)
!=
0
:
after
+=
1
print_rank_0
(
'> padded vocab (size: {}) with {} dummy '
'tokens (new size: {})'
.
format
(
num_tokens
,
after
-
num_tokens
,
after
))
return
after
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment