Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
c6d20c05
Commit
c6d20c05
authored
Mar 22, 2022
by
Lawrence McAfee
Browse files
implemented 'get_grad_views_for_grad_norm()'.
parent
0481f58e
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
123 additions
and
107 deletions
+123
-107
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+23
-23
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+64
-0
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+36
-84
No files found.
megatron/optimizer/clip_grads.py
View file @
c6d20c05
...
...
@@ -119,21 +119,21 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
total_norm
=
grad_norm
**
norm_type
# >>>
from
megatron
import
get_args
from
lutil
import
pax
args
=
get_args
()
for
r
in
range
(
torch
.
distributed
.
get_world_size
()):
if
torch
.
distributed
.
get_rank
()
==
r
:
print
(
"compute: r %d, dist-op %d, gnorm %f ... p %d, g %d, gn %d"
%
(
torch
.
distributed
.
get_rank
(),
args
.
use_distributed_optimizer
,
grad_norm
.
item
(),
sum
(
t
.
nelement
()
for
t
in
parameters
),
sum
(
t
.
nelement
()
for
t
in
grads
),
sum
(
t
.
nelement
()
for
t
in
grads_for_norm
),
))
torch
.
distributed
.
barrier
()
exit
(
0
)
#
from megatron import get_args
#
from lutil import pax
#
args = get_args()
#
for r in range(torch.distributed.get_world_size()):
#
if torch.distributed.get_rank() == r:
#
print("compute: r %d, dist-op %d, gnorm %f ... p %d, g %d, gn %d" % (
#
torch.distributed.get_rank(),
#
args.use_distributed_optimizer,
#
grad_norm.item(),
#
sum(t.nelement() for t in parameters),
#
sum(t.nelement() for t in grads),
#
sum(t.nelement() for t in grads_for_norm),
#
))
#
torch.distributed.barrier()
#
exit(0)
# pax(2, {
# "use distrib opt" : args.use_distributed_optimizer,
# "norm_type" : norm_type,
...
...
@@ -154,14 +154,14 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2,
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
# >>>
from
megatron
import
get_args
from
lutil
import
pax
args
=
get_args
()
pax
(
0
,
{
"use distrib opt"
:
args
.
use_distributed_optimizer
,
"norm_type"
:
norm_type
,
"total_norm"
:
total_norm
,
})
#
from megatron import get_args
#
from lutil import pax
#
args = get_args()
#
pax(0, {
#
"use distrib opt" : args.use_distributed_optimizer,
#
"norm_type" : norm_type,
#
"total_norm" : total_norm,
#
})
# <<<
# Scale.
...
...
megatron/optimizer/distrib_optimizer.py
View file @
c6d20c05
...
...
@@ -22,9 +22,17 @@ import torch
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
# >>>
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
# <<<
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
# >>>
from
.optimizer
import
get_clippy
from
lutil
import
pax
,
tp
# <<<
class
Shard
:
def
__init__
(
self
,
start
,
end
):
...
...
@@ -188,6 +196,45 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Update group's param.
group_shard
[
"orig_group"
][
"params"
]
=
[
main_param
]
# >>>
@
classmethod
def
get_grad_views_for_grad_norm
(
cls
,
opt_group_shards
,
optimizer
):
grad_views
=
[]
# grad_views_SKIPPED = []
for
group_index
,
opt_group_shard
in
enumerate
(
opt_group_shards
):
opt_grad
=
optimizer
.
param_groups
[
group_index
][
"params"
][
0
].
grad
for
param
,
shard
in
opt_group_shard
[
"param_map"
].
items
():
if
param_is_not_shared
(
param
)
and
\
param_is_not_tensor_parallel_duplicate
(
param
):
grad_view
=
opt_grad
[
shard
.
start
:
shard
.
end
]
grad_views
.
append
(
grad_view
)
# else:
# grad_views_SKIPPED.append(opt_grad[shard.start:shard.end])
# >>>
# my_rank = torch.distributed.get_rank()
# for r in range(torch.distributed.get_world_size()):
# if r == my_rank:
# print("r %d, grad views %s." % (
# my_rank,
# ", ".join(str(tuple(g.shape)) for g in grad_views),
# ))
# torch.distributed.barrier()
# for r in range(torch.distributed.get_world_size()):
# if r == my_rank:
# print("r %d, SKIPPED %s." % (
# my_rank,
# ", ".join(str(tuple(g.shape)) for g in grad_views_SKIPPED),
# ))
# torch.distributed.barrier()
# exit(0)
# <<<
return
grad_views
# <<<
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
...
...
@@ -227,6 +274,22 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Initialize main params.
self
.
_copy_model_params_to_main_params
()
# >>> numel/nelem per rank >>>
# for r in range(torch.distributed.get_world_size()):
# if r == torch.distributed.get_rank():
# for m in self.models:
# for b in m._grad_buffers.values():
# print("r %d, %d." % (r, b.data.nelement()))
# torch.distributed.barrier()
# exit(0)
# <<<
# Params for grad norm.
self
.
grad_views_for_grad_norm
=
self
.
get_grad_views_for_grad_norm
(
self
.
opt_group_shards
,
self
.
optimizer
)
def
get_model_parallel_group
(
self
):
return
None
...
...
@@ -407,6 +470,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
group
=
data_parallel_group
,
)
timers
(
'backward-params-all-reduce'
).
stop
()
def
gather_model_params
(
self
,
args
,
timers
):
timers
(
'backward-params-all-gather'
).
start
()
...
...
megatron/optimizer/optimizer.py
View file @
c6d20c05
...
...
@@ -31,6 +31,20 @@ from megatron.utils import unwrap_model
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
# >>>
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
lutil
import
pax
get_clippy
=
lambda
params
:
[
"%d, %d, %d ... %s"
%
(
p
.
grad
is
not
None
,
param_is_not_shared
(
p
),
param_is_not_tensor_parallel_duplicate
(
p
),
str
(
tuple
(
p
.
shape
)),
)
for
p
in
params
]
# <<<
def
_zero_grad_group_helper
(
group
,
set_to_none
):
"""Zero out the gradient for a group of parameters.
...
...
@@ -105,6 +119,17 @@ class MegatronOptimizer(ABC):
def
clip_grad_norm
(
self
,
clip_grad
):
# >>>
# model_params = [ p for m in self.models for p in m.parameters() ]
# optim_params = self.get_parameters()
# from lutil import pax
# pax(1, {
# "model_params" : get_clippy(model_params),
# "optim_params" : get_clippy(optim_params),
# })
# <<<
params
=
self
.
get_parameters
()
return
clip_grad_norm_fp32
(
params
,
clip_grad
,
...
...
@@ -408,91 +433,7 @@ class MixedPrecisionOptimizer(MegatronOptimizer):
timers
(
'optimizer-clip-main-grad'
).
start
()
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
# >>>
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
def
use_grad
(
p
):
conditions
=
[
p
.
grad
is
not
None
,
param_is_not_shared
(
p
),
param_is_not_tensor_parallel_duplicate
(
p
),
# getattr(p, "shared", False),
]
return
all
(
conditions
)
# def print_module(m, d):
# ps = [ "%d/%s" % (
# use_grad(p),
# str(tuple(p.shape)),
# ) for p in m.parameters(recurse = False) ]
# ps = [
# str(tuple(p))
# for p in m.parameters(recurse = False)
# if use_grad(p)
# ]
# print("%s %s | %s" % (".." * d, type(m).__name__, ", ".join(ps)))
# if torch.distributed.get_rank() == 0:
# visited = []
# queue = [ (m, 0) for m in self.models ]
# while queue:
# m, d = queue.pop()
# visited.append((m, d))
# # print_module(m, d)
# queue.extend(reversed([ (mm, d + 1) for mm in m.children() ]))
# for m, d in visited:
# print_module(m, d)
for
r
in
range
(
torch
.
distributed
.
get_world_size
()):
if
r
==
torch
.
distributed
.
get_rank
():
# print("r %d, %s" % (
# torch.distributed.get_rank(),
# "".join(
# "%d" % use_grad(p)
# for m in self.models
# for p in m.parameters()
# ),
# ))
# print("r %d [ d %d, t %d, p %d ] ... %s" % (
# torch.distributed.get_rank(),
# mpu.get_data_parallel_rank(),
# mpu.get_tensor_model_parallel_rank(),
# mpu.get_pipeline_model_parallel_rank(),
# ", ".join(str(tuple(p.shape)) for p in self.get_parameters() if not use_grad(p)),
# ))
print
(
"r %d [ d %d, t %d, p %d ] ... %d, %d ... %s"
%
(
torch
.
distributed
.
get_rank
(),
mpu
.
get_data_parallel_rank
(),
mpu
.
get_tensor_model_parallel_rank
(),
mpu
.
get_pipeline_model_parallel_rank
(),
sum
(
p
.
nelement
()
for
p
in
self
.
get_parameters
()
if
use_grad
(
p
)),
sum
(
p
.
nelement
()
for
p
in
self
.
get_parameters
()
if
not
use_grad
(
p
)),
""
.
join
(
"%d"
%
use_grad
(
p
)
for
p
in
self
.
get_parameters
()
),
))
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
exit
(
0
)
# <<<
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
# >>>
from
lutil
import
pax
pax
(
0
,
{
"use distrib opt"
:
args
.
use_distributed_optimizer
,
"grad_norm"
:
grad_norm
,
})
# <<<
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
...
...
@@ -607,6 +548,17 @@ class Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# >>>
# model_params = [ p for m in self.models for p in m.parameters() ]
# optim_params = self.get_parameters()
# model_params.sort(key = lambda p : p.nelement(), reverse = True)
# optim_params.sort(key = lambda p : p.nelement(), reverse = True)
# # assert len(model_params) == len(optim_params
# pax(7, {
# "model_params" : get_clippy(model_params),
# "optim_params" : get_clippy(optim_params),
# })
# <<<
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment