Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
54c0267e
"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "d44e31aec2a450a31bc80662942094d617a585c2"
Commit
54c0267e
authored
Jun 20, 2020
by
Tunji Ruwase
Browse files
Load non-DeepSpeed checkpoints into ZeRO optimizer
parent
2e6d93e0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
2 deletions
+27
-2
deepspeed/pt/deepspeed_light.py
deepspeed/pt/deepspeed_light.py
+0
-2
deepspeed/pt/fp16_optimizer.py
deepspeed/pt/fp16_optimizer.py
+1
-0
deepspeed/pt/zero_optimizer_stage1.py
deepspeed/pt/zero_optimizer_stage1.py
+26
-0
No files found.
deepspeed/pt/deepspeed_light.py
View file @
54c0267e
...
@@ -561,7 +561,6 @@ class DeepSpeedLight(Module):
...
@@ -561,7 +561,6 @@ class DeepSpeedLight(Module):
if
zero_stage
==
ZERO_OPTIMIZATION_OPTIMIZER_STATES
:
if
zero_stage
==
ZERO_OPTIMIZATION_OPTIMIZER_STATES
:
assert
self
.
zero_reduce_scatter
(),
'Stage 1 only supports reduce scatter mode'
assert
self
.
zero_reduce_scatter
(),
'Stage 1 only supports reduce scatter mode'
logger
.
info
(
'Creating fp16 ZeRO Optimizer Stage 1'
)
optimizer
=
FP16_DeepSpeedZeroOptimizer_Stage1
(
optimizer
=
FP16_DeepSpeedZeroOptimizer_Stage1
(
optimizer
,
optimizer
,
static_loss_scale
=
self
.
loss_scale
(),
static_loss_scale
=
self
.
loss_scale
(),
...
@@ -593,7 +592,6 @@ class DeepSpeedLight(Module):
...
@@ -593,7 +592,6 @@ class DeepSpeedLight(Module):
gradient_predivide_factor
=
self
.
gradient_predivide_factor
())
gradient_predivide_factor
=
self
.
gradient_predivide_factor
())
else
:
else
:
raise
NotImplementedError
(
"ZeRO stage {} not implemented"
.
format
(
zero_stage
))
raise
NotImplementedError
(
"ZeRO stage {} not implemented"
.
format
(
zero_stage
))
logger
.
info
(
'Creating fp16 zero stage {} optimizer'
.
format
(
zero_stage
))
return
optimizer
return
optimizer
...
...
deepspeed/pt/fp16_optimizer.py
View file @
54c0267e
...
@@ -353,6 +353,7 @@ class FP16_Optimizer(object):
...
@@ -353,6 +353,7 @@ class FP16_Optimizer(object):
state_dict
[
'clip_grad'
]
=
self
.
clip_grad
state_dict
[
'clip_grad'
]
=
self
.
clip_grad
return
state_dict
return
state_dict
# Refresh fp32 master params from fp16 copies
def
refresh_fp32_params
(
self
):
def
refresh_fp32_params
(
self
):
for
current
,
saved
in
zip
(
self
.
fp32_groups_flat
,
self
.
fp16_groups_flat
):
for
current
,
saved
in
zip
(
self
.
fp32_groups_flat
,
self
.
fp16_groups_flat
):
current
.
data
.
copy_
(
saved
.
data
)
current
.
data
.
copy_
(
saved
.
data
)
...
...
deepspeed/pt/zero_optimizer_stage1.py
View file @
54c0267e
...
@@ -10,6 +10,24 @@ from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
...
@@ -10,6 +10,24 @@ from deepspeed.pt.loss_scaler import LossScaler, DynamicLossScaler
from
deepspeed.pt.deepspeed_utils
import
get_grad_norm
,
CheckOverflow
from
deepspeed.pt.deepspeed_utils
import
get_grad_norm
,
CheckOverflow
def
flatten_dense_tensors_sub_partition_aligned_
(
tensor_list
,
dp
,
max_elements_per_comm
,
pg
):
num_elements
=
sum
(
t
.
numel
()
for
t
in
tensor_list
)
log_dist
(
"Total number of elements in model: {}, max elements per com: {}"
.
format
(
num_elements
,
max_elements_per_comm
),
ranks
=
[
0
])
# Compute aligned partition size based on parameter count
aligned_param_partition_size
=
math
.
ceil
(
num_elements
/
dp
)
# Compute aligned partition size based on communication size
aligned_comm_partition_size
=
int
(
max_elements_per_comm
//
dp
)
def
flatten_dense_tensors_sub_partition_aligned
(
tensor_list
,
def
flatten_dense_tensors_sub_partition_aligned
(
tensor_list
,
dp
,
dp
,
max_elements_per_comm
,
max_elements_per_comm
,
...
@@ -780,6 +798,14 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
...
@@ -780,6 +798,14 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
'local_sub_partitions_of_fp32_groups'
]
=
self
.
local_sub_partitions_of_fp32_groups
'local_sub_partitions_of_fp32_groups'
]
=
self
.
local_sub_partitions_of_fp32_groups
return
state_dict
return
state_dict
# Refresh the fp32 master params from the fp16 copies.
def
refresh_fp32_params
(
self
):
partition_id
=
dist
.
get_rank
(
group
=
self
.
dp_process_group
)
for
fp16_all_sub_partitions
,
fp32_local_sub_partitions
in
zip
(
self
.
parallel_sub_partitioned_fp16_groups
,
self
.
local_sub_partitions_of_fp32_groups
):
for
local_sub_partition_param_fp16
,
local_sub_partition_param_fp32
in
zip
(
fp16_all_sub_partitions
[
partition_id
],
fp32_local_sub_partitions
):
local_sub_partition_param_fp32
.
data
.
copy_
(
local_sub_partition_param_fp16
.
data
)
def
load_state_dict
(
self
,
state_dict
,
load_optimizer_states
=
True
):
def
load_state_dict
(
self
,
state_dict
,
load_optimizer_states
=
True
):
"""
"""
Loads a state_dict created by an earlier call to state_dict().
Loads a state_dict created by an earlier call to state_dict().
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment