Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
525a8351
Commit
525a8351
authored
Feb 15, 2022
by
Lawrence McAfee
Browse files
created optimizer group shards
parent
f498a6fe
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
78 additions
and
18 deletions
+78
-18
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+78
-18
No files found.
megatron/optimizer/optimizer.py
View file @
525a8351
...
...
@@ -1016,6 +1016,68 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return
shard_map
# @classmethod
# def get_param_size_map(cls, model_gbuf_shards):
# param_size_map = {}
# for model_gbuf_shard_map in model_gbuf_shards:
# for dtype, gbuf_shard_map in model_gbuf_shard_map.items():
# for param, param_shard_map in gbuf_shard_map["param_map"].items():
# assert param not in param_size_map
# param_size_map[param] = param_shard_map["local"].size
# # pax(0, {
# # "dtype" : dtype,
# # "gbuf_shard_map" : gbuf_shard_map,
# # "param" : tp(param),
# # "param_shard_map" : param_shard_map,
# # })
# pax(0, {
# "model_gbuf_shards" : model_gbuf_shards,
# "param_size_map" : [ (str(p.shape), s) for p, s in param_size_map.items() ],
# })
# return param_size_map
@
classmethod
def
get_optimizer_group_shards
(
cls
,
param_groups
,
model_gbuf_shards
):
num_groups
=
len
(
param_groups
)
# Param group map.
param_group_map
=
{}
for
group_index
,
group
in
enumerate
(
param_groups
):
for
param
in
group
[
"params"
]:
assert
param
.
requires_grad
param_group_map
[
param
]
=
group_index
# Optimizer group shards.
group_shards
=
[
{
"size"
:
0
,
"param_map"
:
{}}
for
_
in
param_groups
]
for
model_gbuf_shard_map
in
model_gbuf_shards
:
for
dtype
,
gbuf_shard_map
in
model_gbuf_shard_map
.
items
():
for
param
in
gbuf_shard_map
[
"param_map"
]:
group_index
=
param_group_map
[
param
]
group_shard
=
group_shards
[
group_index
]
param_size
=
gbuf_shard_map
[
"param_map"
][
param
][
"local"
].
size
param_group_start
=
group_shard
[
"size"
]
param_group_end
=
param_group_start
+
param_size
param_group_shard
=
Shard
(
param_group_start
,
param_group_end
)
group_shard
[
"size"
]
+=
param_size
group_shard
[
"param_map"
][
param
]
=
param_group_shard
# raise Exception("hi.")
# pax(0, {"param_group_map": [
# (g, str(p.shape))
# for p, g in param_group_map.items()
# ]})
# pax(0, {"group_shards": group_shards})
return
group_shards
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
bf16
,
grad_scaler
,
models
):
...
...
@@ -1037,17 +1099,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# self.data_parallel_rank = mpu.get_data_parallel_rank()
# self.data_parallel_world_size = mpu.get_data_parallel_world_size()
# Param group map.
self
.
param_group_map
=
{}
for
group_index
,
group
in
enumerate
(
self
.
optimizer
.
param_groups
):
for
param
in
group
[
"params"
]:
assert
param
.
requires_grad
self
.
param_group_map
[
param
]
=
group_index
# Model grad buffer shards.
self
.
model_gbuf_shards
=
[]
for
model_index
,
model
in
enumerate
(
self
.
models
):
self
.
model_gbuf_shards
.
append
(
self
.
get_model_gbuf_shard_map
(
model
))
# pax(0, {"param_group_map": [
# (g, str(p.shape))
# for p, g in self.param_group_map.items()
# ]})
# Optimizer shards.
self
.
opt_group_shards
=
self
.
get_optimizer_group_shards
(
self
.
optimizer
.
param_groups
,
self
.
model_gbuf_shards
)
pax
(
0
,
{
"opt_group_shards"
:
self
.
opt_group_shards
})
# Allocate main param/grad shard.
# Shard allocator.
# ** torch.nn.Parameter ??
...
...
@@ -1058,18 +1122,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# device = torch.cuda.current_device(),
# requires_grad = True)
# Model grad buffer shards.
self
.
model_gbuf_shards
=
[]
for
model_index
,
model
in
enumerate
(
self
.
models
):
self
.
model_gbuf_shards
.
append
(
self
.
get_model_gbuf_shard_map
(
model
))
# Allocate main param/grad shard.
param_shard_map
=
self
.
get_param_shard_map
(
self
.
model_gbuf_shards
)
# >>>
param_size_map
=
self
.
get_param_size_map
(
self
.
model_gbuf_shards
)
pax
(
0
,
{
"model_gbuf_shards"
:
self
.
model_gbuf_shards
,
"param_s
hard
_map"
:
param_s
hard
_map
,
"param_s
ize
_map"
:
param_s
ize
_map
,
})
# <<<
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment