Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
329fe582
Commit
329fe582
authored
Feb 10, 2022
by
Lawrence McAfee
Browse files
working on Float16DistributedOptimizer
parent
7dc8c475
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
243 additions
and
56 deletions
+243
-56
megatron/arguments.py
megatron/arguments.py
+12
-0
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+27
-10
megatron/optimizer/distributed_fused_adam.py
megatron/optimizer/distributed_fused_adam.py
+95
-0
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+93
-0
megatron/training.py
megatron/training.py
+16
-46
No files found.
megatron/arguments.py
View file @
329fe582
...
...
@@ -168,6 +168,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert
args
.
DDP_impl
==
'local'
assert
args
.
use_contiguous_buffers_in_local_ddp
# >>>
# If we use the distributed optimizer, we need to have local DDP
# and we should make sure use-contiguous-buffers-in-local-ddp is on.
if
args
.
use_distributed_optimizer
:
assert
args
.
DDP_impl
==
'local'
assert
args
.
use_contiguous_buffers_in_local_ddp
# <<<
# For torch DDP, we do not use contiguous buffer
if
args
.
DDP_impl
==
'torch'
:
args
.
use_contiguous_buffers_in_local_ddp
=
False
...
...
@@ -700,6 +708,10 @@ def _add_distributed_args(parser):
help
=
'Call torch.cuda.empty_cache() each iteration '
'(training and eval), to reduce fragmentation.'
'0=off, 1=moderate, 2=aggressive.'
)
# >>>
group
.
add_argument
(
'--use-distributed-optimizer'
,
action
=
'store_true'
,
help
=
'Use distributed optimizer.'
)
# <<<
return
parser
...
...
megatron/optimizer/__init__.py
View file @
329fe582
...
...
@@ -19,9 +19,17 @@ from apex.optimizers import FusedSGD as SGD
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
# >>>
from
.distributed_fused_adam
import
DistributedFusedAdam
# <<<
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
# >>>
from
.optimizer
import
(
Float16OptimizerWithFloat16Params
,
Float16DistributedOptimizer
,
FP32Optimizer
,
)
# <<<
def
get_param_groups
(
modules
,
no_weight_decay_cond
,
...
...
@@ -97,7 +105,11 @@ def get_megatron_optimizer(model,
# })
# <<<
if
args
.
optimizer
==
'adam'
:
# >>>
if
args
.
use_distributed_optimizer
:
optimizer
=
DistributedFusedAdam
(
param_groups
)
# <<<
elif
args
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
...
...
@@ -141,13 +153,18 @@ def get_megatron_optimizer(model,
hysteresis
=
args
.
hysteresis
)
# Megatron optimizer.
return
Float16OptimizerWithFloat16Params
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
bf16
,
grad_scaler
)
# >>>
opt_ty
=
Float16DistributedOptimizer
\
if
args
.
use_distributed_optimizer
\
else
Float16OptimizerWithFloat16Params
return
opt_ty
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
bf16
,
grad_scaler
)
# <<<
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
...
...
megatron/optimizer/distributed_fused_adam.py
0 → 100644
View file @
329fe582
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
torch
from
megatron
import
mpu
# >>>
from
lutil
import
pax
,
tp
# <<<
class
DistributedFusedAdam
(
torch
.
optim
.
Optimizer
):
def
__init__
(
self
,
params
):
super
().
__init__
(
params
,
defaults
=
{})
self
.
initialized
=
False
# self.params_32 = None
# self.grads_32 = None
# self.opt_m = None
# self.opt_v = None
# pax(0, {
# "param_groups" : self.param_groups,
# "param_groups / 0" : self.param_groups[0],
# "param_groups / 1" : self.param_groups[1],
# "param_groups / 0 / params" : self.param_groups[0]["params"],
# # "param_groups / params" : [ g["params"] for g in self.param_groups ],
# })
def
initialize
(
self
):
if
self
.
initialized
:
raise
Exception
(
"initialization worked."
)
return
self
.
initialized
=
True
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
total_param_size
=
sum
(
p
.
numel
()
for
g
in
self
.
param_groups
for
p
in
g
[
"params"
]
)
shard_size
=
int
(
math
.
ceil
(
total_param_size
/
data_parallel_world_size
))
shard_start_index
=
data_parallel_rank
*
shard_size
shard_end_index
=
min
(
total_param_size
,
shard_start_index
+
shard_size
)
shard_size
=
shard_end_index
-
shard_start_index
allocate_shard
=
lambda
dtype
:
torch
.
empty
(
[
shard_size
],
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
())
self
.
main_param_shard
=
allocate_shard
(
torch
.
float
)
self
.
main_grad_shard
=
allocate_shard
(
torch
.
float
)
self
.
adam_m_shard
=
allocate_shard
(
torch
.
float
)
self
.
adam_v_shard
=
allocate_shard
(
torch
.
float
)
# pax(2, {
# "data_parallel_rank" : data_parallel_rank,
# "data_parallel_world_size" : data_parallel_world_size,
# "total_param_size" : total_param_size,
# "shard_size" : shard_size,
# "shard" : "%d [ %d, %d ]" % (
# shard_size,
# shard_start_index,
# shard_end_index,
# ),
# })
def
step
(
self
):
self
.
initialize
()
raise
Exception
(
"what's next?"
)
# >>>
# eof
# <<<
megatron/optimizer/optimizer.py
View file @
329fe582
...
...
@@ -275,6 +275,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# <<<
# >>>
# debug()
# from lutil import pax, tp
# pax(0, {
# "param" : tp(param),
# "main_param" : tp(main_param),
# })
# <<<
fp32_from_float16_params_this_group
.
append
(
main_param
)
# Reset existing state dict key to the new main param.
...
...
@@ -354,6 +360,84 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
return
self
.
grad_scaler
.
scale
# >>>
def
reduce_gradientss
(
self
):
# >>>
# if not args.use_distributed_optimizer:
# All-reduce if needed.
# >>>
# if args.DDP_impl == 'local' and not args.use_distributed_optimizer:
if
args
.
DDP_impl
==
'local'
:
# <<<
timers
(
'backward-params-all-reduce'
).
start
()
for
model_module
in
model
:
# >>>
# from lutil import pax, tp
# pax(0, {
# "model" : model,
# "model_module" : model_module,
# })
# <<<
# >>>
# e.g., grad_shard = optimizer.get_grad_shard()
# <<<
model_module
.
allreduce_gradients
()
timers
(
'backward-params-all-reduce'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
# >>>
# if args.DDP_impl == 'local':
# grad = word_embeddings_weight.main_grad
# else:
# grad = word_embeddings_weight.grad
# torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
# +++
grad_shard
=
optimizer
.
get_grad_shard
(
word_embeddings
)
torch
.
distributed
.
all_reduce
(
grad_shard
,
group
=
mpu
.
get_embedding_group
())
# <<<
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
assert
args
.
DDP_impl
==
'local'
,
\
'T5 model is only supported with local DDP mode'
# >>>
# grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
# torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
# +++
grad_shard
=
optimizer
.
get_grad_shard
(
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
)
torch
.
distributed
.
all_reduce
(
grad_shard
,
group
=
mpu
.
get_position_embedding_group
())
# <<<
timers
(
'backward-embedding-all-reduce'
).
stop
()
def
_copy_model_grads_to_main_grads
(
self
):
# This only needs to be done for the float16 group.
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
...
...
@@ -542,6 +626,15 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
current_param
.
data
.
copy_
(
saved_param
.
data
)
# >>>
class
Float16DistributedOptimizer
(
Float16OptimizerWithFloat16Params
):
def
step
(
self
):
raise
Exception
(
"hi."
)
# <<<
class
FP32Optimizer
(
MegatronOptimizer
):
...
...
megatron/training.py
View file @
329fe582
...
...
@@ -410,60 +410,30 @@ def train_step(forward_step_func, data_iterator,
partition
.
zero_grad_buffer
()
optimizer
.
zero_grad
()
# >>>
# Forward pass.
# <<<
forward_backward_func
=
get_forward_backward_func
()
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
=
False
)
# Empty unused memory
# >>>
# Empty unused memory.
# <<<
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
for
model_module
in
model
:
model_module
.
allreduce_gradients
()
timers
(
'backward-params-all-reduce'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
if
args
.
DDP_impl
==
'local'
:
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
assert
args
.
DDP_impl
==
'local'
,
\
'T5 model is only supported with local DDP mode'
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
# >>>
# Reduce gradients. (with distributed optimizer option, optimizer
# now responsible for reducing gradients)
optimizer
.
reduce_gradients
()
# <<<
# >>>
from
lutil
import
pax
pax
({
"optimizer"
:
optimizer
})
# <<<
# Update parameters.
timers
(
'optimizer'
).
start
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment