Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
49cca4d9
Commit
49cca4d9
authored
Feb 10, 2022
by
Lawrence McAfee
Browse files
more work on Float16DistributedOptimizer
parent
329fe582
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
158 additions
and
27 deletions
+158
-27
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+14
-8
megatron/optimizer/distributed_fused_adam.py
megatron/optimizer/distributed_fused_adam.py
+1
-0
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+134
-16
megatron/training.py
megatron/training.py
+9
-3
No files found.
megatron/optimizer/__init__.py
View file @
49cca4d9
...
@@ -20,7 +20,7 @@ from megatron import get_args
...
@@ -20,7 +20,7 @@ from megatron import get_args
from
megatron.model
import
LayerNorm
from
megatron.model
import
LayerNorm
# >>>
# >>>
from
.distributed_fused_adam
import
DistributedFusedAdam
#
from .distributed_fused_adam import DistributedFusedAdam
# <<<
# <<<
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
# >>>
# >>>
...
@@ -106,10 +106,11 @@ def get_megatron_optimizer(model,
...
@@ -106,10 +106,11 @@ def get_megatron_optimizer(model,
# <<<
# <<<
# >>>
# >>>
if
args
.
use_distributed_optimizer
:
# if args.use_distributed_optimizer:
optimizer
=
DistributedFusedAdam
(
param_groups
)
# optimizer = DistributedFusedAdam(param_groups)
# elif args.optimizer == 'adam':
# <<<
# <<<
el
if
args
.
optimizer
==
'adam'
:
if
args
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
weight_decay
=
args
.
weight_decay
,
...
@@ -167,7 +168,12 @@ def get_megatron_optimizer(model,
...
@@ -167,7 +168,12 @@ def get_megatron_optimizer(model,
# <<<
# <<<
# FP32.
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
# >>>
args
.
log_num_zeros_in_grad
,
opt_ty
=
Float32DistributedOptimizer
\
params_have_main_grad
,
if
args
.
use_distributed_optimizer
\
args
.
use_contiguous_buffers_in_local_ddp
)
else
Float32Optimizer
return
opt_ty
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
)
# <<<
megatron/optimizer/distributed_fused_adam.py
View file @
49cca4d9
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
?
?
?
import
math
import
math
import
torch
import
torch
...
...
megatron/optimizer/optimizer.py
View file @
49cca4d9
...
@@ -29,6 +29,9 @@ from megatron import print_rank_0
...
@@ -29,6 +29,9 @@ from megatron import print_rank_0
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
# >>>
from
lutil
import
pax
,
tp
# <<<
def
_zero_grad_group_helper
(
group
,
set_to_none
):
def
_zero_grad_group_helper
(
group
,
set_to_none
):
"""Zero out the gradient for a group of parameters.
"""Zero out the gradient for a group of parameters.
...
@@ -361,7 +364,20 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
...
@@ -361,7 +364,20 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# >>>
# >>>
def
reduce_gradientss
(
self
):
def
reduce_gradients
(
self
,
model
):
# >>>
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.utils
import
unwrap_model
args
=
get_args
()
timers
=
get_timers
()
# <<<
# >>>
# >>>
# if not args.use_distributed_optimizer:
# if not args.use_distributed_optimizer:
...
@@ -405,15 +421,15 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
...
@@ -405,15 +421,15 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if
unwrapped_model
.
share_word_embeddings
:
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
# >>>
# >>>
#
if args.DDP_impl == 'local':
if
args
.
DDP_impl
==
'local'
:
#
grad = word_embeddings_weight.main_grad
grad
=
word_embeddings_weight
.
main_grad
#
else:
else
:
#
grad = word_embeddings_weight.grad
grad
=
word_embeddings_weight
.
grad
#
torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
# +++
# +++
grad_shard
=
optimizer
.
get_grad_shard
(
word_embeddings
)
#
grad_shard = optimizer.get_grad_shard(word_embeddings)
torch
.
distributed
.
all_reduce
(
grad_shard
,
#
torch.distributed.all_reduce(grad_shard,
group
=
mpu
.
get_embedding_group
())
#
group=mpu.get_embedding_group())
# <<<
# <<<
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
...
@@ -428,13 +444,13 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
...
@@ -428,13 +444,13 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
assert
args
.
DDP_impl
==
'local'
,
\
assert
args
.
DDP_impl
==
'local'
,
\
'T5 model is only supported with local DDP mode'
'T5 model is only supported with local DDP mode'
# >>>
# >>>
#
grad = unwrapped_model.language_model.embedding.position_embeddings.weight.main_grad
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
#
torch.distributed.all_reduce(grad, group=mpu.get_position_embedding_group())
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
# +++
# +++
grad_shard
=
optimizer
.
get_grad_shard
(
#
grad_shard = optimizer.get_grad_shard(
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
)
#
unwrapped_model.language_model.embedding.position_embeddings.weight)
torch
.
distributed
.
all_reduce
(
grad_shard
,
#
torch.distributed.all_reduce(grad_shard,
group
=
mpu
.
get_position_embedding_group
())
#
group=mpu.get_position_embedding_group())
# <<<
# <<<
timers
(
'backward-embedding-all-reduce'
).
stop
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
...
@@ -629,9 +645,111 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
...
@@ -629,9 +645,111 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
# >>>
# >>>
class
Float16DistributedOptimizer
(
Float16OptimizerWithFloat16Params
):
class
Float16DistributedOptimizer
(
Float16OptimizerWithFloat16Params
):
def
__init__
(
self
,
*
args
):
super
().
__init__
(
*
args
)
self
.
initialized
=
False
# >>>
self
.
initialize
()
# <<<
def
initialize
(
self
):
# >>>
import
math
# <<<
if
self
.
initialized
:
raise
Exception
(
"initialization worked."
)
return
self
.
initialized
=
True
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
total_param_size
=
sum
(
p
.
numel
()
for
g
in
self
.
param_groups
for
p
in
g
[
"params"
]
)
shard_size
=
int
(
math
.
ceil
(
total_param_size
/
data_parallel_world_size
))
shard_start_index
=
data_parallel_rank
*
shard_size
shard_end_index
=
min
(
total_param_size
,
shard_start_index
+
shard_size
)
self
.
shard_size
=
shard_end_index
-
shard_start_index
# allocate_shard = lambda dtype : torch.empty(
# [self.shard_size],
# dtype = dtype,
# device = torch.cuda.current_device())
allocate_shard
=
lambda
dtype
:
MemoryBuffer
(
self
.
shard_size
,
dtype
)
self
.
main_param_shard
=
allocate_shard
(
torch
.
float
)
self
.
main_grad_shard
=
allocate_shard
(
torch
.
float
)
self
.
adam_m_shard
=
allocate_shard
(
torch
.
float
)
self
.
adam_v_shard
=
allocate_shard
(
torch
.
float
)
def
reduce_gradients
(
self
,
model
):
# >>>
# from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
from
megatron
import
get_args
# from megatron import get_timers
# from megatron.model import DistributedDataParallel as LocalDDP
# from megatron.model import Float16Module
# from megatron.utils import unwrap_model
args
=
get_args
()
# timers = get_timers()
# <<<
# >>>
assert
args
.
use_contiguous_buffers_in_local_ddp
# <<<
# grad_buffers = [ m._grad_buffers for m in model ]
for
virtual_model
in
model
:
grad_buffers
=
virtual_model
.
_grad_buffers
for
dtype
,
grad_buffer
in
grad_buffers
.
items
():
dp_grad_buffers
=
[
grad_buffer
.
get
(
self
.
shard_sizes
[
i
],
self
.
shard_start_indexes
[
i
])
for
i
in
self
.
data_parallel_world_size
]
pax
(
0
,
{
"dp_grad_buffers"
:
dp_grad_buffers
})
torch
.
distributed
.
reduce_scatter
(
self
.
main_grad_shard
,
grad_buffer
.
data
,
group
=
mpu
.
get_data_parallel_group
(),
)
# >>>
pax
(
0
,
{
"virtual_model"
:
virtual_model
,
"grad_buffers"
:
grad_buffers
,
"dtype"
:
dtype
,
"grad_buffer / len"
:
grad_buffer
.
numel
,
"grad_buffer / data"
:
tp
(
grad_buffer
.
data
),
# "optimizer" : self.optimizer,
"main_grad_shard"
:
tp
(
self
.
main_grad_shard
),
})
# <<<
# >>>
from
lutil
import
pax
,
tp
pax
(
0
,
{
"model"
:
model
,
"grad_buffers"
:
grad_buffers
,
"grad_buffers / 0"
:
grad_buffers
[
0
],
"grad_buffers / 0 / data"
:
tp
(
list
(
grad_buffers
[
0
].
values
())[
0
].
data
),
})
# <<<
def
step
(
self
):
def
step
(
self
):
raise
Exception
(
"
hi
."
)
raise
Exception
(
"
step
."
)
# <<<
# <<<
...
...
megatron/training.py
View file @
49cca4d9
...
@@ -427,12 +427,12 @@ def train_step(forward_step_func, data_iterator,
...
@@ -427,12 +427,12 @@ def train_step(forward_step_func, data_iterator,
# >>>
# >>>
# Reduce gradients. (with distributed optimizer option, optimizer
# Reduce gradients. (with distributed optimizer option, optimizer
# now responsible for reducing gradients)
# now responsible for reducing gradients)
optimizer
.
reduce_gradients
()
optimizer
.
reduce_gradients
(
model
)
# <<<
# <<<
# >>>
# >>>
from
lutil
import
pax
#
from lutil import pax
pax
({
"optimizer"
:
optimizer
})
#
pax(
0,
{"optimizer": optimizer})
# <<<
# <<<
# Update parameters.
# Update parameters.
...
@@ -440,6 +440,12 @@ def train_step(forward_step_func, data_iterator,
...
@@ -440,6 +440,12 @@ def train_step(forward_step_func, data_iterator,
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
()
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
()
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
# >>>
# Gather params gradients. (with distributed optimizer option, optimizer
# now responsible for gathering updated params)
optimizer
.
gather_params
()
# <<<
# Update learning rate.
# Update learning rate.
if
update_successful
:
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
increment
=
get_num_microbatches
()
*
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment