Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
3f0bc681
Commit
3f0bc681
authored
Feb 14, 2022
by
Lawrence McAfee
Browse files
copying model grad slices to main grad
parent
6875dff5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
35 deletions
+32
-35
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+32
-35
No files found.
megatron/optimizer/optimizer.py
View file @
3f0bc681
...
...
@@ -934,7 +934,7 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def
reduce_gradients
(
self
,
model
):
# >>>
pax
(
0
,
{
"main param"
:
self
.
world_shard_info_groups
[
0
][
self
.
data_parallel_rank
][
"data"
][
torch
.
float
]})
#
pax(0, {"main param" : self.world_shard_info_groups[0][self.data_parallel_rank]["data"][torch.float]})
# <<<
# >>>
...
...
@@ -968,7 +968,6 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Copy model grads to main shard.
local_shard_info_groups
=
[
g
[
self
.
data_parallel_rank
]
for
g
in
self
.
world_shard_info_groups
]
for
group_index
,
local_shard_info
in
enumerate
(
local_shard_info_groups
):
...
...
@@ -986,51 +985,49 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
dtype_model_dict
=
param_model_map
[
param
]
dtype
=
dtype_model_dict
[
"dtype"
]
vmodel
=
dtype_model_dict
[
"model"
]
model_grad_buffer
=
vmodel
.
_grad_buffers
[
dtype
]
model_grad_buffer
=
vmodel
.
_grad_buffers
[
dtype
]
.
data
model_grad_buffer_start_index
=
\
vmodel
.
_grad_buffer_param_index_map
[
dtype
][
param
][
0
]
+
\
main_slice_orig_start_index
# main_grad_view = self.main_param_shard_groups \
# [group_index][torch.float].grad \
# [shard_indexes["shard"][0]:shard_indexes["shard"][1]]
main_grad_view
=
local_shard_info
[
"data"
][
torch
.
float
]
main_grad_view
=
local_shard_info
[
"data"
][
torch
.
float
].
grad
[
main_slice_shard_start_index
:
main_slice_shard_start_index
+
main_slice_size
]
model_grad_view
=
model_grad_buffer
[
model_grad_buffer_start_index
:
model_grad_buffer_start_index
+
main_slice_size
]
pax
(
0
,
{
"local_shard_info"
:
local_shard_info
,
"main_slice_orig_start_index"
:
main_slice_orig_start_index
,
"main_slice_shard_start_index"
:
main_slice_shard_start_index
,
"main_slice_size"
:
main_slice_size
,
"model_grad_buffer_start_index"
:
model_grad_buffer_start_index
,
"main_grad_view"
:
main_grad_view
,
})
main_grad_view
.
detach
().
copy_
(
model_grad_view
)
pax
(
0
,
{
# "dtype" : dtype,
# "vmodel" : vmodel,
"shard_indexes"
:
shard_indexes
,
"grad_buffer_indexes"
:
grad_buffer_indexes
,
"model_grad_view"
:
model_grad_view
,
"main_grad_views"
:
main_grad_view
,
})
# pax(0, {
# # "local_shard_info" : local_shard_info,
# "main_slice_orig_start_index" : main_slice_orig_start_index,
# "main_slice_shard_start_index" : main_slice_shard_start_index,
# "main_slice_size" : main_slice_size,
# "model_grad_buffer_start_index" :
# model_grad_buffer_start_index,
# "main_grad_view" : tp(main_grad_view),
# "main_grad_view / detach" : tp(main_grad_view.detach()),
# "model_grad_view" : tp(model_grad_view),
# })
pax
(
0
,
{
"group_index"
:
group_index
,
"local_shard_info"
:
local_shard_info
,
"shard_param_index_map"
:
shard_param_index_map
,
"param"
:
tp
(
param
),
"shard_indexes"
:
shard_indexes
,
"grad_buffer_indexes"
:
grad_buffer_indexes
,
})
#
pax(0, {
#
"group_index" : group_index,
#
"local_shard_info" : local_shard_info,
#
"shard_param_index_map" : shard_param_index_map,
#
"param" : tp(param),
#
"shard_indexes" : shard_indexes,
#
"grad_buffer_indexes" : grad_buffer_indexes,
#
})
pax
(
0
,
{
# "world_shard_info_groups" : self.world_shard_info_groups,
# **{"world_shard_info_groups / %d" % i : v
# for i, v in enumerate(self.world_shard_info_groups)},
"local_shard_info_groups"
:
local_shard_info_groups
,
"main_param_shard_groups"
:
self
.
main_param_shard_groups
,
# "main_param_shard_groups" : self.main_param_shard_groups,
# "local_shard_info_groups" : local_shard_info_groups,
"local_shard_info_groups"
:
[
g
[
"data"
]
for
g
in
local_shard_info_groups
],
})
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment