Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
9b7854e4
Commit
9b7854e4
authored
Feb 22, 2022
by
Lawrence McAfee
Browse files
more cleanup of main params/grads
parent
291592e4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
16 additions
and
5 deletions
+16
-5
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+16
-5
No files found.
megatron/optimizer/optimizer.py
View file @
9b7854e4
...
@@ -989,6 +989,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -989,6 +989,11 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
def
get_main_param
(
self
,
group_index
):
return
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
0
]
def
get_main_grad
(
self
,
group_index
):
return
self
.
get_main_param
(
group_index
).
grad
def
load_state_dict
(
self
):
def
load_state_dict
(
self
):
raise
Exception
(
"hi."
)
raise
Exception
(
"hi."
)
def
reload_model_params
(
self
):
def
reload_model_params
(
self
):
...
@@ -1098,15 +1103,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1098,15 +1103,19 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# })
# })
def
_collect_main_grad_data_for_unscaling
(
self
):
def
_collect_main_grad_data_for_unscaling
(
self
):
return
[
p
.
grad
.
data
for
p
in
self
.
main_param_shards
]
#
return [ p.grad.data for p in self.main_param_shards ]
# return [ p.grad.data for p in self.main_param_shards if p is not None ]
# return [ p.grad.data for p in self.main_param_shards if p is not None ]
return
[
self
.
get_main_grad
(
gi
).
data
for
gi
in
range
(
len
(
self
.
opt_group_shards
))
]
def
_copy_model_params_to_main_params
(
self
):
def
_copy_model_params_to_main_params
(
self
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
# main_param = self.main_param_shards[group_index]
# main_param = self.main_param_shards[group_index]
main_param
=
self
.
optimizer
.
param_groups
[
group_index
][
"params"
][
0
]
# main_param = self.optimizer.param_groups[group_index]["params"][0]
pax
(
0
,
{
"main_param"
:
tp
(
main_param
)})
main_param
=
self
.
get_main_param
(
group_index
)
# if group_index > 0:
# pax({"main_param": tp(main_param)})
for
model_param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
for
model_param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
# Model shard.
# Model shard.
...
@@ -1152,7 +1161,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1152,7 +1161,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Copy from DDP's contiguous buffer to main shard's grad.
# Copy from DDP's contiguous buffer to main shard's grad.
model_grad
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
model_grad
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
main_grad
=
self
.
main_param_shards
[
group_index
].
grad
# main_grad = self.main_param_shards[group_index].grad
main_grad
=
self
.
get_main_grad
(
group_index
)
# Copy sub-range within tensor.
# Copy sub-range within tensor.
model_view
=
model_grad
[
model_shard
.
start
:
model_shard
.
end
]
model_view
=
model_grad
[
model_shard
.
start
:
model_shard
.
end
]
...
@@ -1203,7 +1213,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
...
@@ -1203,7 +1213,8 @@ class Float16DistributedOptimizer(BaseFloat16Optimizer):
# Use DDP's contiguous buffer to temporarily hold params.
# Use DDP's contiguous buffer to temporarily hold params.
model_param
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
model_param
=
self
.
models
[
model_index
].
_grad_buffers
[
dtype
].
data
main_param
=
self
.
main_param_shards
[
group_index
]
# main_param = self.main_param_shards[group_index]
main_param
=
self
.
get_main_param
(
group_index
)
# Copy sub-range within tensor.
# Copy sub-range within tensor.
model_view
=
model_param
[
model_shard
.
start
:
model_shard
.
end
]
model_view
=
model_param
[
model_shard
.
start
:
model_shard
.
end
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment