Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
28062e14
Commit
28062e14
authored
Dec 24, 2020
by
mohammad
Browse files
moved entire optimizer build and tested
parent
fb218c9d
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
44 additions
and
55 deletions
+44
-55
megatron/model/__init__.py
megatron/model/__init__.py
+0
-1
megatron/model/utils.py
megatron/model/utils.py
+0
-26
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+34
-1
megatron/training.py
megatron/training.py
+10
-27
No files found.
megatron/model/__init__.py
View file @
28062e14
...
@@ -33,7 +33,6 @@ from .distributed import *
...
@@ -33,7 +33,6 @@ from .distributed import *
from
.bert_model
import
BertModel
,
BertModelFirstStage
,
BertModelIntermediateStage
,
BertModelLastStage
from
.bert_model
import
BertModel
,
BertModelFirstStage
,
BertModelIntermediateStage
,
BertModelLastStage
from
.realm_model
import
ICTBertModel
from
.realm_model
import
ICTBertModel
from
.gpt2_model
import
GPT2Model
,
GPT2ModelFirstStage
,
GPT2ModelIntermediateStage
,
GPT2ModelLastStage
from
.gpt2_model
import
GPT2Model
,
GPT2ModelFirstStage
,
GPT2ModelIntermediateStage
,
GPT2ModelLastStage
from
.utils
import
get_params_for_weight_decay_optimization
from
.language_model
import
get_language_model
from
.language_model
import
get_language_model
megatron/model/utils.py
View file @
28062e14
...
@@ -20,7 +20,6 @@ import math
...
@@ -20,7 +20,6 @@ import math
import
torch
import
torch
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model
import
import_layernorm
def
init_method_normal
(
sigma
):
def
init_method_normal
(
sigma
):
"""Init method based on N(0, sigma)."""
"""Init method based on N(0, sigma)."""
...
@@ -60,28 +59,3 @@ def openai_gelu(x):
...
@@ -60,28 +59,3 @@ def openai_gelu(x):
@
torch
.
jit
.
script
@
torch
.
jit
.
script
def
erf_gelu
(
x
):
def
erf_gelu
(
x
):
return
x
*
0.5
*
(
torch
.
erf
(
x
/
1.41421
).
to
(
dtype
=
x
.
dtype
)
+
torch
.
ones_like
(
x
).
to
(
dtype
=
x
.
dtype
))
return
x
*
0.5
*
(
torch
.
erf
(
x
/
1.41421
).
to
(
dtype
=
x
.
dtype
)
+
torch
.
ones_like
(
x
).
to
(
dtype
=
x
.
dtype
))
def
get_params_for_weight_decay_optimization
(
module
):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
megatron/optimizer/optimizer.py
View file @
28062e14
...
@@ -21,16 +21,49 @@ from abc import abstractmethod
...
@@ -21,16 +21,49 @@ from abc import abstractmethod
import
torch
import
torch
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.multi_tensor_apply
import
multi_tensor_applier
from
apex.optimizers
import
FusedAdam
as
Adam
import
amp_C
import
amp_C
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model
import
import_layernorm
def
get_megatron_optimizer
(
optimizer
,
model
):
def
get_params_for_weight_decay_optimization
(
module
):
"""Divide params into with-weight-decay and without-weight-decay groups.
Layernorms and baises will have no weight decay but the rest will.
"""
args
=
get_args
()
args
=
get_args
()
LayerNorm
=
import_layernorm
(
args
.
fp32_residual_connection
)
weight_decay_params
=
{
'params'
:
[]}
no_weight_decay_params
=
{
'params'
:
[],
'weight_decay'
:
0.0
}
for
module_
in
module
.
modules
():
if
isinstance
(
module_
,
LayerNorm
):
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
p
in
list
(
module_
.
_parameters
.
values
())
if
p
is
not
None
])
else
:
weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
!=
'bias'
])
no_weight_decay_params
[
'params'
].
extend
(
[
p
for
n
,
p
in
list
(
module_
.
_parameters
.
items
())
if
p
is
not
None
and
n
==
'bias'
])
return
weight_decay_params
,
no_weight_decay_params
def
get_megatron_optimizer
(
model
):
args
=
get_args
()
# Base optimizer.
param_groups
=
get_params_for_weight_decay_optimization
(
model
)
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
eps
=
args
.
adam_eps
)
if
args
.
fp16
:
if
args
.
fp16
:
# Constant loss scale.
# Constant loss scale.
...
...
megatron/training.py
View file @
28062e14
...
@@ -24,7 +24,6 @@ _TRAIN_START_TIME = time.time()
...
@@ -24,7 +24,6 @@ _TRAIN_START_TIME = time.time()
import
torch
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
apex.optimizers
import
FusedAdam
as
Adam
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
get_timers
...
@@ -45,7 +44,6 @@ from megatron.initialize import initialize_megatron
...
@@ -45,7 +44,6 @@ from megatron.initialize import initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.learning_rates
import
AnnealingLR
from
megatron.learning_rates
import
AnnealingLR
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
get_params_for_weight_decay_optimization
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.model.realm_model
import
ICTBertModel
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.data.data_loaders
import
build_pretraining_data_loader
from
megatron.data.data_loaders
import
build_pretraining_data_loader
...
@@ -184,6 +182,10 @@ def get_model(model_provider_func):
...
@@ -184,6 +182,10 @@ def get_model(model_provider_func):
# Build model on cpu.
# Build model on cpu.
model
=
model_provider_func
()
model
=
model_provider_func
()
# Set tensor model parallel attributes if not set.
for
param
in
model
.
parameters
():
mpu
.
set_defaults_if_not_set_tensor_model_parallel_attributes
(
param
)
# Print number of parameters.
# Print number of parameters.
if
mpu
.
get_data_parallel_rank
()
==
0
:
if
mpu
.
get_data_parallel_rank
()
==
0
:
print
(
' > number of parameters on (tensor, pipeline) '
print
(
' > number of parameters on (tensor, pipeline) '
...
@@ -212,30 +214,6 @@ def get_model(model_provider_func):
...
@@ -212,30 +214,6 @@ def get_model(model_provider_func):
'Exiting.'
.
format
(
args
.
DDP_impl
))
'Exiting.'
.
format
(
args
.
DDP_impl
))
def
get_optimizer
(
model
):
"""Set up the optimizer."""
args
=
get_args
()
# Build parameter groups (weight decay and non-decay).
while
isinstance
(
model
,
(
torchDDP
,
LocalDDP
,
FP16_Module
)):
model
=
model
.
module
param_groups
=
get_params_for_weight_decay_optimization
(
model
)
# Add model parallel attribute if it is not set.
for
param_group
in
param_groups
:
for
param
in
param_group
[
'params'
]:
if
not
hasattr
(
param
,
'tensor_model_parallel'
):
param
.
tensor_model_parallel
=
False
# Use Adam.
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
eps
=
args
.
adam_eps
)
# Wrap into fp16 optimizer.
optimizer
=
get_megatron_optimizer
(
optimizer
,
model
)
return
optimizer
def
get_learning_rate_scheduler
(
optimizer
):
def
get_learning_rate_scheduler
(
optimizer
):
"""Build the learning rate scheduler."""
"""Build the learning rate scheduler."""
args
=
get_args
()
args
=
get_args
()
...
@@ -284,7 +262,12 @@ def setup_model_and_optimizer(model_provider_func):
...
@@ -284,7 +262,12 @@ def setup_model_and_optimizer(model_provider_func):
args
=
get_args
()
args
=
get_args
()
model
=
get_model
(
model_provider_func
)
model
=
get_model
(
model_provider_func
)
optimizer
=
get_optimizer
(
model
)
unwrapped_model
=
model
while
isinstance
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
FP16_Module
)):
unwrapped_model
=
unwrapped_model
.
module
optimizer
=
get_megatron_optimizer
(
unwrapped_model
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
lr_scheduler
=
get_learning_rate_scheduler
(
optimizer
)
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment