Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
291592e4
Commit
291592e4
authored
Feb 22, 2022
by
Lawrence McAfee
Browse files
removed zero-size optimizer group shards.
parent
23f9238d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
75 additions
and
42 deletions
+75
-42
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+75
-42
No files found.
megatron/optimizer/optimizer.py
View file @
291592e4
...
...
@@ -875,7 +875,12 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# ))
# <<<
# pax(1, {
# Squeeze zero-size group shards.
for
group_index
,
group_shard
in
enumerate
(
group_shards
):
group_shard
[
"orig_group"
]
=
param_groups
[
group_index
]
group_shards
=
[
g
for
g
in
group_shards
if
g
[
"size"
]
>
0
]
# pax(0, {
# "param_group_map": [
# (g, str(p.shape))
# for p, g in param_group_map.items()
...
...
@@ -885,6 +890,47 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
return
group_shards
@
classmethod
def
allocate_main_param_shards
(
cls
,
opt_group_shards
):
# Allocate main param/grad shard.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
allocate_shard
=
lambda
shard_size
,
dtype
:
torch
.
empty
(
(
shard_size
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
True
)
# main_param_shards = []
for
group_index
,
group_shard
in
enumerate
(
opt_group_shards
):
group_size
=
group_shard
[
"size"
]
assert
group_size
!=
0
,
"temporary check ... remove me."
# ** todo: for dtype in model_main_dtypes ........ **
# Allocate shard.
# if group_size == 0:
# main_param = None
# else:
main_param
=
allocate_shard
(
group_size
,
torch
.
float
)
main_param
.
grad
=
allocate_shard
(
group_size
,
torch
.
float
)
mpu
.
set_tensor_model_parallel_attributes
(
main_param
,
True
,
0
,
1
)
# main_param_shards.append(main_param)
group_shard
[
"orig_group"
][
"params"
]
=
[
main_param
]
# # Update optimizer group.
# self.optimizer.param_groups[group_index]["params"] = [ main_param ]
# pax(1, {
# "opt_group_shards" : opt_group_shards,
# "main_param_shards" : main_param_shards,
# })
# return main_param_shards
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
bf16
,
grad_scaler
,
models
):
...
...
@@ -910,52 +956,36 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self
.
model_gbuf_shards
.
append
(
self
.
get_model_gbuf_shard_map
(
model
))
self
.
param_gbuf_map
=
self
.
get_param_gbuf_map
(
self
.
model_gbuf_shards
)
# pax(0, {"param_gbuf_map": [ (str(tuple(p.shape)), d) for p, d in self.param_gbuf_map.items() ]})
# Optimizer shards.
self
.
opt_group_shards
=
self
.
get_optimizer_group_shards
(
self
.
optimizer
.
param_groups
,
self
.
model_gbuf_shards
)
# Allocate main param/grad shard.
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
allocate_shard
=
lambda
shard_size
,
dtype
:
torch
.
empty
(
(
shard_size
,),
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
True
)
self
.
main_param_shards
=
[]
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
group_size
=
group_shard
[
"size"
]
# ** todo: for dtype in model_main_dtypes ........ **
# pax(0, {**{"opt_group_shards / %d" % i : g for i, g in enumerate(self.opt_group_shards)}})
# Allocate shard.
if
group_size
==
0
:
main_param
=
None
else
:
main_param
=
allocate_shard
(
group_size
,
torch
.
float
)
main_param
.
grad
=
allocate_shard
(
group_size
,
torch
.
float
)
mpu
.
set_tensor_model_parallel_attributes
(
main_param
,
True
,
0
,
1
)
self
.
main_param_shards
.
append
(
main_param
)
# Update optimizer group.
self
.
optimizer
.
param_groups
[
group_index
][
"params"
]
=
[
main_param
]
# Allocate main param shards.
# self.main_param_shards = \
# self.allocate_main_param_shards(self.opt_group_shards)
self
.
allocate_main_param_shards
(
self
.
opt_group_shards
)
# >>>
pax
(
0
,
{
"model_gbuf_shards"
:
self
.
model_gbuf_shards
,
"opt_group_shards"
:
self
.
opt_group_shards
,
"main_param_shards"
:
self
.
main_param_shards
,
})
#
pax(0, {
#
"model_gbuf_shards" : self.model_gbuf_shards,
#
"opt_group_shards" : self.opt_group_shards,
#
"main_param_shards" : self.main_param_shards,
#
})
# <<<
# Initialize main params.
self
.
_copy_model_params_to_main_params
()
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self
.
optimizer
.
param_groups
=
\
[
g
[
"orig_group"
]
for
g
in
self
.
opt_group_shards
]
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
...
...
@@ -1069,11 +1099,14 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
def
_collect_main_grad_data_for_unscaling
(
self
):
return
[
p
.
grad
.
data
for
p
in
self
.
main_param_shards
]
# return [ p.grad.data for p in self.main_param_shards if p is not None ]
def
_copy_model_params_to_main_params
(
self
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
main_param
=
self
.
main_param_shards
[
group_index
]
# main_param = self.main_param_shards[group_index]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
0
]
pax
(
0
,
{
"main_param"
:
tp
(
main_param
)})
for
model_param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
# Model shard.
...
...
@@ -1098,13 +1131,13 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# "model_shard" : str(model_shard),
# })
pax
(
1
,
{
**
{
"opt_group_shards / %d"
%
i
:
s
for
i
,
s
in
enumerate
(
self
.
opt_group_shards
)
},
"main_param_shards"
:
self
.
main_param_shards
,
})
#
pax(1, {
#
**{
#
"opt_group_shards / %d" % i : s
#
for i, s in enumerate(self.opt_group_shards)
#
},
#
"main_param_shards" : self.main_param_shards,
#
})
def
_copy_model_grads_to_main_grads
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment