Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3d433e8a
Commit
3d433e8a
authored
Mar 18, 2022
by
Lawrence McAfee
Browse files
cleanup.
parent
19730725
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
285 deletions
+12
-285
megatron/arguments.py
megatron/arguments.py
+1
-1
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+11
-284
No files found.
megatron/arguments.py
View file @
3d433e8a
...
...
@@ -737,7 +737,7 @@ def _add_distributed_args(parser):
help
=
'Trade-off memory savings & iteration time, for '
'disributed optimizer
\'
s communication operations (i.e., '
'(reduce/gather). This value ranges from 0.0 (default, '
'no memory savings) to 1.0 (max memory savings, at '
'no memory savings) to 1.0 (max memory savings, at
the
'
'expense of iteration time).'
)
return
parser
...
...
megatron/optimizer/distrib_optimizer.py
View file @
3d433e8a
...
...
@@ -231,14 +231,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
return
None
# >>>
@
staticmethod
def
has_nan_debug
(
tensors
):
if
isinstance
(
tensors
,
torch
.
Tensor
):
tensors
=
[
tensors
]
assert
isinstance
(
tensors
,
list
)
has_nans
=
[
(
not
torch
.
all
(
torch
.
isfinite
(
t
)).
item
())
for
t
in
tensors
]
has_nan
=
any
(
has_nans
)
return
has_nan
#
@staticmethod
#
def has_nan_debug(tensors):
#
if isinstance(tensors, torch.Tensor):
#
tensors = [ tensors ]
#
assert isinstance(tensors, list)
#
has_nans = [ (not torch.all(torch.isfinite(t)).item()) for t in tensors ]
#
has_nan = any(has_nans)
#
return has_nan
# def get_local_model_param_views(self):
# '''** FOR DEBUGGING. **'''
# model_param_views = []
...
...
@@ -329,7 +329,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Distributed optimizer requires contiguous buffer; don't set to None.
_zero_grad_group_helper
(
model_params
,
set_to_none
=
False
)
# >>>
def
get_model_grad_buffer_dp_views
(
self
):
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
...
...
@@ -349,6 +348,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
get_model_grad_buffer_dp_views_chunked
(
self
,
mem_savings_factor
):
# Iterate grad buffers & chunk.
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
chunk_view_items
=
[]
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
...
...
@@ -358,79 +358,22 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for
view
in
gbuf_views
:
assert
view
.
nelement
()
==
view_numel
# chunk_numel_min = 1024
# chunk_numel_min = 16384
# Compute chunk size (via savings factor).
chunk_numel_min
=
131072
# chunk_numel_min = 1048576
chunk_numel_max
=
view_numel
# chunk_numel_min_log = math.log(chunk_numel_min)
# chunk_numel_max_log = math.log(chunk_numel_max)
# chunk_numel_log = (chunk_numel_min_log + chunk_numel_max_log) / 2
# chunk_numel = int(math.exp(chunk_numel_log))
chunk_numel
=
int
(
mem_savings_factor
*
chunk_numel_min
+
(
1
-
mem_savings_factor
)
*
chunk_numel_max
)
# >>>
# from lutil import pax
# pax(0, {
# "view_numel" : view_numel,
# "chunk_numel_min" : chunk_numel_min,
# "chunk_numel_max" : chunk_numel_max,
# "chunk_numel_min_log" : chunk_numel_min_log,
# "chunk_numel_max_log" : chunk_numel_max_log,
# "chunk_numel_log" : chunk_numel_log,
# "chunk_numel" : chunk_numel,
# "mem_savings_factor" : mem_savings_factor,
# })
# <<<
# Chunk views.
for
start_index
in
range
(
0
,
view_numel
,
chunk_numel
):
end_index
=
min
(
view_numel
,
start_index
+
chunk_numel
)
chunk_views
=
[
t
[
start_index
:
end_index
]
for
t
in
gbuf_views
]
chunk_view_items
.
append
((
model_index
,
dtype
,
chunk_views
))
# >>>
# from lutil import pax
# pax(0, {
# "gbuf_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in gbuf_view_items],
# "chunk_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in chunk_view_items],
# })
# <<<
return
chunk_view_items
# <<<
# >>>
# def reduce_model_grads_0(self, args, timers):
# '''Note: this is a different order of reduction, versus the non-
# distributed optimizer, which reduces: 1) all grads, 2) embedding
# grads.
# '''
# # All-reduce embedding grads.
# timers('backward-embedding-all-reduce').start()
# self.allreduce_embedding_grads(args)
# timers('backward-embedding-all-reduce').stop()
# # Reduce-scatter all grads.
# timers('backward-params-all-reduce').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_group = mpu.get_data_parallel_group()
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# for model_index, dtype, gbuf_views in gbuf_view_items:
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf /= data_parallel_world_size
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# timers('backward-params-all-reduce').stop()
# def reduce_model_grads_1(self, args, timers):
def
reduce_model_grads
(
self
,
args
,
timers
):
'''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
...
...
@@ -458,91 +401,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_view_items
=
\
self
.
get_model_grad_buffer_dp_views_chunked
(
mem_savings_factor
)
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf /= data_parallel_world_size
torch
.
distributed
.
reduce_scatter
(
gbuf_views
[
data_parallel_rank
],
gbuf_views
,
group
=
data_parallel_group
,
)
timers
(
'backward-params-all-reduce'
).
stop
()
# def reduce_model_grads(self, *args):
# # >>>
# return
# # <<<
# # self.reduce_model_grads_0(*args)
# self.reduce_model_grads_1(*args)
# <<<
# >>>
# def gather_model_params_0(self, args, timers):
# timers('backward-params-all-gather').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# # All-gather updated main params.
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# for model_index, dtype, gbuf_views in gbuf_view_items:
# torch.distributed.all_gather(
# gbuf_views,
# gbuf_views[data_parallel_rank],
# group = data_parallel_group,
# )
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
# timers('backward-params-all-gather').stop()
# def gather_model_params_1(self, args, timers):
# timers('backward-params-all-gather').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# # All-gather updated main params.
# # - All grad buffer views are guaranteed to have the same num elements
# # across all data parallel ranks, with grad buffer padding that is done
# # in distributed.py. Thus, all sub-views will have consistent start/end
# # indexes across data parallel ranks.
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# # sub_view_numel = 1 * 1024
# # sub_view_numel = 1 * 131072
# sub_view_numel = 256 * 1048576
# for model_index, dtype, gbuf_views in gbuf_view_items:
# # ** Sanity check. ** (should be unnecessary; see comment above)
# view_numel = gbuf_views[0].nelement()
# for view in gbuf_views:
# assert view.nelement() == view_numel
# for start_index in range(0, view_numel, sub_view_numel):
# end_index = min(view_numel, start_index + sub_view_numel)
# sub_views = [ t[start_index:end_index] for t in gbuf_views ]
# torch.distributed.all_gather(
# sub_views,
# sub_views[data_parallel_rank],
# group = data_parallel_group,
# )
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
# timers('backward-params-all-gather').stop()
# def gather_model_params_1(self, args, timers):
def
gather_model_params
(
self
,
args
,
timers
):
timers
(
'backward-params-all-gather'
).
start
()
...
...
@@ -556,11 +420,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
# sub_numel = 1 * 1024
# sub_numel = 1 * 131072
# sub_numel = 1024 * 1048576
# gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
gbuf_view_items
=
\
self
.
get_model_grad_buffer_dp_views_chunked
(
mem_savings_factor
)
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
...
...
@@ -578,138 +437,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
param
.
detach
().
copy_
(
param
.
main_grad
)
timers
(
'backward-params-all-gather'
).
stop
()
# def gather_model_params_2(self, args, timers):
# raise Exception("_all_gather_base not applicable when each DP rank owns contiguous range of grad buffer.")
# timers('backward-params-all-gather').start()
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# data_parallel_group = mpu.get_data_parallel_group()
# # All-gather updated main params.
# # - All grad buffer views are guaranteed to have the same num elements
# # across all data parallel ranks, with grad buffer padding that is done
# # in distributed.py. Thus, all sub-views will have consistent start/end
# # indexes across data parallel ranks.
# gbuf_items = self.get_model_grad_buffers_SINGLE()
# # local_sub_numel = 1 * 1024
# # local_sub_numel = 1 * 131072
# ideal_local_numel = 128 * 1048576
# ideal_world_numel = data_parallel_world_size * ideal_local_numel
# for model_index, dtype, gbuf in gbuf_items:
# gbuf_numel = gbuf.nelement()
# # >>>
# # from lutil import pax
# # pax(0, {
# # "gbuf_items" : [ (a, b, c.shape) for a, b, c in gbuf_items ],
# # "gbuf" : str(gbuf.shape),
# # "gbuf_numel" : gbuf_numel,
# # "local_sub_numel" : local_sub_numel,
# # "world_sub_numel" : world_sub_numel,
# # })
# # <<<
# for world_start_index in range(0, gbuf_numel, ideal_world_numel):
# world_end_index = \
# min(gbuf_numel, world_start_index + ideal_world_numel)
# world_numel = world_end_index - world_start_index
# assert world_numel % data_parallel_world_size == 0
# local_numel = int(world_numel / data_parallel_world_size)
# local_start_index = \
# world_start_index + data_parallel_rank * local_numel
# local_end_index = \
# min(gbuf_numel, local_start_index + local_numel)
# try:
# world_view = gbuf[world_start_index:world_end_index]
# local_view = gbuf[local_start_index:local_end_index]
# except:
# # >>>
# from lutil import pax
# pax(0, {
# "world_start_index" : world_start_index,
# "world_end_index" : world_end_index,
# "local_start_index" : local_start_index,
# "local_end_index" : local_end_index,
# })
# # <<<
# try:
# torch.distributed._all_gather_base(
# world_view,
# local_view,
# group = data_parallel_group,
# )
# except:
# # >>>
# from lutil import pax
# pax(0, {
# "data_parallel_rank" : data_parallel_rank,
# # "local_sub_numel" : local_sub_numel,
# # "world_sub_numel" : world_sub_numel,
# "world_start_index" : world_start_index,
# "world_end_index" : world_end_index,
# "local_start_index" : local_start_index,
# "local_end_index" : local_end_index,
# "gbuf" : str(gbuf.shape),
# "world_view" : str(world_view.shape),
# "local_view" : str(local_view.shape),
# "local_sub_numel / ideal" : local_sub_numel,
# "local_sub_numel / act" :
# local_end_index - local_start_index,
# })
# # <<<
# # >>>
# # from lutil import pax, tp
# # pax(0, {
# # # "gbuf" : tp(gbuf),
# # "world range" : "%d, %d"%(world_start_index, world_end_index),
# # "local range" : "%d, %d"%(local_start_index, local_end_index),
# # "world_view" : tp(world_view),
# # "local_view" : tp(local_view),
# # "gbuf view" : tp(gbuf[world_start_index:world_end_index]),
# # })
# # <<<
# # >>>
# for model_index, dtype, gbuf in gbuf_items:
# if self.has_nan_debug(gbuf):
# raise Exception("hi.")
# # from lutil import pax, tp
# # pax(0, {
# # "gbuf_items" : [ (a, b, tp(c)) for a, b, c in gbuf_items ],
# # })
# # <<<
# # Each model param now contains its updated values in its
# # '.main_grad' field.
# for model in self.models:
# for dtype, param_map in model._grad_buffer_param_index_map.items():
# for param in param_map:
# param.detach().copy_(param.main_grad)
# # >>>
# if self.has_nan_debug(param):
# raise Exception("wha?")
# # <<<
# timers('backward-params-all-gather').stop()
# def gather_model_params(self, *args):
# # >>>
# # return
# # <<<
# # self.gather_model_params_0(*args)
# self.gather_model_params_1(*args)
# # self.gather_model_params_2(*args)
# # ~~~
# # self.debug_model(0, "after / gather_model_params", 0)
# <<<
def
_collect_main_grad_data_for_unscaling
(
self
):
return
[
g
.
data
for
g
in
self
.
get_main_grads
()
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment