Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2c1660e7
Commit
2c1660e7
authored
Mar 14, 2022
by
Lawrence McAfee
Browse files
cleaned distrib_optimizer.py.
parent
efa3cbcf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
54 deletions
+5
-54
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+5
-54
No files found.
megatron/optimizer/distrib_optimizer.py
View file @
2c1660e7
...
@@ -25,11 +25,6 @@ from megatron import mpu
...
@@ -25,11 +25,6 @@ from megatron import mpu
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
# >>>
from
lutil
import
pax
,
tp
DEBUG_ITERATION
=
2
# 10
# <<<
class
Shard
:
class
Shard
:
def
__init__
(
self
,
start
,
end
):
def
__init__
(
self
,
start
,
end
):
...
@@ -42,10 +37,6 @@ class Shard:
...
@@ -42,10 +37,6 @@ class Shard:
return
"%d,%d [%d]"
%
(
self
.
start
,
self
.
end
,
self
.
size
)
return
"%d,%d [%d]"
%
(
self
.
start
,
self
.
end
,
self
.
size
)
# class Float16DistributedOptimizer(Float16OptimizerWithFloat16Params):
# class Float16DistributedOptimizer(MegatronOptimizer):
# class Float16DistributedOptimizer(BaseFloat16Optimizer):
# class DistributedOptimizer(MegatronOptimizer):
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
@
classmethod
@
classmethod
...
@@ -177,18 +168,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -177,18 +168,11 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
allocate_main_param_shards
(
cls
,
opt_group_shards
):
def
allocate_main_param_shards
(
cls
,
opt_group_shards
):
# Allocator method.
# Allocator method.
# >>>
allocate_shard
=
lambda
shard_size
,
dtype
:
torch
.
empty
(
allocate_shard
=
lambda
shard_size
,
dtype
:
torch
.
empty
(
(
shard_size
,),
(
shard_size
,),
dtype
=
dtype
,
dtype
=
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
True
)
requires_grad
=
True
)
# allocate_shard = lambda shard_size, dtype : torch.zeros(
# (shard_size,),
# dtype = dtype,
# device = torch.cuda.current_device(),
# requires_grad = True)
# <<<
# Allocate each group's param/grad shard.
# Allocate each group's param/grad shard.
for
group_index
,
group_shard
in
enumerate
(
opt_group_shards
):
for
group_index
,
group_shard
in
enumerate
(
opt_group_shards
):
...
@@ -295,29 +279,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -295,29 +279,12 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
def
get_main_grad
(
self
,
group_index
):
def
get_main_grad
(
self
,
group_index
):
return
self
.
get_main_param
(
group_index
).
grad
return
self
.
get_main_param
(
group_index
).
grad
# def load_state_dict(self):
# raise Exception("hi.")
# # def reload_model_params(self): # ... done in MixedPrecisionOptimizer
# # raise Exception("hi.")
# def state_dict(self):
# raise Exception("hi.")
def
state_dict
(
self
):
def
state_dict
(
self
):
state_dict
=
{}
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
if
self
.
grad_scaler
:
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
# state_dict['params'] = \
# [ p for g in self.optimizer.param_groups for p in g["params"] ]
state_dict
[
'groups'
]
=
[
g
[
'params'
]
for
g
in
self
.
optimizer
.
param_groups
]
state_dict
[
'groups'
]
=
[
g
[
'params'
]
for
g
in
self
.
optimizer
.
param_groups
]
# pax(0, { # ... only called on model rank 0
# # "optimizer" : self.optimizer,
# "state_dict" : state_dict,
# "state_dict / param_groups" : state_dict["optimizer"]["param_groups"],
# "optimizer / groups" : self.optimizer.param_groups,
# "state_dict / params" : [ p.shape for p in state_dict["params"] ],
# "optimizer / params" :
# [ p.shape for g in self.optimizer.param_groups for p in g["params"] ],
# })
return
state_dict
return
state_dict
...
@@ -330,10 +297,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -330,10 +297,6 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
'an old checkpoint ...'
)
'an old checkpoint ...'
)
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
# pax(0, {
# "state_dict" : state_dict,
# "params" : state_dict["params"],
# })
# Grad scaler.
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
if
'grad_scaler'
not
in
state_dict
:
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
...
@@ -349,32 +312,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -349,32 +312,20 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
# Copy data for the main params.
# Copy data for the main params.
current_groups
=
[
g
[
"params"
]
for
g
in
self
.
optimizer
.
param_groups
]
current_groups
=
[
g
[
"params"
]
for
g
in
self
.
optimizer
.
param_groups
]
assert
"groups"
in
state_dict
,
"key 'groups' not in state_dict."
assert
"groups"
in
state_dict
,
"key 'groups' not in state_dict."
# pax(0, {
# "state_dict" : state_dict,
# "current_groups" : current_groups,
# "saved_groups" : state_dict[params_key],
# })
for
current_group
,
saved_group
in
zip
(
current_groups
,
state_dict
[
"groups"
]):
for
current_group
,
saved_group
in
zip
(
current_groups
,
state_dict
[
"groups"
]):
# pax(0, {
# "current_group" : current_group,
# "saved_group" : saved_group,
# })
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
current_param
.
data
.
copy_
(
saved_param
.
data
)
def
zero_grad
(
self
,
set_to_none
=
True
):
def
zero_grad
(
self
,
set_to_none
=
True
):
# Collect model params.
model_params
=
[]
model_params
=
[]
for
model
in
self
.
models
:
for
model
in
self
.
models
:
for
dtype
,
param_map
in
model
.
_grad_buffer_param_index_map
.
items
():
for
dtype
,
param_map
in
model
.
_grad_buffer_param_index_map
.
items
():
model_params
.
extend
(
param_map
.
keys
())
model_params
.
extend
(
param_map
.
keys
())
# main_params = []
# for main_group in self.optimizer.param_groups:
# main_params.extend(main_group["params"])
#
** using
contiguous buffer; don't set
_
to
_n
one
**
#
Distributed optimizer requires
contiguous buffer; don't set
to
N
one
.
_zero_grad_group_helper
(
model_params
,
set_to_none
=
False
)
_zero_grad_group_helper
(
model_params
,
set_to_none
=
False
)
# _zero_grad_group_helper(params, set_to_none = False)
def
get_model_grad_buffer_dp_views
(
self
):
def
get_model_grad_buffer_dp_views
(
self
):
...
@@ -422,7 +373,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -422,7 +373,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
timers
(
'backward-params-all-reduce'
).
stop
()
timers
(
'backward-params-all-reduce'
).
stop
()
def
gather_model_params
(
self
,
args
,
timers
,
ITERATION
):
def
gather_model_params
(
self
,
args
,
timers
):
timers
(
'backward-params-all-gather'
).
start
()
timers
(
'backward-params-all-gather'
).
start
()
...
@@ -471,7 +422,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -471,7 +422,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
main_view
.
detach
().
copy_
(
model_view
)
main_view
.
detach
().
copy_
(
model_view
)
def
_copy_model_grads_to_main_grads
(
self
,
ITERATION
):
def
_copy_model_grads_to_main_grads
(
self
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
model_param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
for
model_param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
...
@@ -494,7 +445,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
...
@@ -494,7 +445,7 @@ class DistributedOptimizer(MixedPrecisionOptimizer):
main_view
.
detach
().
copy_
(
model_view
)
main_view
.
detach
().
copy_
(
model_view
)
def
_copy_main_params_to_model_params
(
self
,
ITERATION
):
def
_copy_main_params_to_model_params
(
self
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
group_index
,
group_shard
in
enumerate
(
self
.
opt_group_shards
):
for
model_param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
for
model_param
,
main_shard
in
group_shard
[
"param_map"
].
items
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment