Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
0888a3e1
Commit
0888a3e1
authored
Dec 25, 2020
by
mohammad
Browse files
further refactoring
parent
dfd8ed47
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
194 additions
and
157 deletions
+194
-157
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+80
-0
megatron/optimizer/grad_scaler.py
megatron/optimizer/grad_scaler.py
+113
-0
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+0
-156
megatron/training.py
megatron/training.py
+1
-1
No files found.
megatron/optimizer/__init__.py
0 → 100644
View file @
0888a3e1
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
apex.optimizers
import
FusedAdam
as
Adam
from
megatron
import
get_args
from
megatron.model
import
import_layernorm
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
FP16OptimizerWithFP16Params
,
FP32Optimizer
def
_get_params_for_weight_decay_optimization
(
module
):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
def
get_megatron_optimizer
(
model
):
args
=
get_args
()
# Base optimizer.
param_groups
=
_get_params_for_weight_decay_optimization
(
model
)
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
eps
=
args
.
adam_eps
)
if
args
.
fp16
:
# Constant loss scale.
if
args
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
# Dynamic loss scale.
else
:
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
args
.
initial_loss_scale
,
min_scale
=
args
.
min_loss_scale
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
args
.
loss_scale_window
,
hysteresis
=
args
.
hysteresis
)
# Megatron optimizer.
return
FP16OptimizerWithFP16Params
(
optimizer
,
grad_scaler
,
args
.
clip_grad
)
# FP32.
return
FP32Optimizer
(
optimizer
,
model
,
args
.
clip_grad
)
megatron/optimizer/grad_scaler.py
0 → 100644
View file @
0888a3e1
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron grad scaler."""
from
abc
import
ABC
from
abc
import
abstractmethod
import
torch
class
MegatronGradScaler
(
ABC
):
def
__init__
(
self
,
initial_scale
):
"""Initialize scale value with the input initial scale."""
assert
initial_scale
>
0.0
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
initial_scale
])
@
property
def
scale
(
self
):
return
self
.
_scale
@
property
def
inv_scale
(
self
):
return
self
.
_scale
.
double
().
reciprocal
().
float
()
@
abstractmethod
def
update
(
self
,
found_inf
):
pass
'''
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
'''
class
ConstantGradScaler
(
MegatronGradScaler
):
def
update
(
self
,
found_inf
):
pass
class
DynamicGradScaler
(
MegatronGradScaler
):
def
__init__
(
self
,
initial_scale
,
min_scale
,
growth_factor
,
backoff_factor
,
growth_interval
,
hysteresis
):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
super
(
DynamicGradScaler
,
self
).
__init__
(
initial_scale
)
# Lower bound on the scale.
assert
min_scale
>
0.0
assert
min_scale
<=
initial_scale
self
.
min_scale
=
torch
.
cuda
.
FloatTensor
([
min_scale
])
# Growth and backoff factors for the scale.
assert
growth_factor
>
1.0
self
.
growth_factor
=
torch
.
cuda
.
FloatTensor
([
growth_factor
])
assert
backoff_factor
<
1.0
assert
backoff_factor
>
0.0
self
.
backoff_factor
=
torch
.
cuda
.
FloatTensor
([
backoff_factor
])
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert
growth_interval
>
0
self
.
growth_interval
=
growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert
hysteresis
>
0
self
.
hysteresis
=
hysteresis
# Trackers.
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
=
self
.
hysteresis
def
update
(
self
,
found_inf
):
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if
found_inf
:
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
-=
1
# Now if we are our of hysteresis count, scale down the loss.
if
self
.
_hysteresis_tracker
<=
0
:
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
min_scale
)
else
:
# If there is no nan/inf, increment the growth tracker.
self
.
_growth_tracker
+=
1
# If we have had enough consequitive intervals with no nan/inf:
if
self
.
_growth_tracker
==
self
.
growth_interval
:
# Reset the tracker and hysteresis trackers,
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
=
self
.
hysteresis
# and scale up the loss scale.
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
megatron/optimizer/optimizer.py
View file @
0888a3e1
...
...
@@ -22,166 +22,10 @@ import torch
from
torch._six
import
inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.optimizers
import
FusedAdam
as
Adam
import
amp_C
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron.model
import
import_layernorm
def
get_params_for_weight_decay_optimization
(
module
):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
def
get_megatron_optimizer
(
model
):
args
=
get_args
()
# Base optimizer.
param_groups
=
get_params_for_weight_decay_optimization
(
model
)
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
eps
=
args
.
adam_eps
)
if
args
.
fp16
:
# Constant loss scale.
if
args
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
# Dynamic loss scale.
else
:
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
args
.
initial_loss_scale
,
min_scale
=
args
.
min_loss_scale
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
args
.
loss_scale_window
,
hysteresis
=
args
.
hysteresis
)
# Megatron optimizer.
return
FP16OptimizerWithFP16Params
(
optimizer
,
grad_scaler
,
args
.
clip_grad
)
# FP32.
return
FP32Optimizer
(
optimizer
,
model
,
args
.
clip_grad
)
class
MegatronGradScaler
(
ABC
):
def
__init__
(
self
,
initial_scale
):
"""Initialize scale value with the input initial scale."""
assert
initial_scale
>
0.0
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
initial_scale
])
@
property
def
scale
(
self
):
return
self
.
_scale
@
property
def
inv_scale
(
self
):
return
self
.
_scale
.
double
().
reciprocal
().
float
()
@
abstractmethod
def
update
(
self
,
found_inf
):
pass
'''
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
'''
class
ConstantGradScaler
(
MegatronGradScaler
):
def
update
(
self
,
found_inf
):
pass
class
DynamicGradScaler
(
MegatronGradScaler
):
def
__init__
(
self
,
initial_scale
,
min_scale
,
growth_factor
,
backoff_factor
,
growth_interval
,
hysteresis
):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
super
(
DynamicGradScaler
,
self
).
__init__
(
initial_scale
)
# Lower bound on the scale.
assert
min_scale
>
0.0
assert
min_scale
<=
initial_scale
self
.
min_scale
=
torch
.
cuda
.
FloatTensor
([
min_scale
])
# Growth and backoff factors for the scale.
assert
growth_factor
>
1.0
self
.
growth_factor
=
torch
.
cuda
.
FloatTensor
([
growth_factor
])
assert
backoff_factor
<
1.0
assert
backoff_factor
>
0.0
self
.
backoff_factor
=
torch
.
cuda
.
FloatTensor
([
backoff_factor
])
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert
growth_interval
>
0
self
.
growth_interval
=
growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert
hysteresis
>
0
self
.
hysteresis
=
hysteresis
# Trackers.
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
=
self
.
hysteresis
def
update
(
self
,
found_inf
):
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if
found_inf
:
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
-=
1
# Now if we are our of hysteresis count, scale down the loss.
if
self
.
_hysteresis_tracker
<=
0
:
self
.
_scale
=
torch
.
max
(
self
.
_scale
*
self
.
backoff_factor
,
self
.
min_scale
)
else
:
# If there is no nan/inf, increment the growth tracker.
self
.
_growth_tracker
+=
1
# If we have had enough consequitive intervals with no nan/inf:
if
self
.
_growth_tracker
==
self
.
growth_interval
:
# Reset the tracker and hysteresis trackers,
self
.
_growth_tracker
=
0
self
.
_hysteresis_tracker
=
self
.
hysteresis
# and scale up the loss scale.
self
.
_scale
=
self
.
_scale
*
self
.
growth_factor
def
_zero_grad_group_helper
(
group
,
set_to_none
):
...
...
megatron/training.py
View file @
0888a3e1
...
...
@@ -38,7 +38,7 @@ from megatron import print_rank_last
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.fp16
import
FP16_Module
from
megatron.optimizer
.optimizer
import
get_megatron_optimizer
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment