Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
12d91733
"...text-generation-inference.git" did not exist on "09674e6df99918dad5169b5cc81241c884c543d0"
Commit
12d91733
authored
Mar 18, 2022
by
Lawrence McAfee
Browse files
split reduce_grads/gather_params using gbuf sub-views.
parent
e30ad67e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
264 additions
and
11 deletions
+264
-11
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+263
-11
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+1
-0
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
12d91733
...
...
@@ -230,14 +230,15 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
get_model_parallel_group
(
self
):
return
None
# @staticmethod
# def has_nan_debug(tensors):
# if isinstance(tensors, torch.Tensor):
# tensors = [ tensors ]
# assert isinstance(tensors, list)
# has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
# has_nan = any(has_nans)
# return has_nan
# >>>
@
staticmethod
def
has_nan_debug
(
tensors
):
if
isinstance
(
tensors
,
torch
.
Tensor
):
tensors
=
[
tensors
]
assert
isinstance
(
tensors
,
list
)
has_nans
=
[
(
not
torch
.
all
(
torch
.
isfinite
(
t
)).
item
())
for
t
in
tensors
]
has_nan
=
any
(
has_nans
)
return
has_nan
# def get_local_model_param_views(self):
# '''** FOR DEBUGGING. **'''
# model_param_views = []
...
...
@@ -269,6 +270,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# def get_world_model_grads(self):
# '''** FOR DEBUGGING. **'''
# return [ p.main_grad for p in self.get_world_model_params() ]
# <<<
def
get_main_params
(
self
):
return
[
g
[
"params"
][
0
]
for
g
in
self
.
optimizer
.
param_groups
]
...
...
@@ -327,6 +329,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Distributed optimizer requires contiguous buffer; don't set to None.
_zero_grad_group_helper
(
model_params
,
set_to_none
=
False
)
# >>>
def
get_model_grad_buffer_dp_views
(
self
):
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
...
...
@@ -343,8 +346,48 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf_views
))
return
gbuf_view_items
def
get_model_grad_buffer_dp_views_SUB
(
self
,
sub_view_numel
):
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
sub_view_items
=
[]
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
# ** Sanity check. ** (should be unnecessary; see comment above)
view_numel
=
gbuf_views
[
0
].
nelement
()
for
view
in
gbuf_views
:
assert
view
.
nelement
()
==
view_numel
for
start_index
in
range
(
0
,
view_numel
,
sub_view_numel
):
end_index
=
min
(
view_numel
,
start_index
+
sub_view_numel
)
sub_views
=
[
t
[
start_index
:
end_index
]
for
t
in
gbuf_views
]
sub_view_items
.
append
((
model_index
,
dtype
,
sub_views
))
# >>>
from
lutil
import
pax
pax
(
0
,
{
"gbuf_view_items"
:
[(
a
,
b
,
c
.
shape
)
for
a
,
b
,
c
in
gbuf_view_items
],
"sub_view_items"
:
[(
a
,
b
,
c
.
shape
)
for
a
,
b
,
c
in
sub_view_items
],
})
# <<<
return
sub_view_items
# def get_model_grad_buffers_SINGLE(self):
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# # Grad buffers.
# gbuf_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf in model._grad_buffers.items():
def
reduce_model_grads
(
self
,
args
,
timers
):
# assert gbuf.numel_padded % data_parallel_world_size == 0
# shard_size = int(gbuf.numel_padded / data_parallel_world_size)
# gbuf_items.append((model_index, dtype, gbuf.data))
# return gbuf_items
# <<<
# >>>
def
reduce_model_grads_0
(
self
,
args
,
timers
):
'''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
grads.
...
...
@@ -371,9 +414,44 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group
=
data_parallel_group
,
)
timers
(
'backward-params-all-reduce'
).
stop
()
def
reduce_model_grads_1
(
self
,
args
,
timers
):
'''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
grads.
'''
def
gather_model_params
(
self
,
args
,
timers
):
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
self
.
allreduce_embedding_grads
(
args
)
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Reduce-scatter all grads.
timers
(
'backward-params-all-reduce'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
sub_numel
=
1
*
1048576
sub_view_items
=
self
.
get_model_grad_buffer_dp_views_SUB
(
sub_numel
)
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
gbuf
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
gbuf
/=
data_parallel_world_size
torch
.
distributed
.
reduce_scatter
(
gbuf_views
[
data_parallel_rank
],
gbuf_views
,
group
=
data_parallel_group
,
)
timers
(
'backward-params-all-reduce'
).
stop
()
def
reduce_model_grads
(
self
,
*
args
):
# >>>
# return
# <<<
# self.reduce_model_grads_0(*args)
self
.
reduce_model_grads_1
(
*
args
)
# <<<
# >>>
def
gather_model_params_0
(
self
,
args
,
timers
):
timers
(
'backward-params-all-gather'
).
start
()
...
...
@@ -397,7 +475,181 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param
.
detach
().
copy_
(
param
.
main_grad
)
timers
(
'backward-params-all-gather'
).
stop
()
def
gather_model_params_1
(
self
,
args
,
timers
):
timers
(
'backward-params-all-gather'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
# All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements
# across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
# sub_view_numel = 1 * 1024
# sub_view_numel = 1 * 131072
sub_view_numel
=
1
*
1048576
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
# ** Sanity check. ** (should be unnecessary; see comment above)
view_numel
=
gbuf_views
[
0
].
nelement
()
for
view
in
gbuf_views
:
assert
view
.
nelement
()
==
view_numel
for
start_index
in
range
(
0
,
view_numel
,
sub_view_numel
):
end_index
=
min
(
view_numel
,
start_index
+
sub_view_numel
)
sub_views
=
[
t
[
start_index
:
end_index
]
for
t
in
gbuf_views
]
torch
.
distributed
.
all_gather
(
sub_views
,
sub_views
[
data_parallel_rank
],
group
=
data_parallel_group
,
)
# Each model param now contains its updated values in its
# '.main_grad' field.
for
model
in
self
.
models
:
for
dtype
,
param_map
in
model
.
_grad_buffer_param_index_map
.
items
():
for
param
in
param_map
:
param
.
detach
().
copy_
(
param
.
main_grad
)
timers
(
'backward-params-all-gather'
).
stop
()
# def gather_model_params_2(self, args, timers):
# raise Exception("_all_gather_base not applicable when each DP rank owns contiguous range of grad buffer.")
# timers('backward-params-all-gather').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_group = mpu.get_data_parallel_group()
# # All-gather updated main params.
# # - All grad buffer views are guaranteed to have the same num elements
# # across all data parallel ranks, with grad buffer padding that is done
# # in distributed.py. Thus, all sub-views will have consistent start/end
# # indexes across data parallel ranks.
# gbuf_items = self.get_model_grad_buffers_SINGLE()
# # local_sub_numel = 1 * 1024
# # local_sub_numel = 1 * 131072
# ideal_local_numel = 128 * 1048576
# ideal_world_numel = data_parallel_world_size * ideal_local_numel
# for model_index, dtype, gbuf in gbuf_items:
# gbuf_numel = gbuf.nelement()
# # >>>
# # from lutil import pax
# # pax(0, {
# # "gbuf_items" : [ (a, b, c.shape) for a, b, c in gbuf_items ],
# # "gbuf" : str(gbuf.shape),
# # "gbuf_numel" : gbuf_numel,
# # "local_sub_numel" : local_sub_numel,
# # "world_sub_numel" : world_sub_numel,
# # })
# # <<<
# for world_start_index in range(0, gbuf_numel, ideal_world_numel):
# world_end_index = \
# min(gbuf_numel, world_start_index + ideal_world_numel)
# world_numel = world_end_index - world_start_index
# assert world_numel % data_parallel_world_size == 0
# local_numel = int(world_numel / data_parallel_world_size)
# local_start_index = \
# world_start_index + data_parallel_rank * local_numel
# local_end_index = \
# min(gbuf_numel, local_start_index + local_numel)
# try:
# world_view = gbuf[world_start_index:world_end_index]
# local_view = gbuf[local_start_index:local_end_index]
# except:
# # >>>
# from lutil import pax
# pax(0, {
# "world_start_index" : world_start_index,
# "world_end_index" : world_end_index,
# "local_start_index" : local_start_index,
# "local_end_index" : local_end_index,
# })
# # <<<
# try:
# torch.distributed._all_gather_base(
# world_view,
# local_view,
# group = data_parallel_group,
# )
# except:
# # >>>
# from lutil import pax
# pax(0, {
# "data_parallel_rank" : data_parallel_rank,
# # "local_sub_numel" : local_sub_numel,
# # "world_sub_numel" : world_sub_numel,
# "world_start_index" : world_start_index,
# "world_end_index" : world_end_index,
# "local_start_index" : local_start_index,
# "local_end_index" : local_end_index,
# "gbuf" : str(gbuf.shape),
# "world_view" : str(world_view.shape),
# "local_view" : str(local_view.shape),
# "local_sub_numel / ideal" : local_sub_numel,
# "local_sub_numel / act" :
# local_end_index - local_start_index,
# })
# # <<<
# # >>>
# # from lutil import pax, tp
# # pax(0, {
# # # "gbuf" : tp(gbuf),
# # "world range" : "%d, %d"%(world_start_index, world_end_index),
# # "local range" : "%d, %d"%(local_start_index, local_end_index),
# # "world_view" : tp(world_view),
# # "local_view" : tp(local_view),
# # "gbuf view" : tp(gbuf[world_start_index:world_end_index]),
# # })
# # <<<
# # >>>
# for model_index, dtype, gbuf in gbuf_items:
# if self.has_nan_debug(gbuf):
# raise Exception("hi.")
# # from lutil import pax, tp
# # pax(0, {
# # "gbuf_items" : [ (a, b, tp(c)) for a, b, c in gbuf_items ],
# # })
# # <<<
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
# # >>>
# if self.has_nan_debug(param):
# raise Exception("wha?")
# # <<<
# timers('backward-params-all-gather').stop()
def
gather_model_params
(
self
,
*
args
):
# >>>
return
# <<<
# self.gather_model_params_0(*args)
self
.
gather_model_params_1
(
*
args
)
# self.gather_model_params_2(*args)
# ~~~
# self.debug_model(0, "after / gather_model_params", 0)
# <<<
def
_collect_main_grad_data_for_unscaling
(
self
):
return
[
g
.
data
for
g
in
self
.
get_main_grads
()
]
...
...
megatron/optimizer/optimizer.py
View file @
12d91733
...
...
@@ -327,6 +327,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
from
megatron
import
get_args
args
=
get_args
()
my_rank
=
torch
.
distributed
.
get_rank
()
DEBUG_ITERATION
=
ITERATION
if
ITERATION
!=
DEBUG_ITERATION
:
return
for
r
in
range
(
torch
.
distributed
.
get_world_size
()):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment