Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
e5db0fda
Commit
e5db0fda
authored
Feb 18, 2022
by
Lawrence McAfee
Browse files
modularized reduce_gradients, gather params; trainin runs, but loss==nan
parent
a7782b21
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
138 additions
and
28 deletions
+138
-28
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+138
-28
No files found.
megatron/optimizer/optimizer.py
View file @
e5db0fda
...
@@ -1318,6 +1318,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1318,6 +1318,29 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "local_shard_info_groups" : [ g["data"] for g in local_shard_info_groups ],
# "local_shard_info_groups" : [ g["data"] for g in local_shard_info_groups ],
# })
# })
def
get_model_grad_buffer_dp_views
(
self
):
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
args
=
get_args
()
assert
args
.
use_contiguous_buffers_in_local_ddp
# Grad buffer views.
gbuf_view_items
=
[]
for
model_index
,
model
in
enumerate
(
self
.
models
):
for
dtype
,
gbuf_shard
in
self
.
model_gbuf_shards
[
model_index
].
items
():
world_shards
=
gbuf_shard
[
"world_all"
]
gbuf
=
model
.
_grad_buffers
[
dtype
]
gbuf_views
=
[]
for
shard
in
world_shards
:
gbuf_views
.
append
(
gbuf
.
data
[
shard
.
start
:
shard
.
end
])
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf_views
))
# pax(0, {"gbuf_view_items": gbuf_view_items})
return
gbuf_view_items
def
reduce_gradients
(
self
,
model
):
def
reduce_gradients
(
self
,
model
):
# >>>
# >>>
...
@@ -1338,43 +1361,87 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1338,43 +1361,87 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
# Reduce-scatter.
# ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
# # ** only contiguous grad buffer supported, for now [ TEMPORARY ] **
assert
args
.
use_contiguous_buffers_in_local_ddp
# assert args.use_contiguous_buffers_in_local_ddp
# data_parallel_rank = mpu.get_data_parallel_rank()
# data_parallel_group = mpu.get_data_parallel_group()
# for model_index, model in enumerate(self.models):
# for dtype, gbuf_shard in self.model_gbuf_shards[model_index].items():
# world_shards = gbuf_shard["world_all"]
# gbuf = model._grad_buffers[dtype]
# gbuf_views = []
# for shard in world_shards:
# gbuf_views.append(gbuf.data[shard.start:shard.end])
# torch.distributed.reduce_scatter(
# gbuf_views[data_parallel_rank],
# gbuf_views,
# group = data_parallel_group,
# )
# # pax(0, {
# # "model_index" : model_index,
# # "model" : model,
# # "dtype" : str(dtype),
# # "gbuf_shard" : gbuf_shard,
# # "world_shards" : world_shards,
# # "gbuf_views" : gbuf_views,
# # })
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
for
model_index
,
model
in
enumerate
(
self
.
models
):
for
dtype
,
gbuf_shard
in
self
.
model_gbuf_shards
[
model_index
].
items
():
world_shards
=
gbuf_shard
[
"world_all"
]
gbuf
=
model
.
_grad_buffers
[
dtype
]
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
gbuf_views
=
[]
for
shard
in
world_shards
:
gbuf_views
.
append
(
gbuf
.
data
[
shard
.
start
:
shard
.
end
])
torch
.
distributed
.
reduce_scatter
(
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
gbuf_views
[
data_parallel_rank
],
torch
.
distributed
.
reduce_scatter
(
gbuf_views
,
gbuf_views
[
data_parallel_rank
],
group
=
data_parallel_group
,
gbuf_views
,
)
group
=
data_parallel_group
,
)
# pax(0, {"gbuf_view_items": gbuf_view_items})
# pax(0, {
def
gather_params
(
self
):
# "model_index" : model_index,
# "model" : model,
# "dtype" : str(dtype),
# "gbuf_shard" : gbuf_shard,
# "world_shards" : world_shards,
# "gbuf_views" : gbuf_views,
# })
# >>>
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
# torch.distributed.barrier()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
# raise Exception("hi.")
# <<<
def
gather_params
(
self
):
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
for
model_index
,
dtype
,
gbuf_views
in
gbuf_view_items
:
torch
.
distributed
.
all_gather
(
gbuf_views
,
gbuf_views
[
data_parallel_rank
],
group
=
data_parallel_group
,
)
# for param, (model_index, dtype) in self.param_gbuf_map.items():
# gbuf = self.model_gbuf_shards[model_index][dtype]
# pax(0, {
# "param" : tp(param),
# "model_index" : model_index,
# "dtype" : str(dtype),
# "gbuf" : gbuf,
# })
for
param
in
self
.
param_gbuf_map
:
param
.
detach
().
copy_
(
param
.
main_grad
)
# pax(0, {
# "param" : tp(param),
# "main_grad" : tp(param.main_grad),
# # "grad" : tp(param.grad),
# })
raise
Exception
(
"gather params."
)
# pax(0, {
# "gbuf_view_items" : gbuf_view_items,
# "param_gbuf_map" : [
# (str(tuple(p.shape)), d)
# for p, d in self.param_gbuf_map.items()
# ],
# })
# def step(self):
# def step(self):
...
@@ -1429,6 +1496,49 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1429,6 +1496,49 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "opt_group_shards" : self.opt_group_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
# })
def
_copy_main_params_to_model_params
(
self
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
model_index
,
gbuf_dtype
=
self
.
param_gbuf_map
[
param
]
model_shard
=
self
.
model_gbuf_shards
\
[
model_index
][
gbuf_dtype
][
"param_map"
][
param
][
"world"
]
assert
main_shard
.
size
==
model_shard
.
size
# Use DDP's contiguous buffer to temporarily hold params.
model_tensor
=
\
self
.
models
[
model_index
].
_grad_buffers
[
gbuf_dtype
].
data
main_tensor
=
self
.
main_param_shards
[
group_index
]
# Copy sub-range within tensor.
model_view
=
model_tensor
[
model_shard
.
start
:
model_shard
.
end
]
main_view
=
main_tensor
[
main_shard
.
start
:
main_shard
.
end
]
model_view
.
detach
().
copy_
(
main_view
)
# Debug.
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# "param" : tp(param),
# "model_index" : model_index,
# "gbuf_dtype" : str(gbuf_dtype),
# "model_grad_tensor" : tp(model_grad_tensor),
# "main_grad_tensor" : tp(main_grad_tensor),
# "model_grad_view" : tp(model_grad_view),
# "main_grad_view" : tp(main_grad_view),
# "model_shard" : str(model_shard),
# "main_shard" : str(main_shard),
# })
# pax(0, {
# "model_gbuf_shards" : self.model_gbuf_shards,
# "opt_group_shards" : self.opt_group_shards,
# })
# <<<
# <<<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment