Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
371a8828
Commit
371a8828
authored
Feb 23, 2022
by
Lawrence McAfee
Browse files
isolated grad discrepency to BaseFloat16Optimizer.step()
parent
e7f0cdee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
118 additions
and
113 deletions
+118
-113
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+117
-112
megatron/training.py
megatron/training.py
+1
-1
No files found.
megatron/optimizer/optimizer.py
View file @
371a8828
...
...
@@ -32,7 +32,7 @@ from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
# >>>
from
lutil
import
pax
,
tp
DEBUG_ITERATION
=
10
DEBUG_ITERATION
=
1
#
10
# <<<
...
...
@@ -278,6 +278,14 @@ class BaseFloat16Optimizer(MegatronOptimizer):
self
.
_copy_model_grads_to_main_grads
(
ITERATION
)
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# >>>
# pax(0, {
# "[LOC]" : "[** BEFORE UNSCALE **]",
# "param_group / params" : [ p for g in self.optimizer.param_groups for p in g["params"] ],
# "param_group / grads" : [ p.grad for g in self.optimizer.param_groups for p in g["params"] ],
# })
# <<<
# pax(0, {
# "params" : self.get_parameters(), # self.main_param_shards,
# "grads" : [ p.grad for p in self.get_parameters() ], # self.main_param_shards ],
...
...
@@ -305,6 +313,14 @@ class BaseFloat16Optimizer(MegatronOptimizer):
})
return
False
,
None
,
None
# >>>
pax
(
0
,
{
"[LOC]"
:
"[** BEFORE CLIP **]"
,
"param_group / params"
:
[
p
for
g
in
self
.
optimizer
.
param_groups
for
p
in
g
[
"params"
]
],
"param_group / grads"
:
[
p
.
grad
for
g
in
self
.
optimizer
.
param_groups
for
p
in
g
[
"params"
]
],
})
# <<<
# Clip the main gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
grad_norm
=
None
...
...
@@ -316,16 +332,18 @@ class BaseFloat16Optimizer(MegatronOptimizer):
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Step the optimizer.
self
.
optimizer
.
step
()
# >>>
# pax(0, {
# "main params" : self.get_main_params(),
# "main grads" : self.get_main_grads(),
# })
pax
(
0
,
{
# "main params" : self.get_main_params(),
# "main grads" : self.get_main_grads(),
**
{
"param_groups / %d"
%
i
:
g
for
i
,
g
in
enumerate
(
self
.
optimizer
.
param_groups
)},
"param_group / grads"
:
[
p
.
grad
for
g
in
self
.
optimizer
.
param_groups
for
p
in
g
[
"params"
]
],
})
# <<<
# Step the optimizer.
self
.
optimizer
.
step
()
# Update params from main params.
timers
(
'optimizer-copy-main-to-model-params'
).
start
()
self
.
_copy_main_params_to_model_params
(
ITERATION
)
...
...
@@ -415,6 +433,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
fp32_from_float16_params_this_group
.
append
(
main_param
)
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
# >>>
raise
Exception
(
"hi."
)
# <<<
self
.
optimizer
.
state
[
main_param
]
\
=
self
.
optimizer
.
state
.
pop
(
param
)
...
...
@@ -483,6 +504,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers
=
get_timers
()
# <<<
# >>>
# pax(0, {
# "grads" : [ p.main_grad for m in model for p in m.parameters() ],
# })
# <<<
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
...
...
@@ -490,6 +517,12 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
model_module
.
allreduce_gradients
()
timers
(
'backward-params-all-reduce'
).
stop
()
# >>>
# pax(0, {
# "grads" : [ p.main_grad for m in model for p in m.parameters() ],
# })
# <<<
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
...
...
@@ -497,6 +530,9 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
timers
(
'backward-embedding-all-reduce'
).
start
()
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
# >>>
raise
Exception
(
"hi."
)
# <<<
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
...
...
@@ -576,6 +612,16 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
if
not
self
.
use_contiguous_buffers_in_local_ddp
:
model_param
.
main_grad
=
None
# >>>
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** main. **",
# "ITERATION" : ITERATION,
# "model grads" :
# [ p.main_grad for m in self.models for p in m.parameters() ],
# })
# <<<
def
_collect_main_grad_data_for_unscaling
(
self
):
main_grads
=
[]
...
...
@@ -623,7 +669,7 @@ class Float16OptimizerWithFloat16Params(BaseFloat16Optimizer):
pax
(
0
,
{
"** branch **"
:
"** main. **"
,
"ITERATION"
:
ITERATION
,
"model params"
:
[
p
for
m
in
self
.
models
for
p
in
m
.
parameters
()
],
"model params"
:
[
p
for
m
in
self
.
models
for
p
in
m
.
parameters
()],
})
# <<<
...
...
@@ -984,9 +1030,18 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
[
g
[
"orig_group"
]
for
g
in
self
.
opt_group_shards
]
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
# pax(0, {
# # "opt_group_shards" : self.opt_group_shards,
# # "param_groups" : self.optimizer.param_groups,
# "optimizer" : self.optimizer,
# "optimizer / state" : self.optimizer.state,
# })
# pax(1, {
# "opt_group_shards" : self.opt_group_shards,
# "param_groups" : self.optimizer.param_groups,
# "optimizer" : self.optimizer,
# **{"optimizer / param_groups / %d" % i : g
# for i, g in enumerate(self.optimizer.param_groups)},
# "optimizer / state" : self.optimizer.state,
# "optimizer / state_dict" : self.optimizer.state_dict(),
# })
# Initialize main params.
...
...
@@ -1028,6 +1083,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def
get_world_model_params
(
self
):
'''** FOR DEBUGGING. **'''
return
[
p
for
m
in
self
.
models
for
p
in
m
.
parameters
()
]
def
get_world_model_grads
(
self
):
'''** FOR DEBUGGING. **'''
return
[
p
.
main_grad
for
p
in
self
.
get_world_model_params
()
]
def
get_main_params
(
self
):
return
[
g
[
"params"
][
0
]
for
g
in
self
.
optimizer
.
param_groups
]
...
...
@@ -1075,20 +1133,25 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
for
model_index
,
model
in
enumerate
(
self
.
models
):
for
dtype
,
gbuf_shard
in
self
.
model_gbuf_shards
[
model_index
].
items
():
world_shards
=
gbuf_shard
[
"world_all"
]
gbuf
=
model
.
_grad_buffers
[
dtype
]
gbuf_views
=
[]
for
shard
in
world_shards
:
gbuf_views
.
append
(
gbuf
.
data
[
shard
.
start
:
shard
.
end
])
gbuf
=
model
.
_grad_buffers
[
dtype
].
data
gbuf_views
=
[
gbuf
[
s
.
start
:
s
.
end
]
for
s
in
world_shards
]
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf_views
))
# pax(0, {
# "world_shards" : world_shards,
# "gbuf_views" : gbuf_views,
# })
# pax(0, {"gbuf_view_items": gbuf_view_items})
return
gbuf_view_items
def
reduce_grads
(
self
,
model
):
# >>>
timers
=
get_timers
()
# <<<
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync word embedding params.
...
...
@@ -1101,6 +1164,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
timers
(
'backward-embedding-all-reduce'
).
start
()
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
# >>>
raise
Exception
(
"hi."
)
# <<<
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
...
...
@@ -1116,6 +1182,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
if
args
.
DDP_impl
==
'local'
:
grad
=
word_embeddings_weight
.
main_grad
else
:
raise
Exception
(
"only 'main_grad' supported for distrib-opt."
)
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
# +++
...
...
@@ -1123,7 +1190,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# torch.distributed.all_reduce(grad_shard,
# group=mpu.get_embedding_group())
# <<<
timers
(
'backward-embedding-all-reduce'
).
stop
()
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Sync T5 position embedding params.
...
...
@@ -1133,18 +1200,30 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
# pax(0, {"gbuf_views": [g for item in gbuf_view_items for g in item[2]]})
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
# coalesced /= mpu.get_data_parallel_world_size()
gbuf
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
torch
.
mul
(
gbuf
.
data
,
1.
/
data_parallel_world_size
,
out
=
gbuf
.
data
)
# gbuf_views = [ t / data_parallel_world_size for t in gbuf_views ]
# gbuf_d
# pax(0, {
# "data_parallel_world_size" : data_parallel_world_size,
# "gbuf" : tp(gbuf),
# })
torch
.
distributed
.
reduce_scatter
(
gbuf_views
[
data_parallel_rank
],
gbuf_views
,
group
=
data_parallel_group
,
)
# pax(0, {"gbuf_view
_
item
s":
gbuf_view_items})
# pax(0, {"gbuf_view
s": [g for
item
in
gbuf_view_items
for g in item[2]]
})
def
gather_params
(
self
):
...
...
@@ -1161,24 +1240,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
group
=
data_parallel_group
,
)
# Each model param now contains its updated values in it
'
s
# Each model param now contains its updated values in its
# '.main_grad' field.
for
param
in
self
.
param_gbuf_map
:
param
.
detach
().
copy_
(
param
.
main_grad
)
# pax(0, {
# "param" : tp(param),
# "main_grad" : tp(param.main_grad),
# # "grad" : tp(param.grad),
# })
# pax(1, {
# "data_parallel_rank" : data_parallel_rank,
# "main params" : self.get_main_params(),
# "model params / world" : self.get_world_model_params(),
# **{"gbuf_view_items / %d"%i:v[2] for i,v in enumerate(gbuf_view_items)},
# # "gbuf_view_item" : tp(gbuf_view[data_parallel_rank]),
# # "model params / local" : self.get_local_model_param_views(),
# })
# pax(0, {"gbuf_view_items": gbuf_view_items})
def
_collect_main_grad_data_for_unscaling
(
self
):
return
[
g
.
data
for
g
in
self
.
get_main_grads
()
]
...
...
@@ -1199,51 +1266,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Copy shard data.
main_view
=
main_param
[
main_shard
.
start
:
main_shard
.
end
]
model_view
=
model_param
.
view
(
-
1
)[
model_shard
.
start
:
model_shard
.
end
]
# try:
main_view
.
detach
().
copy_
(
model_view
)
# except:
# pax({
# "main_param" : tp(main_param),
# "model_param" : tp(model_param),
# "main_view" : tp(main_view),
# "model_view" : tp(model_view),
# "main_shard" : str(main_shard),
# "model_shard" : str(model_shard),
# })
# pax(0, {
# **{
# "opt_group_shards / %d" % i : s
# for i, s in enumerate(self.opt_group_shards)
# },
# "main_params" : self.get_main_params(),
# })
def
_copy_model_grads_to_main_grads
(
self
,
ITERATION
):
# >>>
model_grads
=
self
.
get_local_model_grad_views
()
model_has_nan
=
self
.
has_nan_debug
(
model_grads
)
if
model_has_nan
:
pax
(
1
,
{
"ITERATION"
:
ITERATION
,
"model grads"
:
model_grads
,
"model_has_nan"
:
model_has_nan
,
"model params / local"
:
self
.
get_local_model_param_views
(),
# "model params / world" : [ list(self.param_gbuf_map),
# "main grads" : self.get_main_grads(),
})
# <<<
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
model_param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
# Model shard.
model_index
,
dtype
=
self
.
param_gbuf_map
[
model_param
]
model_shard
=
self
.
model_gbuf_shards
\
[
model_index
][
dtype
][
"param_map"
][
model_param
][
"gbuf_world"
]
assert
main_shard
.
size
==
model_shard
.
size
# pax(0, {
# "model_param" : tp(model_param),
# "main_shard" : str(main_shard),
# "param shard" : self.model_gbuf_shards \
# [model_index][dtype]["param_map"][model_param],
# })
# Copy from DDP's contiguous buffer to main shard's grad.
model_grad
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
main_grad
=
self
.
get_main_grad
(
group_index
)
...
...
@@ -1269,38 +1314,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# })
# >>>
# pax(1, {
# # "model_gbuf_shards" : self.model_gbuf_shards,
# **{
# "opt_group_shards / %d" % i : s
# for i, s in enumerate(self.opt_group_shards)
# },
# "main_grads" : self.get_main_grads(),
# })
# for group_index, main_grad in enumerate(self.get_main_grads()):
# # is_nan = torch.any(torch.isnan(main_grad)).item()
# if is_nan:
# # opt_group_shard = self.opt_group_shards[group_index]
# # param_views = []
# # for param, shard in opt_group_shard["param_map"].items():
# # ddd
# pax(0, {
# "opt_group_shard" : self.opt_group_shards[group_index],
# "param_map" : [ (str(p.shape), str(d)) for p, d in self.opt_group_shards[group_index]["param_map"].items() ],
# "gbufs" : [ b.data for m in self.models for d, b in m._grad_buffers.items() ],
# "group_index" : group_index,
# "main_param" : tp(self.get_main_param(group_index)),
# "main_grad" : tp(main_grad),
# "is_nan" : is_nan,
# })
main_grads
=
self
.
get_main_grads
()
main_has_nan
=
self
.
has_nan_debug
(
main_grads
)
if
main_has_nan
:
raise
Exception
(
"hi."
)
# pax(1, {
# "model grads" : self.get_local_model_grad_views(),
# })
# if ITERATION == DEBUG_ITERATION:
# pax(0, {
# "** branch **" : "** fix. **",
# "ITERATION" : ITERATION,
# # "model grads" : self.get_world_model_grads(),
# "main_grads" : self.get_main_grads(),
# })
# <<<
...
...
@@ -1340,27 +1360,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "main_shard" : str(main_shard),
# })
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
# >>>
for
param
in
self
.
param_gbuf_map
:
# is_nan = torch.any(torch.isnan(param)).item()
is_nan
=
not
torch
.
all
(
torch
.
isfinite
(
param
)).
item
()
if
is_nan
:
pax
({
"param"
:
tp
(
param
),
"is_nan"
:
is_nan
,
})
if
ITERATION
==
DEBUG_ITERATION
:
pax
(
0
,
{
"** branch **"
:
"** fix. **"
,
"ITERATION"
:
ITERATION
,
# "main params" : self.get_main_params(),
# "model params / local" : self.get_local_model_param_views(),
"model params"
:
[
p
for
m
in
self
.
models
for
p
in
m
.
parameters
()],
"model params"
:
self
.
get_world_model_params
(),
})
# <<<
...
...
megatron/training.py
View file @
371a8828
...
...
@@ -432,7 +432,7 @@ def train_step(forward_step_func, data_iterator,
# >>>
# Reduce gradients. (with distributed optimizer option, optimizer
# now responsible for reducing gradients)
optimizer
.
reduce_grad
ient
s
(
model
)
optimizer
.
reduce_grads
(
model
)
# <<<
# >>>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment