Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
dce054db
Unverified
Commit
dce054db
authored
Nov 19, 2020
by
Jeff Rasley
Committed by
GitHub
Nov 19, 2020
Browse files
backwards compatability w. v020 ckpts, fix issue with zero-1 ckpts (#543)
parent
9de21b72
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
4 additions
and
10 deletions
+4
-10
deepspeed/__init__.py
deepspeed/__init__.py
+2
-0
deepspeed/runtime/zero/stage1.py
deepspeed/runtime/zero/stage1.py
+2
-10
No files found.
deepspeed/__init__.py
View file @
dce054db
...
...
@@ -42,6 +42,8 @@ sys.modules['deepspeed.pt'] = deepspeed.pt
sys
.
modules
[
'deepspeed.pt.deepspeed_utils'
]
=
deepspeed
.
runtime
.
utils
setattr
(
deepspeed
.
pt
,
'deepspeed_config'
,
deepspeed
.
runtime
.
config
)
sys
.
modules
[
'deepspeed.pt.deepspeed_config'
]
=
deepspeed
.
runtime
.
config
setattr
(
deepspeed
.
pt
,
'loss_scaler'
,
deepspeed
.
runtime
.
fp16
.
loss_scaler
)
sys
.
modules
[
'deepspeed.pt.loss_scaler'
]
=
deepspeed
.
runtime
.
fp16
.
loss_scaler
def
initialize
(
args
,
...
...
deepspeed/runtime/zero/stage1.py
View file @
dce054db
...
...
@@ -194,7 +194,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
# max elems per param group
self
.
max_elems_per_comm
=
[]
self
.
legacy_max_elements_per_comm
=
max_elements_per_comm
# loop to deal with groups
for
i
,
param_group
in
enumerate
(
self
.
optimizer
.
param_groups
):
...
...
@@ -859,7 +858,6 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
state_dict
[
'zero_stage'
]
=
ZERO_OPTIMIZATION_OPTIMIZER_STATES
state_dict
[
'partition_count'
]
=
self
.
partition_count
state_dict
[
'num_comm_intervals_per_group'
]
=
self
.
num_comm_intervals_per_group
state_dict
[
'max_elems_per_comm'
]
=
self
.
max_elems_per_comm
# Remove paddings for DP alignment to enable loading for other alignment values
fp32_groups_without_padding
=
self
.
_get_groups_without_padding
(
...
...
@@ -933,10 +931,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
sd
[
'local_sub_partitions_of_fp32_groups'
][
group_idx
]
for
sd
in
all_state_dict
]
if
'max_elems_per_comm'
in
all_state_dict
[
0
]:
max_elems_per_comm
=
all_state_dict
[
0
][
'max_elems_per_comm'
][
group_idx
]
else
:
max_elems_per_comm
=
self
.
legacy_max_elements_per_comm
max_elems_per_comm
=
self
.
max_elems_per_comm
[
group_idx
]
sub_partition_weights
=
self
.
_retrieve_group_sub_partition_weights
(
all_partition_fp32_weights
,
...
...
@@ -1009,10 +1004,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
all_partition_group_states
=
[
sd
[
'base_optimizer_state'
][
group_idx
]
for
sd
in
state_dict_list
]
if
'max_elems_per_comm'
in
state_dict_list
[
0
]:
max_elems_per_comm
=
state_dict_list
[
0
][
'max_elems_per_comm'
][
group_idx
]
else
:
max_elems_per_comm
=
self
.
legacy_max_elements_per_comm
max_elems_per_comm
=
self
.
max_elems_per_comm
[
group_idx
]
group_optimizer_states
=
self
.
_retrieve_group_optimizer_states
(
all_partition_group_states
,
max_elems_per_comm
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment