Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
f8fd34e3
Commit
f8fd34e3
authored
Mar 02, 2022
by
Lawrence McAfee
Browse files
fixed gather params; now copying all params, not just local dp subset.
parent
cbcd5579
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
143 additions
and
72 deletions
+143
-72
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+7
-5
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+127
-66
megatron/training.py
megatron/training.py
+9
-1
No files found.
megatron/optimizer/clip_grads.py
View file @
f8fd34e3
...
@@ -155,14 +155,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
...
@@ -155,14 +155,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2, ITERATION=None):
# >>>
# >>>
from
megatron
import
get_args
from
megatron
import
get_args
args
=
get_args
()
args
=
get_args
()
if
not
args
.
use_distributed_optimizer
:
if
args
.
use_distributed_optimizer
:
torch
.
distributed
.
all_reduce
(
total_norm
,
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
group
=
mpu
.
get_model_parallel_group
())
# +++
else
:
else
:
torch
.
distributed
.
all_reduce
(
total_norm
,
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_model_parallel_group
())
# <<<
# <<<
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
...
@@ -223,9 +222,12 @@ def count_zeros_fp32(parameters):
...
@@ -223,9 +222,12 @@ def count_zeros_fp32(parameters):
# Sum across all model-parallel GPUs.
# Sum across all model-parallel GPUs.
# >>>
# >>>
from
megatron
import
get_args
args
=
get_args
()
if
args
.
use_distributed_optimizer
:
if
args
.
use_distributed_optimizer
:
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
op
=
torch
.
distributed
.
ReduceOp
.
SUM
)
# pax({"total_num_zeros": total_num_zeros.item()})
else
:
else
:
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
...
...
megatron/optimizer/optimizer.py
View file @
f8fd34e3
...
@@ -32,7 +32,7 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
...
@@ -32,7 +32,7 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
# >>>
from
lutil
import
pax
,
tp
from
lutil
import
pax
,
tp
DEBUG_ITERATION
=
0
# 10
DEBUG_ITERATION
=
2
# 10
# <<<
# <<<
...
@@ -273,7 +273,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
...
@@ -273,7 +273,7 @@ class BaseFloat16Optimizer(MegatronOptimizer):
return
return
for
r
in
range
(
torch
.
distributed
.
get_world_size
()):
for
r
in
range
(
torch
.
distributed
.
get_world_size
()):
if
my_rank
==
r
:
if
my_rank
==
r
:
print
(
" + %
4
s; [r%d]; %s, %.12e"
%
(
"fix"
if
args
.
use_distributed_optimizer
else
"main"
,
my_rank
,
key
,
value
))
print
(
" +
br/
%s; [r%d
, i%d
]; %s, %.12e"
%
(
"fix
"
if
args
.
use_distributed_optimizer
else
"main"
,
my_rank
,
ITERATION
,
key
,
value
))
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
# if my_rank == 0:
# if my_rank == 0:
...
@@ -282,9 +282,11 @@ class BaseFloat16Optimizer(MegatronOptimizer):
...
@@ -282,9 +282,11 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# exit(0)
# exit(0)
exit
(
0
)
exit
(
0
)
def
_debug_model
(
self
,
ITERATION
,
key
,
use_param
):
# def _debug_model(self, ITERATION, key, use_param):
def
debug_model
(
self
,
ITERATION
,
key
,
use_grad
):
use_grad
=
bool
(
use_grad
)
tensors
=
[
tensors
=
[
(
p
.
float
()
if
use_
pa
ra
m
else
p
.
main_grad
.
float
())
(
p
.
main_grad
.
float
()
if
use_
g
ra
d
else
p
.
float
())
for
m
in
self
.
models
for
p
in
m
.
parameters
()
for
m
in
self
.
models
for
p
in
m
.
parameters
()
]
]
# pax(0, {
# pax(0, {
...
@@ -296,65 +298,72 @@ class BaseFloat16Optimizer(MegatronOptimizer):
...
@@ -296,65 +298,72 @@ class BaseFloat16Optimizer(MegatronOptimizer):
return
self
.
debug_general
(
return
self
.
debug_general
(
ITERATION
,
ITERATION
,
"model/%s, %s [count %d]"
%
(
"model/%s, %s [count %d]"
%
(
"
pa
ra
m
"
if
use_
pa
ra
m
else
"
g
ra
d
"
,
"
g
ra
d
"
if
use_
g
ra
d
else
"
pa
ra
m
"
,
key
,
key
,
count
,
count
,
),
),
sum
(
torch
.
sum
(
torch
.
abs
(
t
))
for
t
in
tensors
).
item
()
/
count
,
# sum(torch.sum(torch.abs(t)) for t in tensors).item() / count,
sum
(
torch
.
sum
(
torch
.
abs
(
t
))
for
t
in
tensors
),
)
)
def
_debug_main
(
self
,
ITERATION
,
key0
,
key1
,
f
,
ff
):
# def debug_model_param(self, ITERATION, key):
count
=
sum
(
# return self._debug_model(ITERATION, key, True)
p
.
nelement
()
# def debug_model_grad(self, ITERATION, key):
for
g
in
self
.
optimizer
.
param_groups
# return self._debug_model(ITERATION, key, False)
for
p
in
g
[
"params"
]
)
# def _debug_main(self, ITERATION, key0, key1, f, ff):
return
self
.
debug_general
(
# count = sum(
ITERATION
,
# p.nelement()
"main/%s, %s [count %d]"
%
(
key1
,
key0
,
count
),
# for g in self.optimizer.param_groups
sum
(
ff
(
f
(
p
))
# for p in g["params"]
for
g
in
self
.
optimizer
.
param_groups
# )
for
p
in
g
[
"params"
]).
item
()
/
count
,
# return self.debug_general(
)
# ITERATION,
# def debug_main_param_mean(self, ITERATION, key):
# "main/%s, %s [count %d]" % (key1, key0, count),
# sum(ff(f(p))
# for g in self.optimizer.param_groups
# for p in g["params"]).item() / count,
# )
# def debug_main_param(self, ITERATION, key):
# return self._debug_main(
# return self._debug_main(
# ITERATION,
# ITERATION,
# key,
# key,
# "param mean",
# "param", # sum",
# lambda p : p,
# # lambda p : p,
# torch.mean,
# lambda p : torch.abs(p),
# torch.sum,
# )
# )
# def debug_main_param_sum(self, ITERATION, key):
# def debug_main_grad(self, ITERATION, key):
def
debug_model_param
(
self
,
ITERATION
,
key
):
return
self
.
_debug_model
(
ITERATION
,
key
,
True
)
def
debug_model_grad
(
self
,
ITERATION
,
key
):
return
self
.
_debug_model
(
ITERATION
,
key
,
False
)
def
debug_main_param
(
self
,
ITERATION
,
key
):
return
self
.
_debug_main
(
ITERATION
,
key
,
"param"
,
# sum",
# lambda p : p,
lambda
p
:
torch
.
abs
(
p
),
torch
.
sum
,
)
# def debug_main_grad_mean(self, ITERATION, key):
# return self._debug_main(
# return self._debug_main(
# ITERATION,
# ITERATION,
# key,
# key,
# "grad mean",
# "grad", # sum",
# lambda p : p.grad,
# # lambda p : p.grad,
# torch.mean,
# lambda p : torch.abs(p.grad),
# torch.sum,
# )
# )
# def debug_main_grad_sum(self, ITERATION, key):
# def _debug_main(self, ITERATION, key, use_param):
def
debug_main_grad
(
self
,
ITERATION
,
key
):
def
debug_main
(
self
,
ITERATION
,
key
,
use_grad
):
return
self
.
_debug_main
(
use_grad
=
bool
(
use_grad
)
tensors
=
[
p
.
grad
if
use_grad
else
p
for
g
in
self
.
optimizer
.
param_groups
for
p
in
g
[
"params"
]
]
tensors
=
[
t
.
float
()
for
t
in
tensors
]
count
=
sum
(
t
.
nelement
()
for
t
in
tensors
)
return
self
.
debug_general
(
ITERATION
,
ITERATION
,
key
,
"main/%s, %s [count %d]"
%
(
"grad"
,
# sum",
"grad"
if
use_grad
else
"param"
,
# lambda p : p.grad,
key
,
lambda
p
:
torch
.
abs
(
p
.
grad
),
count
,
torch
.
sum
,
),
sum
(
torch
.
sum
(
torch
.
abs
(
t
))
for
t
in
tensors
),
)
)
# def debug_main_param(self, ITERATION, key):
# return self._debug_main(ITERATION, key, True)
# def debug_main_grad(self, ITERATION, key):
# return self._debug_main(ITERATION, key, False)
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
# <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -365,6 +374,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
...
@@ -365,6 +374,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
# >>>
# >>>
# self.debug_model_param(ITERATION, "before copy grad.")
# self.debug_model_param(ITERATION, "before copy grad.")
# self.debug_model_grad(ITERATION, "before copy grad.")
# self.debug_model_grad(ITERATION, "before copy grad.")
# self.debug_main_param(ITERATION, "before copy grad.")
# self.debug_main_grad(ITERATION, "before copy grad.")
# <<<
# <<<
# Copy gradients from model params to main params.
# Copy gradients from model params to main params.
...
@@ -373,10 +384,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
...
@@ -373,10 +384,8 @@ class BaseFloat16Optimizer(MegatronOptimizer):
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# >>>
# >>>
# self.debug_model_param(ITERATION, "after copy grad.")
# self.debug_model(ITERATION, "after copy grad.", 0)
# self.debug_model_grad(ITERATION, "after copy grad.")
# self.debug_main(ITERATION, "after copy grad.", 1)
# self.debug_main_param(ITERATION, "after copy grad.")
# self.debug_main_grad(ITERATION, "after copy grad.")
# <<<
# <<<
# Do unscale, check for inf, and update grad scaler only for
# Do unscale, check for inf, and update grad scaler only for
...
@@ -412,12 +421,23 @@ class BaseFloat16Optimizer(MegatronOptimizer):
...
@@ -412,12 +421,23 @@ class BaseFloat16Optimizer(MegatronOptimizer):
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
self
.
log_num_zeros_in_grad
else
None
# >>>
# param = self.optimizer.param_groups[0]["params"][0]
# pax(0, {
# "param" : tp(param),
# "grad" : tp(param.grad),
# })
# <<<
# >>>
# self.debug_main(ITERATION, "before step.", 0)
# <<<
# Step the optimizer.
# Step the optimizer.
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
# >>>
# >>>
# self.debug_main_param(ITERATION, "after step.")
# self.debug_main(ITERATION, "after step.", 0)
# self.debug_main_grad(ITERATION, "after step.")
# <<<
# <<<
# Update params from main params.
# Update params from main params.
...
@@ -652,7 +672,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
...
@@ -652,7 +672,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
# <<<
# <<<
timers
(
'backward-embedding-all-reduce'
).
stop
()
timers
(
'backward-embedding-all-reduce'
).
stop
()
def
gather_params
(
self
):
def
gather_params
(
self
,
ITERATION
):
pass
pass
def
_copy_model_grads_to_main_grads
(
self
,
ITERATION
):
def
_copy_model_grads_to_main_grads
(
self
,
ITERATION
):
...
@@ -1273,6 +1293,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1273,6 +1293,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
# Reduce-scatter.
# timers('backward-params-reduce-scatter').start()
timers
(
'backward-params-all-reduce'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
...
@@ -1292,6 +1314,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1292,6 +1314,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# buffer_.data, group=mpu.get_data_parallel_group())
# buffer_.data, group=mpu.get_data_parallel_group())
# <<<
# <<<
# >>>
# self.debug_main_param(0, "before reduce scatter")
# self.debug_main_grad(0, "before reduce scatter")
# <<<
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
# coalesced /= mpu.get_data_parallel_world_size()
# coalesced /= mpu.get_data_parallel_world_size()
gbuf
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
gbuf
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
...
@@ -1320,10 +1347,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1320,10 +1347,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# gbuf,
# gbuf,
# group = data_parallel_group,
# group = data_parallel_group,
# )
# )
# timers('backward-params-reduce-scatter').stop()
timers
(
'backward-params-all-reduce'
).
stop
()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
def
gather_params
(
self
):
def
gather_params
(
self
,
ITERATION
):
# >>>
timers
=
get_timers
()
# <<<
timers
(
'backward-params-all-gather'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
...
@@ -1340,11 +1375,32 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1340,11 +1375,32 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Each model param now contains its updated values in its
# Each model param now contains its updated values in its
# '.main_grad' field.
# '.main_grad' field.
for
param
in
self
.
param_gbuf_map
:
# for param in self.param_gbuf_map: # ... incomplete param list.
param
.
detach
().
copy_
(
param
.
main_grad
)
for
model
in
self
.
models
:
for
dtype
,
param_map
in
model
.
_grad_buffer_param_index_map
.
items
():
for
param
in
param_map
:
param
.
detach
().
copy_
(
param
.
main_grad
)
timers
(
'backward-params-all-gather'
).
stop
()
# pax(0, {"gbuf_view_items": gbuf_view_items})
# pax(0, {"gbuf_view_items": gbuf_view_items})
# >>>
# self.debug_main(ITERATION, "after/inside gather_params.", 0)
# self.debug_model(ITERATION, "after/inside gather_params.", 0)
# if ITERATION == 2:
# pax(1, {
# "ITERATION" : ITERATION,
# # "gbufs" : [
# # tp(b.data)
# # for m in self.models
# # for b in m._grad_buffers.values()
# # ],
# "param_gbuf_map" : [ str(tuple(p.shape)) for p in self.param_gbuf_map ],
# })
# <<<
def
_collect_main_grad_data_for_unscaling
(
self
):
def
_collect_main_grad_data_for_unscaling
(
self
):
return
[
g
.
data
for
g
in
self
.
get_main_grads
()
]
return
[
g
.
data
for
g
in
self
.
get_main_grads
()
]
...
@@ -1400,24 +1456,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1400,24 +1456,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# pax(0, {
# pax(0, {
# "group_index" : group_index,
# "group_index" : group_index,
# "group_shard" : group_shard,
# "group_shard" : group_shard,
# "param" : tp(param),
#
#
"param" : tp(param),
# "model_index" : model_index,
# "model_index" : model_index,
# "
gbuf_
dtype" : str(
gbuf_
dtype),
# "dtype" : str(dtype),
# "model_grad
_tensor
" : tp(model_grad
_tensor
),
# "model_grad" : tp(model_grad),
# "main_grad
_tensor
" : tp(main_grad
_tensor
),
# "main_grad" : tp(main_grad),
# "model_
grad_
view" : tp(model_
grad_
view),
# "model_view" : tp(model_view),
# "main_
grad_
view" : tp(main_
grad_
view),
# "main_view" : tp(main_view),
# "model_shard" : str(model_shard),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# "main_shard" : str(main_shard),
# })
# })
# >>>
# >>>
# if ITERATION == DEBUG_ITERATION:
# if
1 or
ITERATION == DEBUG_ITERATION:
# pax(0, {
# pax(0, {
# "** branch **" : "** fix. **",
# "** branch **" : "** fix. **",
# "ITERATION" : ITERATION,
# "ITERATION" : ITERATION,
# # "model grads" : self.get_world_model_grads(),
# # "model grads" : self.get_world_model_grads(),
# "main_grads" : self.get_main_grads(),
# "main_grads" : self.get_main_grads(),
# "group shards" : [
# "group %d; %s" % (grp_idx, main_shard)
# for grp_idx, grp_shard in enumerate(self.opt_group_shards)
# for model_param, main_shard in grp_shard["param_map"].items()
# ],
# })
# })
# <<<
# <<<
...
...
megatron/training.py
View file @
f8fd34e3
...
@@ -431,6 +431,10 @@ def train_step(forward_step_func, data_iterator,
...
@@ -431,6 +431,10 @@ def train_step(forward_step_func, data_iterator,
if
args
.
empty_unused_memory_level
>=
1
:
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# >>>
# optimizer.debug_model(ITERATION, "before reduce grads.", 0)
# <<<
# >>>
# >>>
# Reduce gradients. (with distributed optimizer option, optimizer
# Reduce gradients. (with distributed optimizer option, optimizer
# now responsible for reducing gradients)
# now responsible for reducing gradients)
...
@@ -465,7 +469,11 @@ def train_step(forward_step_func, data_iterator,
...
@@ -465,7 +469,11 @@ def train_step(forward_step_func, data_iterator,
# >>>
# >>>
# Gather params gradients. (with distributed optimizer option, optimizer
# Gather params gradients. (with distributed optimizer option, optimizer
# now responsible for gathering updated params)
# now responsible for gathering updated params)
optimizer
.
gather_params
()
optimizer
.
gather_params
(
ITERATION
)
# <<<
# >>>
# optimizer.debug_model(ITERATION, "after gather params.", 0)
# <<<
# <<<
# Update learning rate.
# Update learning rate.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment