Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ed90a1b2
Commit
ed90a1b2
authored
Mar 18, 2022
by
Lawrence McAfee
Browse files
mem savings factor training; need to check loss.
parent
06b9ebe0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
170 additions
and
145 deletions
+170
-145
megatron/arguments.py
megatron/arguments.py
+1
-1
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+114
-89
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+55
-55
No files found.
megatron/arguments.py
View file @
ed90a1b2
...
...
@@ -733,7 +733,7 @@ def _add_distributed_args(parser):
'affects the encoder embedding.)'
)
group
.
add_argument
(
'--use-distributed-optimizer'
,
action
=
'store_true'
,
help
=
'Use distributed optimizer.'
)
group
.
add_argument
(
'--disrib-opt-comm-mem-savings'
,
default
=
0.
,
type
=
float
,
group
.
add_argument
(
'--dis
t
rib-opt-comm-mem-savings'
,
default
=
0.
,
type
=
float
,
help
=
'Trade-off memory savings & iteration time, for '
'disributed optimizer
\'
s communication operations (i.e., '
'(reduce/gather). This value ranges from 0.0 (default, '
...
...
megatron/optimizer/distrib_optimizer.py
View file @
ed90a1b2
...
...
@@ -346,9 +346,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf_views
))
return
gbuf_view_items
def
get_model_grad_buffer_dp_views_SUB
(
self
,
sub_view_numel
):
def
get_model_grad_buffer_dp_views_chunked
(
self
,
mem_savings_factor
):
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
sub
_view_items
=
[]
chunk
_view_items
=
[]
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
# ** Sanity check. ** (should be unnecessary; see comment above)
...
...
@@ -356,65 +358,77 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
for
view
in
gbuf_views
:
assert
view
.
nelement
()
==
view_numel
for
start_index
in
range
(
0
,
view_numel
,
sub_view_numel
):
end_index
=
min
(
view_numel
,
start_index
+
sub_view_numel
)
sub_views
=
[
t
[
start_index
:
end_index
]
for
t
in
gbuf_views
]
sub_view_items
.
append
((
model_index
,
dtype
,
sub_views
))
chunk_numel_min
=
1024
**
2
chunk_numel_max
=
view_numel
# chunk_numel_min_log = math.log(chunk_numel_min)
# chunk_numel_max_log = math.log(chunk_numel_max)
# chunk_numel_log = (chunk_numel_min_log + chunk_numel_max_log) / 2
# chunk_numel = int(math.exp(chunk_numel_log))
chunk_numel
=
int
(
mem_savings_factor
*
chunk_numel_min
+
(
1
-
mem_savings_factor
)
*
chunk_numel_max
)
# >>>
# from lutil import pax
# pax(0, {
# "view_numel" : view_numel,
# "chunk_numel_min" : chunk_numel_min,
# "chunk_numel_max" : chunk_numel_max,
# "chunk_numel_min_log" : chunk_numel_min_log,
# "chunk_numel_max_log" : chunk_numel_max_log,
# "chunk_numel_log" : chunk_numel_log,
# "chunk_numel" : chunk_numel,
# "mem_savings_factor" : mem_savings_factor,
# })
# <<<
for
start_index
in
range
(
0
,
view_numel
,
chunk_numel
):
end_index
=
min
(
view_numel
,
start_index
+
chunk_numel
)
chunk_views
=
[
t
[
start_index
:
end_index
]
for
t
in
gbuf_views
]
chunk_view_items
.
append
((
model_index
,
dtype
,
chunk_views
))
# >>>
# from lutil import pax
# pax(0, {
# "gbuf_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in gbuf_view_items],
# "
sub
_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in
sub
_view_items],
# "
chunk
_view_items" : [(a,b,"%d / %s" % (len(c), [ d.nelement() for d in c ])) for a,b,c in
chunk
_view_items],
# })
# <<<
return
sub_view_items
# def get_model_grad_buffers_SINGLE(self):
# data_parallel_world_size = mpu.get_data_parallel_world_size()
# # Grad buffers.
# gbuf_items = []
# for model_index, model in enumerate(self.models):
# for dtype, gbuf in model._grad_buffers.items():
# assert gbuf.numel_padded % data_parallel_world_size == 0
# shard_size = int(gbuf.numel_padded / data_parallel_world_size)
# gbuf_items.append((model_index, dtype, gbuf.data))
# return gbuf_items
return
chunk_view_items
# <<<
# >>>
def
reduce_model_grads_0
(
self
,
args
,
timers
):
'''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
grads.
'''
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
self
.
allreduce_embedding_grads
(
args
)
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Reduce-scatter all grads.
timers
(
'backward-params-all-reduce'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
#
def reduce_model_grads_0(self, args, timers):
#
'''Note: this is a different order of reduction, versus the non-
#
distributed optimizer, which reduces: 1) all grads, 2) embedding
#
grads.
#
'''
#
# All-reduce embedding grads.
#
timers('backward-embedding-all-reduce').start()
#
self.allreduce_embedding_grads(args)
#
timers('backward-embedding-all-reduce').stop()
#
# Reduce-scatter all grads.
#
timers('backward-params-all-reduce').start()
#
data_parallel_rank = mpu.get_data_parallel_rank()
#
data_parallel_world_size = mpu.get_data_parallel_world_size()
#
data_parallel_group = mpu.get_data_parallel_group()
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
gbuf
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
gbuf
/=
data_parallel_world_size
torch
.
distributed
.
reduce_scatter
(
gbuf_views
[
data_parallel_rank
],
gbuf_views
,
group
=
data_parallel_group
,
)
timers
(
'backward-params-all-reduce'
).
stop
()
def
reduce_model_grads_1
(
self
,
args
,
timers
):
# gbuf_view_items = self.get_model_grad_buffer_dp_views()
# for model_index, dtype, gbuf_views in gbuf_view_items:
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf /= data_parallel_world_size
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# timers('backward-params-all-reduce').stop()
# def reduce_model_grads_1(self, args, timers):
def
reduce_model_grads
(
self
,
args
,
timers
):
'''Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) all grads, 2) embedding
grads.
...
...
@@ -425,14 +439,21 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
self
.
allreduce_embedding_grads
(
args
)
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Reduce-scatter
all grads
.
# Reduce-scatter
setup
.
timers
(
'backward-params-all-reduce'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
mem_savings_factor
=
args
.
distrib_opt_comm_mem_savings
sub_numel
=
1
*
1048576
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views_SUB
(
sub_numel
)
# Scale grad buffers by '1 / data_parallel_world_size'.
for
model
in
self
.
models
:
for
dtype
,
gbuf
in
model
.
_grad_buffers
.
items
():
gbuf
.
data
/=
data_parallel_world_size
# Reduce scatter all grads.
gbuf_view_items
=
\
self
.
get_model_grad_buffer_dp_views_chunked
(
mem_savings_factor
)
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
# gbuf = self.models[model_index]._grad_buffers[dtype].data
# gbuf /= data_parallel_world_size
...
...
@@ -442,39 +463,39 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group
=
data_parallel_group
,
)
timers
(
'backward-params-all-reduce'
).
stop
()
def
reduce_model_grads
(
self
,
*
args
):
# >>>
return
# <<<
# self.reduce_model_grads_0(*args)
self
.
reduce_model_grads_1
(
*
args
)
#
def reduce_model_grads(self, *args):
#
# >>>
#
return
#
# <<<
#
# self.reduce_model_grads_0(*args)
#
self.reduce_model_grads_1(*args)
# <<<
# >>>
def
gather_model_params_0
(
self
,
args
,
timers
):
#
def gather_model_params_0(self, args, timers):
timers
(
'backward-params-all-gather'
).
start
()
#
timers('backward-params-all-gather').start()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
#
data_parallel_rank = mpu.get_data_parallel_rank()
#
data_parallel_group = mpu.get_data_parallel_group()
# All-gather updated main params.
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
torch
.
distributed
.
all_gather
(
gbuf_views
,
gbuf_views
[
data_parallel_rank
],
group
=
data_parallel_group
,
)
#
# All-gather updated main params.
#
gbuf_view_items = self.get_model_grad_buffer_dp_views()
#
for model_index, dtype, gbuf_views in gbuf_view_items:
#
torch.distributed.all_gather(
#
gbuf_views,
#
gbuf_views[data_parallel_rank],
#
group = data_parallel_group,
#
)
# Each model param now contains its updated values in its
# '.main_grad' field.
for
model
in
self
.
models
:
for
dtype
,
param_map
in
model
.
_grad_buffer_param_index_map
.
items
():
for
param
in
param_map
:
param
.
detach
().
copy_
(
param
.
main_grad
)
#
# Each model param now contains its updated values in its
#
# '.main_grad' field.
#
for model in self.models:
#
for dtype, param_map in model._grad_buffer_param_index_map.items():
#
for param in param_map:
#
param.detach().copy_(param.main_grad)
timers
(
'backward-params-all-gather'
).
stop
()
#
timers('backward-params-all-gather').stop()
# def gather_model_params_1(self, args, timers):
# timers('backward-params-all-gather').start()
...
...
@@ -518,12 +539,14 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# param.detach().copy_(param.main_grad)
# timers('backward-params-all-gather').stop()
def
gather_model_params_1
(
self
,
args
,
timers
):
# def gather_model_params_1(self, args, timers):
def
gather_model_params
(
self
,
args
,
timers
):
timers
(
'backward-params-all-gather'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
mem_savings_factor
=
args
.
distrib_opt_comm_mem_savings
# All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements
...
...
@@ -533,8 +556,10 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# sub_numel = 1 * 1024
# sub_numel = 1 * 131072
sub_numel
=
1024
*
1048576
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views_SUB
(
sub_numel
)
# sub_numel = 1024 * 1048576
# gbuf_view_items = self.get_model_grad_buffer_dp_views_SUB(sub_numel)
gbuf_view_items
=
\
self
.
get_model_grad_buffer_dp_views_chunked
(
mem_savings_factor
)
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
torch
.
distributed
.
all_gather
(
gbuf_views
,
...
...
@@ -671,16 +696,16 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# # <<<
# timers('backward-params-all-gather').stop()
def
gather_model_params
(
self
,
*
args
):
# >>>
# return
# <<<
# self.gather_model_params_0(*args)
self
.
gather_model_params_1
(
*
args
)
# self.gather_model_params_2(*args)
#
def gather_model_params(self, *args):
#
# >>>
#
# return
#
# <<<
#
# self.gather_model_params_0(*args)
#
self.gather_model_params_1(*args)
#
# self.gather_model_params_2(*args)
# ~~~
# self.debug_model(0, "after / gather_model_params", 0)
#
# ~~~
#
# self.debug_model(0, "after / gather_model_params", 0)
# <<<
def
_collect_main_grad_data_for_unscaling
(
self
):
...
...
megatron/optimizer/optimizer.py
View file @
ed90a1b2
...
...
@@ -322,61 +322,61 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
return
found_inf_flag
# >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
@
classmethod
def
debug_base
(
cls
,
ITERATION
,
key
,
value
):
from
megatron
import
get_args
args
=
get_args
()
my_rank
=
torch
.
distributed
.
get_rank
()
DEBUG_ITERATION
=
ITERATION
if
ITERATION
!=
DEBUG_ITERATION
:
return
for
r
in
range
(
torch
.
distributed
.
get_world_size
()):
if
my_rank
==
r
:
# prefix = " + "
prefix
=
""
print
(
"%sbr/%s; [r%d, i%d]; %s, %.12e"
%
(
prefix
,
"fix "
if
args
.
use_distributed_optimizer
else
"main"
,
my_rank
,
ITERATION
,
key
,
value
))
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
# if my_rank == 0:
# raise Exception("debug.")
# else:
# exit(0)
exit
(
0
)
def
debug_model
(
self
,
ITERATION
,
key
,
use_grad
):
use_grad
=
bool
(
use_grad
)
tensors
=
[
(
p
.
main_grad
.
float
()
if
use_grad
else
p
.
float
())
for
m
in
self
.
models
for
p
in
m
.
parameters
()
]
count
=
sum
(
t
.
nelement
()
for
t
in
tensors
)
return
self
.
debug_base
(
ITERATION
,
"model/%s, %s [count %d]"
%
(
"grad"
if
use_grad
else
"param"
,
key
,
count
,
),
# sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
sum
(
torch
.
sum
(
torch
.
abs
(
t
))
for
t
in
tensors
),
)
def
debug_main
(
self
,
ITERATION
,
key
,
use_grad
):
use_grad
=
bool
(
use_grad
)
tensors
=
[
p
.
grad
if
use_grad
else
p
for
g
in
self
.
optimizer
.
param_groups
for
p
in
g
[
"params"
]
]
tensors
=
[
t
.
float
()
for
t
in
tensors
]
count
=
sum
(
t
.
nelement
()
for
t
in
tensors
)
return
self
.
debug_base
(
ITERATION
,
"main/%s, %s [count %d]"
%
(
"grad"
if
use_grad
else
"param"
,
key
,
count
,
),
sum
(
torch
.
sum
(
torch
.
abs
(
t
))
for
t
in
tensors
),
)
#
@classmethod
#
def debug_base(cls, ITERATION, key, value):
#
from megatron import get_args
#
args = get_args()
#
my_rank = torch.distributed.get_rank()
#
DEBUG_ITERATION = ITERATION
#
if ITERATION != DEBUG_ITERATION:
#
return
#
for r in range(torch.distributed.get_world_size()):
#
if my_rank == r:
#
# prefix = " + "
#
prefix = ""
#
print("%sbr/%s; [r%d, i%d]; %s, %.12e" % (prefix, "fix " if args.use_distributed_optimizer else "main", my_rank, ITERATION, key, value))
#
torch.distributed.barrier()
#
torch.distributed.barrier()
#
# if my_rank == 0:
#
# raise Exception("debug.")
#
# else:
#
# exit(0)
#
exit(0)
#
def debug_model(self, ITERATION, key, use_grad):
#
use_grad = bool(use_grad)
#
tensors = [
#
(p.main_grad.float() if use_grad else p.float())
#
for m in self.models for p in m.parameters()
#
]
#
count = sum(t.nelement() for t in tensors)
#
return self.debug_base(
#
ITERATION,
#
"model/%s, %s [count %d]" % (
#
"grad" if use_grad else "param",
#
key,
#
count,
#
),
#
# sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
#
sum(torch.sum(torch.abs(t)) for t in tensors),
#
)
#
def debug_main(self, ITERATION, key, use_grad):
#
use_grad = bool(use_grad)
#
tensors = [
#
p.grad if use_grad else p
#
for g in self.optimizer.param_groups
#
for p in g["params"]
#
]
#
tensors = [ t.float() for t in tensors ]
#
count = sum(t.nelement() for t in tensors)
#
return self.debug_base(
#
ITERATION,
#
"main/%s, %s [count %d]" % (
#
"grad" if use_grad else "param",
#
key,
#
count,
#
),
#
sum(torch.sum(torch.abs(t)) for t in tensors),
#
)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@
torch
.
no_grad
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment