Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
33626179
Commit
33626179
authored
Feb 15, 2022
by
Lawrence McAfee
Browse files
allocated main params/grads
parent
525a8351
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
16 deletions
+25
-16
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+25
-16
No files found.
megatron/optimizer/optimizer.py
View file @
33626179
...
@@ -1109,27 +1109,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1109,27 +1109,34 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self
.
optimizer
.
param_groups
,
self
.
optimizer
.
param_groups
,
self
.
model_gbuf_shards
)
self
.
model_gbuf_shards
)
pax
(
0
,
{
"opt_group_shards"
:
self
.
opt_group_shards
})
#
pax(0, {"opt_group_shards": self.opt_group_shards})
# Allocate main param/grad shard.
# Allocate main param/grad shard.
# Shard allocator.
# ** torch.nn.Parameter ??
# ** torch.nn.Parameter ??
# ** MemoryBuffer ??
# ** MemoryBuffer ??
# allocate_shard = lambda shard_size, dtype : torch.empty(
allocate_shard
=
lambda
shard_size
,
dtype
:
torch
.
empty
(
# (shard_size,),
(
shard_size
,),
# dtype = dtype,
dtype
=
dtype
,
# device = torch.cuda.current_device(),
device
=
torch
.
cuda
.
current_device
(),
# requires_grad = True)
requires_grad
=
True
)
self
.
main_param_shards
=
[]
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
# pax(0, {
# "group_index" : group_index,
# "group_shard" : group_shard,
# })
group_size
=
group_shard
[
"size"
]
# for dtype in model_main_dtypes ........
# >>>
# Allocate shard.
param_size_map
=
self
.
get_param_size_map
(
self
.
model_gbuf_shards
)
main_param
=
allocate_shard
(
group_size
,
torch
.
float
)
pax
(
0
,
{
main_param
.
grad
=
allocate_shard
(
group_size
,
torch
.
float
)
"model_gbuf_shards"
:
self
.
model_gbuf_shards
,
self
.
main_param_shards
.
append
(
main_param
)
"param_size_map"
:
param_size_map
,
})
# Update optimizer group.
# <<<
self
.
optimizer
.
param_groups
[
group_index
][
"params"
]
=
[
main_param
]
# Leverage state_dict() and load_state_dict() to
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
# recast preexisting per-param state tensors
...
@@ -1137,7 +1144,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1137,7 +1144,9 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# >>>
# >>>
pax
(
0
,
{
pax
(
0
,
{
"world_shard_infos"
:
self
.
world_shard_infos
,
"model_gbuf_shards"
:
self
.
model_gbuf_shards
,
"opt_group_shards"
:
self
.
opt_group_shards
,
"main_param_shards"
:
self
.
main_param_shards
,
})
})
# <<<
# <<<
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment