Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e7f0cdee
Commit
e7f0cdee
authored
Feb 23, 2022
by
Lawrence McAfee
Browse files
renamed reduce_gradients -> reduce_grads [ matches gather_params ]
parent
4b843668
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
77 additions
and
68 deletions
+77
-68
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+77
-68
No files found.
megatron/optimizer/optimizer.py
View file @
e7f0cdee
...
@@ -31,6 +31,8 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
...
@@ -31,6 +31,8 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
# >>>
from
lutil
import
pax
,
tp
from
lutil
import
pax
,
tp
DEBUG_ITERATION
=
10
# <<<
# <<<
...
@@ -130,7 +132,7 @@ class MegatronOptimizer(ABC):
...
@@ -130,7 +132,7 @@ class MegatronOptimizer(ABC):
@
abstractmethod
@
abstractmethod
def
reduce_grad
ient
s
(
self
):
def
reduce_grads
(
self
):
pass
pass
...
@@ -466,7 +468,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
...
@@ -466,7 +468,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# >>>
# >>>
def
reduce_grad
ient
s
(
self
,
model
):
def
reduce_grads
(
self
,
model
):
# >>>
# >>>
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
...
@@ -481,26 +483,10 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
...
@@ -481,26 +483,10 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers
=
get_timers
()
timers
=
get_timers
()
# <<<
# <<<
# >>>
# if not args.use_distributed_optimizer:
# All-reduce if needed.
# All-reduce if needed.
# >>>
# if args.DDP_impl == 'local' and not args.use_distributed_optimizer:
if
args
.
DDP_impl
==
'local'
:
if
args
.
DDP_impl
==
'local'
:
# <<<
timers
(
'backward-params-all-reduce'
).
start
()
timers
(
'backward-params-all-reduce'
).
start
()
for
model_module
in
model
:
for
model_module
in
model
:
# >>>
# from lutil import pax, tp
# pax(0, {
# "model" : model,
# "model_module" : model_module,
# })
# <<<
# >>>
# e.g., grad_shard = optimizer.get_grad_shard()
# <<<
model_module
.
allreduce_gradients
()
model_module
.
allreduce_gradients
()
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
...
@@ -559,7 +545,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
...
@@ -559,7 +545,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
def
gather_params
(
self
):
def
gather_params
(
self
):
pass
pass
def
_copy_model_grads_to_main_grads
(
self
):
def
_copy_model_grads_to_main_grads
(
self
,
ITERATION
):
# This only needs to be done for the float16 group.
# This only needs to be done for the float16 group.
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
self
.
fp32_from_float16_groups
):
...
@@ -627,11 +613,19 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
...
@@ -627,11 +613,19 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
return
model_data
,
main_data
return
model_data
,
main_data
def
_copy_main_params_to_model_params
(
self
):
def
_copy_main_params_to_model_params
(
self
,
ITERATION
):
# Only needed for the float16 params.
# Only needed for the float16 params.
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
model_data
,
main_data
=
self
.
_get_model_and_main_params_data_float16
()
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
_multi_tensor_copy_this_to_that
(
this
=
main_data
,
that
=
model_data
,
overflow_buf
=
self
.
_dummy_overflow_buf
)
overflow_buf
=
self
.
_dummy_overflow_buf
)
# >>>
if
ITERATION
==
DEBUG_ITERATION
:
pax
(
0
,
{
"** branch **"
:
"** main. **"
,
"ITERATION"
:
ITERATION
,
"model params"
:
[
p
for
m
in
self
.
models
for
p
in
m
.
parameters
()
],
})
# <<<
def
_copy_model_params_to_main_params
(
self
):
def
_copy_model_params_to_main_params
(
self
):
...
@@ -766,14 +760,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -766,14 +760,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"gbuf_local"
:
param_local_shard
,
"gbuf_local"
:
param_local_shard
,
"param"
:
sub_param_shard
,
"param"
:
sub_param_shard
,
}
}
pax
(
1
,
{
"gbuf_world_shard"
:
gbuf_world_shard
,
"param shards"
:
param_shard_map
[
param
],
})
# >>>
# if param_world_start < gbuf_world_shard.start:
# pax({"param shards": param_shard_map[param]})
# <<<
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
# pax(0, {"param_shard_map": [ str((str(p.shape), s)) for p,s in param_shard_map.items() ]})
...
@@ -1070,10 +1056,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1070,10 +1056,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# for main_group in self.optimizer.param_groups:
# for main_group in self.optimizer.param_groups:
# main_params.extend(main_group["params"])
# main_params.extend(main_group["params"])
_zero_grad_group_helper
(
model_params
,
set_to_none
)
# ** using contiguous buffer; don't set_to_none **
_zero_grad_group_helper
(
model_params
,
set_to_none
=
False
)
# set_to_none)
# _zero_grad_group_helper(params, set_to_none = False)
# _zero_grad_group_helper(params, set_to_none = False)
# pax(0, {"params": params})
# pax(0, {"
model_
params":
model_
params})
def
get_model_grad_buffer_dp_views
(
self
):
def
get_model_grad_buffer_dp_views
(
self
):
...
@@ -1100,13 +1087,44 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1100,13 +1087,44 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return
gbuf_view_items
return
gbuf_view_items
def
reduce_grad
ient
s
(
self
,
model
):
def
reduce_grads
(
self
,
model
):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync word embedding params.
# Sync word embedding params.
# ... todo ...
# ... todo ...
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
# >>>
if
args
.
DDP_impl
==
'local'
:
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
# +++
# grad_shard = optimizer.get_grad_shard(word_embeddings)
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync T5 position embedding params.
# Sync T5 position embedding params.
...
@@ -1153,27 +1171,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1153,27 +1171,16 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# # "grad" : tp(param.grad),
# # "grad" : tp(param.grad),
# })
# })
# pax(0, {
# pax(1, {
# "gbuf_view_items" : gbuf_view_items,
# "data_parallel_rank" : data_parallel_rank,
# "param_gbuf_map" : [
# "main params" : self.get_main_params(),
# (str(tuple(p.shape)), d)
# "model params / world" : self.get_world_model_params(),
# for p, d in self.param_gbuf_map.items()
# **{"gbuf_view_items / %d"%i:v[2] for i,v in enumerate(gbuf_view_items)},
# ],
# # "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# # "model params / local" : self.get_local_model_param_views(),
# })
# })
pax
(
1
,
{
"data_parallel_rank"
:
data_parallel_rank
,
"main params"
:
self
.
get_main_params
(),
# "model params / world" : self.get_world_model_params(),
**
{
"gbuf_view_items / %d"
%
i
:
v
[
2
]
for
i
,
v
in
enumerate
(
gbuf_view_items
)},
# "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# "model params / local" : self.get_local_model_param_views(),
})
def
_collect_main_grad_data_for_unscaling
(
self
):
def
_collect_main_grad_data_for_unscaling
(
self
):
# return [ p.grad.data for p in self.main_param_shards ]
# return [ p.grad.data for p in self.main_param_shards if p is not None ]
# return [ self.get_main_grad(gi).data
# for gi in range(len(self.opt_group_shards)) ]
return
[
g
.
data
for
g
in
self
.
get_main_grads
()
]
return
[
g
.
data
for
g
in
self
.
get_main_grads
()
]
def
_copy_model_params_to_main_params
(
self
):
def
_copy_model_params_to_main_params
(
self
):
...
@@ -1319,19 +1326,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1319,19 +1326,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
model_view
.
detach
().
copy_
(
main_view
)
model_view
.
detach
().
copy_
(
main_view
)
# Debug.
# Debug.
pax
(
1
,
{
#
pax(1, {
"group_index"
:
group_index
,
#
"group_index" : group_index,
"group_shard"
:
group_shard
,
#
"group_shard" : group_shard,
"model_param"
:
tp
(
model_param
),
#
"model_param" : tp(model_param),
"model_index"
:
model_index
,
#
"model_index" : model_index,
"dtype"
:
str
(
dtype
),
#
"dtype" : str(dtype),
"model_param"
:
tp
(
model_param
),
#
"model_param" : tp(model_param),
"main_param"
:
tp
(
main_param
),
#
"main_param" : tp(main_param),
"model_view"
:
tp
(
model_view
),
#
"model_view" : tp(model_view),
"main_view"
:
tp
(
main_view
),
#
"main_view" : tp(main_view),
"model_shard"
:
str
(
model_shard
),
#
"model_shard" : str(model_shard),
"main_shard"
:
str
(
main_shard
),
#
"main_shard" : str(main_shard),
})
#
})
# pax(0, {
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "model_gbuf_shards" : self.model_gbuf_shards,
...
@@ -1347,12 +1354,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1347,12 +1354,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
"is_nan"
:
is_nan
,
"is_nan"
:
is_nan
,
})
})
# model_param_views = self.get_local_model_param_views()
if
ITERATION
==
DEBUG_ITERATION
:
# pax(1, {
pax
(
0
,
{
# "ITERATION" : ITERATION,
"** branch **"
:
"** fix. **"
,
# "main params" : self.get_main_params(),
"ITERATION"
:
ITERATION
,
# "model params / local" : self.get_local_model_param_views(),
# "main params" : self.get_main_params(),
# })
# "model params / local" : self.get_local_model_param_views(),
"model params"
:
[
p
for
m
in
self
.
models
for
p
in
m
.
parameters
()],
})
# <<<
# <<<
# <<<
# <<<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment