Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a74e245c
Commit
a74e245c
authored
Feb 14, 2022
by
Lawrence McAfee
Browse files
built local shard param index map
parent
f7232502
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
13 deletions
+57
-13
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+57
-13
No files found.
megatron/optimizer/optimizer.py
View file @
a74e245c
...
...
@@ -773,6 +773,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self
.
main_param_shard_groups
=
[]
for
group_index
,
model_param_group
in
enumerate
(
self
.
model_param_groups
):
# pax(0, {
# "model_param_group" : model_param_group,
# # "offset_map" : {str(p.shape):o for p, o in model_param_group["offset_map"].items()},
# "offset_map" : [(o,tp(p)) for p, o in model_param_group["offset_map"].items()],
# })
# Group sizes.
model_param_size
=
model_param_group
[
"size"
]
max_world_shard_size
=
int
(
math
.
ceil
(
model_param_size
/
self
.
data_parallel_world_size
))
...
...
@@ -790,20 +797,49 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
})
self
.
world_shard_info_groups
.
append
(
world_shard_infos
)
# pax(0, {"world_shard_infos": world_shard_infos})
# DP local shard info.
local_shard_info
=
world_shard_infos
[
self
.
data_parallel_rank
]
local_shard_start_index
=
local_shard_info
[
"start"
]
local_shard_end_index
=
local_shard_info
[
"end"
]
local_shard_size
=
local_shard_info
[
"size"
]
# Shard param index map.
local_shard_info
[
"param_index_map"
]
=
{}
for
param
,
offset_dict
in
model_param_group
[
"offset_map"
].
items
():
param_start_index
=
offset_dict
[
"start"
]
param_end_index
=
offset_dict
[
"end"
]
param_shard_start_index
=
max
(
local_shard_start_index
,
param_start_index
)
param_shard_end_index
=
min
(
local_shard_end_index
,
param_end_index
)
if
param_shard_end_index
>
param_shard_start_index
:
local_shard_info
[
"param_index_map"
][
param
]
=
{
"start"
:
param_shard_start_index
-
local_shard_start_index
,
"end"
:
param_shard_end_index
-
local_shard_start_index
,
}
# pax(0, {
# "local index" : "%d, %d" % (
# local_shard_start_index,
# local_shard_end_index,
# ),
# "param index" : "%s, %d" % (
# param_start_index,
# param_end_index,
# ),
# "param" : tp(param),
# "shard_param_index_map" : shard_param_index_map,
# "local_shard_info" : local_shard_info,
# })
pax
(
0
,
{
"local_shard_info"
:
local_shard_info
})
# Allocate shards.
# (Non-fp32 shards are for convenience; e.g., intermediaries
# between model params and main fp32 shard. Necessary???)
local_shard_size
=
world_shard_infos
[
self
.
data_parallel_rank
][
"size"
]
# # self.main_param_shard = allocate_shard(torch.float)
# # self.main_grad_shard = allocate_shard(torch.float)
# self.param_shard_map = {ty:allocate_shard(ty) for ty in dtypes}
# self.grad_shard_map = {ty:allocate_shard(ty) for ty in dtypes}
# self.adam_m_shard = allocate_shard(torch.float)
# self.adam_v_shard = allocate_shard(torch.float)
main_param_shards
=
{
ty
:
allocate_shard
(
local_shard_size
,
ty
)
for
ty
in
model_main_dtypes
}
...
...
@@ -863,9 +899,17 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Copy model grads to main shard.
self
.
world_shard_info_groups
=
[]
# world_group_shard_infos ?
self
.
main_param_shard_groups
=
[]
pax
(
0
,
{
"main_shard_info_groups"
:
self
.
main_shard_info_groups
})
local_shard_info_groups
=
[
g
[
self
.
data_parallel_rank
]
for
g
in
self
.
world_shard_info_groups
]
pax
(
0
,
{
# "world_shard_info_groups" : self.world_shard_info_groups,
# **{"world_shard_info_groups / %d" % i : v
# for i, v in enumerate(self.world_shard_info_groups)},
"local_shard_info_groups"
:
local_shard_info_groups
,
"main_param_shard_groups"
:
self
.
main_param_shard_groups
,
# "main_param_shard_groups" : self.main_param_shard_groups,
})
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Reduce-scatter.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment