Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
64b94f00
Commit
64b94f00
authored
Mar 22, 2022
by
Lawrence McAfee
Browse files
setup code to try _reduce_scatter_base, _all_gather_base.
parent
6728a780
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
71 additions
and
20 deletions
+71
-20
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+13
-8
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+58
-12
No files found.
megatron/optimizer/clip_grads.py
View file @
64b94f00
...
@@ -72,6 +72,11 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
...
@@ -72,6 +72,11 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
# grads_for_norm.append(grad)
# grads_for_norm.append(grad)
# <<<
# <<<
# >>>
# Grads.
grads
=
[
p
.
grad
for
p
in
parameters
if
p
is
not
None
]
# <<<
# Norm parameters.
# Norm parameters.
max_norm
=
float
(
max_norm
)
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
norm_type
=
float
(
norm_type
)
...
@@ -115,14 +120,14 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
...
@@ -115,14 +120,14 @@ def clip_grad_norm_fp32(parameters, grads_for_norm,
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
# >>>
# >>>
from
megatron
import
get_args
#
from megatron import get_args
from
lutil
import
pax
#
from lutil import pax
args
=
get_args
()
#
args = get_args()
pax
(
0
,
{
#
pax(0, {
"use distrib opt"
:
args
.
use_distributed_optimizer
,
#
"use distrib opt" : args.use_distributed_optimizer,
"norm_type"
:
norm_type
,
#
"norm_type" : norm_type,
"total_norm"
:
total_norm
,
#
"total_norm" : total_norm,
})
#
})
# <<<
# <<<
# Scale.
# Scale.
...
...
megatron/optimizer/distrib_optimizer.py
View file @
64b94f00
...
@@ -413,6 +413,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -413,6 +413,19 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf_views
))
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf_views
))
return
gbuf_view_items
return
gbuf_view_items
# >>>
def
get_model_grad_buffer_dp_views_SINGLE
(
self
):
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
# Grad buffer views.
gbuf_items
=
[]
for
model_index
,
model
in
enumerate
(
self
.
models
):
for
dtype
,
gbuf
in
model
.
_grad_buffers
.
items
():
gbuf_items
.
append
((
model_index
,
dtype
,
gbuf
.
data
))
return
gbuf_items
# <<<
def
get_model_grad_buffer_dp_views_chunked
(
self
,
mem_savings_factor
):
def
get_model_grad_buffer_dp_views_chunked
(
self
,
mem_savings_factor
):
...
@@ -466,14 +479,36 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -466,14 +479,36 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf
.
data
/=
data_parallel_world_size
gbuf
.
data
/=
data_parallel_world_size
# Reduce scatter all grads.
# Reduce scatter all grads.
gbuf_view_items
=
\
# >>>
self
.
get_model_grad_buffer_dp_views_chunked
(
mem_savings_factor
)
# gbuf_view_items = \
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
# self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
torch
.
distributed
.
reduce_scatter
(
# for model_index, dtype, gbuf_views in gbuf_view_items:
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# +++
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
gbuf_view_items_SINGLE
=
self
.
get_model_grad_buffer_dp_views_SINGLE
()
for
index
,
(
model_index
,
dtype
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
# >>>
pax
(
0
,
{
"gbuf_view"
:
gbuf_views
[
data_parallel_rank
].
shape
,
"gbuf SINGLE"
:
gbuf_view_items_SINGLE
[
index
][
2
].
shape
,
})
# <<<
torch
.
distributed
.
_reduce_scatter_base
(
gbuf_views
[
data_parallel_rank
],
gbuf_views
[
data_parallel_rank
],
gbuf_view
s
,
gbuf_view
_items_SINGLE
[
index
][
2
]
,
group
=
data_parallel_group
,
group
=
data_parallel_group
,
)
)
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# <<<
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
def
gather_model_params
(
self
,
args
,
timers
):
def
gather_model_params
(
self
,
args
,
timers
):
...
@@ -489,14 +524,25 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -489,14 +524,25 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# across all data parallel ranks, with grad buffer padding that is done
# across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
# indexes across data parallel ranks.
gbuf_view_items
=
\
# >>>
self
.
get_model_grad_buffer_dp_views_chunked
(
mem_savings_factor
)
# gbuf_view_items = \
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
# self.get_model_grad_buffer_dp_views_chunked(mem_savings_factor)
torch
.
distributed
.
all_gather
(
# for model_index, dtype, gbuf_views in gbuf_view_items:
gbuf_views
,
# torch.distributed.all_gather(
gbuf_views
[
data_parallel_rank
],
# gbuf_views,
group
=
data_parallel_group
,
# gbuf_views[data_parallel_rank],
# group = data_parallel_group,
# )
# +++
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
gbuf_view_items_SINGLE
=
self
.
get_model_grad_buffer_dp_views_SINGLE
()
for
index
,
(
model_index
,
dtype
,
gbuf_views
)
in
enumerate
(
gbuf_view_items
):
torch
.
distributed
.
_all_gather_base
(
# gbuf_view_items_SINGLE[index][2],
# gbuf_views[data_parallel_rank],
# group = data_parallel_group,
)
)
# <<<
# Each model param now contains its updated values in its
# Each model param now contains its updated values in its
# '.main_grad' field.
# '.main_grad' field.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment