Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
deepspeed
Commits
6021b702
Unverified
Commit
6021b702
authored
Nov 21, 2020
by
Olatunji Ruwase
Committed by
GitHub
Nov 21, 2020
Browse files
Support non-tensor state in checkpoint (#548)
parent
0178e6cc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
163 additions
and
58 deletions
+163
-58
deepspeed/runtime/zero/stage1.py
deepspeed/runtime/zero/stage1.py
+9
-4
tests/unit/simple_model.py
tests/unit/simple_model.py
+31
-0
tests/unit/test_checkpointing.py
tests/unit/test_checkpointing.py
+123
-54
No files found.
deepspeed/runtime/zero/stage1.py
View file @
6021b702
...
...
@@ -947,9 +947,10 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
state_key
,
all_partition_states
,
max_elems_per_comm
):
partition_id
=
dist
.
get_rank
(
group
=
self
.
dp_process_group
)
alignment
=
dist
.
get_world_size
(
group
=
self
.
dp_process_group
)
if
not
torch
.
is_tensor
(
all_partition_states
[
0
]):
return
all_partition_states
[
0
]
alignment
=
dist
.
get_world_size
(
group
=
self
.
dp_process_group
)
flat_merged_partitions
=
flatten_dense_tensors_sub_partition_aligned
(
tensor_list
=
all_partition_states
,
dp
=
dist
.
get_world_size
(
group
=
self
.
dp_process_group
),
...
...
@@ -964,6 +965,7 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
dp_process_group
=
self
.
dp_process_group
)
partition_id
=
dist
.
get_rank
(
group
=
self
.
dp_process_group
)
return
[
sub_partition
for
sub_partition
in
dp_sub_partitions
[
partition_id
]]
# Compute the optimizer state partitions for the group by
...
...
@@ -1013,8 +1015,11 @@ class FP16_DeepSpeedZeroOptimizer_Stage1(object):
for
group_idx
,
group
in
enumerate
(
self
.
optimizer
.
param_groups
):
for
param_idx
,
param
in
enumerate
(
group
[
'params'
]):
for
key
,
saved
in
base_optimizer_group_states
[
group_idx
].
items
():
current
=
self
.
optimizer
.
state
[
param
][
key
]
current
.
data
.
copy_
(
saved
[
param_idx
].
data
)
if
torch
.
is_tensor
(
self
.
optimizer
.
state
[
param
][
key
]):
current
=
self
.
optimizer
.
state
[
param
][
key
]
current
.
data
.
copy_
(
saved
[
param_idx
].
data
)
else
:
self
.
optimizer
.
state
[
param
][
key
]
=
saved
# Restore base optimizer fp32 weights from ZeRO fp16 weights
def
_restore_from_fp16_weights
(
self
):
...
...
tests/unit/simple_model.py
View file @
6021b702
...
...
@@ -101,6 +101,37 @@ class SimpleOptimizer(torch.optim.Optimizer):
return
loss
class
HybridStateOptimizer
(
torch
.
optim
.
Optimizer
):
def
__init__
(
self
,
params
,
lr
=
0.11072018
):
defaults
=
dict
(
lr
=
lr
)
super
(
HybridStateOptimizer
,
self
).
__init__
(
params
,
defaults
)
def
__setstate__
(
self
,
state
):
super
(
HybridStateOptimizer
,
self
).
__setstate__
(
state
)
def
step
(
self
,
closure
=
None
):
loss
=
None
if
closure
is
not
None
:
loss
=
closure
()
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
if
p
.
grad
is
None
:
continue
state
=
self
.
state
[
p
]
if
len
(
state
)
==
0
:
state
[
'integer_step'
]
=
0
state
[
'tensor_step'
]
=
torch
.
zeros
(
1
)
d_p
=
p
.
grad
.
data
p
.
data
.
add_
(
-
group
[
'lr'
],
d_p
)
state
[
'integer_step'
]
+=
1
state
[
'tensor_step'
]
+=
1
return
loss
class
PLD_SimpleModel
(
SimpleModel
):
def
__init__
(
self
,
hidden_dim
,
empty_grad
=
False
,
rank
=
0
):
super
(
PLD_SimpleModel
,
self
).
__init__
(
hidden_dim
,
empty_grad
,
rank
)
...
...
tests/unit/test_checkpointing.py
View file @
6021b702
...
...
@@ -36,6 +36,7 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
compare_deepspeed_states
(
saved_model
,
loaded_model
)
for
p0
,
p1
in
zip
(
saved_model
.
module
.
parameters
(),
loaded_model
.
module
.
parameters
()):
assert
id
(
p0
)
!=
id
(
p1
),
f
'Comparing fp16 model state tensor against itself :
{
id
(
p0
)
}
<====>
{
id
(
p1
)
}
'
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"FP16 model state
{
p0
}
is not equal to
{
p1
}
"
if
not
compare_optimizer
:
...
...
@@ -43,20 +44,24 @@ def compare_model_states(saved_model, loaded_model, compare_optimizer=True):
if
isinstance
(
saved_model
.
optimizer
,
FP16_DeepSpeedZeroOptimizer
):
for
p0
,
p1
in
zip
(
saved_model
.
optimizer
.
single_partition_of_fp32_groups
,
loaded_model
.
optimizer
.
single_partition_of_fp32_groups
):
assert
id
(
p0
)
!=
id
(
p1
),
f
'Comparing fp32 model state tensor against itself:
{
id
(
p0
)
}
<====>
{
id
(
p1
)
}
'
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"Fp32 model states
{
p0
}
is not equal to
{
p1
}
"
elif
isinstance
(
saved_model
.
optimizer
,
FP16_DeepSpeedZeroOptimizer_Stage1
):
for
partition0
,
partition1
in
zip
(
saved_model
.
optimizer
.
local_sub_partitions_of_fp32_groups
,
loaded_model
.
optimizer
.
local_sub_partitions_of_fp32_groups
):
for
p0
,
p1
in
zip
(
partition0
,
partition1
):
assert
id
(
p0
)
!=
id
(
p1
),
f
'Comparing fp32 model state tensor against itself:
{
id
(
p0
)
}
<====>
{
id
(
p1
)
}
'
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"Fp32 model states
{
p0
}
is not equal to
{
p1
}
"
elif
isinstance
(
saved_model
.
optimizer
,
FP16_Optimizer
):
for
p0
,
p1
in
zip
(
saved_model
.
optimizer
.
fp32_groups_flat
,
loaded_model
.
optimizer
.
fp32_groups_flat
):
assert
id
(
p0
)
!=
id
(
p1
),
f
'Comparing fp32 model state tensor against itself:
{
id
(
p0
)
}
<====>
{
id
(
p1
)
}
'
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"FP32 model states
{
p0
}
is not equal to
{
p1
}
"
elif
isinstance
(
saved_model
.
optimizer
,
FP16_UnfusedOptimizer
):
for
params0
,
params1
in
zip
(
saved_model
.
optimizer
.
fp32_groups
,
loaded_model
.
optimizer
.
fp32_groups
):
for
p0
,
p1
in
zip
(
params0
,
params1
):
assert
id
(
p0
)
!=
id
(
p1
),
f
'Comparing fp32 model state tensor against itself:
{
id
(
p0
)
}
<====>
{
id
(
p1
)
}
'
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"FP32 model states
{
p0
}
is not equal to
{
p1
}
"
elif
isinstance
(
saved_model
.
optimizer
,
torch
.
optim
.
Optimizer
):
pass
...
...
@@ -72,6 +77,7 @@ def compare_optimizer_states(saved_model, loaded_model, hidden_dim, fp16=True):
loaded_optimizer
.
state
.
values
()):
for
s0
,
s1
in
zip
(
state0
.
values
(),
state1
.
values
()):
if
isinstance
(
s0
,
torch
.
Tensor
)
and
isinstance
(
s1
,
torch
.
Tensor
):
assert
id
(
s0
)
!=
id
(
s1
),
f
'Comparing optimizer state tensor against itself:
{
id
(
s0
)
}
<====>
{
id
(
s1
)
}
'
assert
torch
.
equal
(
s0
,
s1
)
else
:
assert
s0
==
s1
...
...
@@ -100,18 +106,34 @@ def compare_lr_scheduler_states(saved_model, loaded_model):
assert
state0
==
state1
def
create_deepspeed_model
(
args
,
model
,
base_optimizer
):
if
base_optimizer
is
None
:
ds_model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
())
else
:
ds_model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
optimizer
=
base_optimizer
)
return
ds_model
def
checkpoint_correctness_verification
(
args
,
model
,
model
s
,
hidden_dim
,
tmpdir
,
load_optimizer_states
=
False
,
load_lr_scheduler_states
=
False
,
fp16
=
True
,
train_batch
=
False
):
train_batch
=
False
,
base_optimizers
=
[
None
,
None
]):
dtype
=
torch
.
half
if
fp16
else
torch
.
float32
ds_model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
())
ds_model
=
create_deepspeed_model
(
args
=
args
,
model
=
models
[
0
],
base_optimizer
=
base_optimizers
[
0
])
data_loader
=
random_dataloader
(
model
=
ds_model
,
total_samples
=
50
,
hidden_dim
=
hidden_dim
,
...
...
@@ -125,7 +147,6 @@ def checkpoint_correctness_verification(args,
else
:
for
n
,
batch
in
enumerate
(
data_loader
):
loss
=
ds_model
(
batch
[
0
],
batch
[
1
])
print
(
loss
)
ds_model
.
backward
(
loss
)
ds_model
.
step
()
...
...
@@ -136,9 +157,9 @@ def checkpoint_correctness_verification(args,
trained_model
.
save_checkpoint
(
save_folder
,
save_tag
)
loaded_model
,
_
,
_
,
_
=
deepspeed
.
initialize
(
args
=
args
,
model
=
model
,
model_parameters
=
model
.
parameters
()
)
loaded_model
=
create_deepspeed_model
(
args
=
args
,
model
=
model
s
[
1
]
,
base_optimizer
=
base_optimizers
[
1
]
)
loaded_model
.
load_checkpoint
(
save_folder
,
save_tag
,
...
...
@@ -191,25 +212,26 @@ def test_checkpoint_unfused_optimizer(tmpdir):
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
model
s
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_unfused_optimizer
(
args
,
model
,
model
s
,
hidden_dim
,
load_optimizer_states
):
checkpoint_correctness_verification
(
args
,
model
,
hidden_dim
,
tmpdir
,
model
s
=
models
,
hidden_dim
=
hidden_dim
,
tmpdir
=
tmpdir
,
load_optimizer_states
=
load_optimizer_states
)
_test_checkpoint_unfused_optimizer
(
args
=
args
,
model
=
model
,
model
s
=
model
s
,
hidden_dim
=
hidden_dim
,
load_optimizer_states
=
True
)
_test_checkpoint_unfused_optimizer
(
args
=
args
,
model
=
model
,
model
s
=
model
s
,
hidden_dim
=
hidden_dim
,
load_optimizer_states
=
False
)
...
...
@@ -236,22 +258,26 @@ def test_checkpoint_fused_optimizer(tmpdir):
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
model
s
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_fused_optimizer
(
args
,
model
,
hidden_dim
,
load_optimizer_states
):
def
_test_checkpoint_fused_optimizer
(
args
,
models
,
hidden_dim
,
load_optimizer_states
):
checkpoint_correctness_verification
(
args
,
model
,
hidden_dim
,
tmpdir
,
model
s
=
models
,
hidden_dim
=
hidden_dim
,
tmpdir
=
tmpdir
,
load_optimizer_states
=
load_optimizer_states
)
_test_checkpoint_fused_optimizer
(
args
=
args
,
model
=
model
,
model
s
=
model
s
,
hidden_dim
=
hidden_dim
,
load_optimizer_states
=
True
)
_test_checkpoint_fused_optimizer
(
args
=
args
,
model
=
model
,
model
s
=
model
s
,
hidden_dim
=
hidden_dim
,
load_optimizer_states
=
False
)
...
...
@@ -293,18 +319,18 @@ def test_checkpoint_zero_optimizer(tmpdir, zero_stage, use_cpu_offload):
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
model
s
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_zero_optimizer
(
args
,
model
,
hidden_dim
,
load_optimizer_states
):
def
_test_checkpoint_zero_optimizer
(
args
,
model
s
,
hidden_dim
,
load_optimizer_states
):
checkpoint_correctness_verification
(
args
,
model
,
hidden_dim
,
tmpdir
,
model
s
=
models
,
hidden_dim
=
hidden_dim
,
tmpdir
=
tmpdir
,
load_optimizer_states
=
load_optimizer_states
)
_test_checkpoint_zero_optimizer
(
args
=
args
,
model
=
model
,
model
s
=
model
s
,
hidden_dim
=
hidden_dim
,
load_optimizer_states
=
True
)
...
...
@@ -346,21 +372,21 @@ def test_checkpoint_zero_no_optimizer(tmpdir, zero_stage, use_cpu_offload):
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
model
s
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_zero_no_optimizer
(
args
,
model
,
model
s
,
hidden_dim
,
load_optimizer_states
):
checkpoint_correctness_verification
(
args
,
model
,
hidden_dim
,
tmpdir
,
model
s
=
models
,
hidden_dim
=
hidden_dim
,
tmpdir
=
tmpdir
,
load_optimizer_states
=
load_optimizer_states
)
_test_checkpoint_zero_no_optimizer
(
args
=
args
,
model
=
model
,
model
s
=
model
s
,
hidden_dim
=
hidden_dim
,
load_optimizer_states
=
False
)
...
...
@@ -412,24 +438,24 @@ def test_checkpoint_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
model
s
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_lr_scheduler
(
args
,
model
,
model
s
,
hidden_dim
,
load_optimizer_states
,
load_lr_scheduler_states
):
checkpoint_correctness_verification
(
args
,
model
,
hidden_dim
,
tmpdir
,
model
s
=
models
,
hidden_dim
=
hidden_dim
,
tmpdir
=
tmpdir
,
load_optimizer_states
=
load_optimizer_states
,
load_lr_scheduler_states
=
load_lr_scheduler_states
)
_test_checkpoint_lr_scheduler
(
args
=
args
,
model
=
model
,
model
s
=
model
s
,
hidden_dim
=
hidden_dim
,
load_optimizer_states
=
False
,
load_lr_scheduler_states
=
True
)
...
...
@@ -478,24 +504,24 @@ def test_checkpoint_no_lr_scheduler(tmpdir, zero_stage, use_cpu_offload):
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
model
s
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_no_lr_scheduler
(
args
,
model
,
model
s
,
hidden_dim
,
load_optimizer_states
,
load_lr_scheduler_states
):
checkpoint_correctness_verification
(
args
,
model
,
hidden_dim
,
tmpdir
,
model
s
=
models
,
hidden_dim
=
hidden_dim
,
tmpdir
=
tmpdir
,
load_optimizer_states
=
load_optimizer_states
,
load_lr_scheduler_states
=
load_lr_scheduler_states
)
_test_checkpoint_no_lr_scheduler
(
args
=
args
,
model
=
model
,
model
s
=
model
s
,
hidden_dim
=
hidden_dim
,
load_optimizer_states
=
False
,
load_lr_scheduler_states
=
False
)
...
...
@@ -523,13 +549,17 @@ def test_checkpoint_fp32_optimizer(tmpdir):
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
model
=
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
model
s
=
[
SimpleModel
(
hidden_dim
,
empty_grad
=
False
)
for
_
in
range
(
2
)]
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_fp32_optimizer
(
args
,
model
,
hidden_dim
):
checkpoint_correctness_verification
(
args
,
model
,
hidden_dim
,
tmpdir
,
fp16
=
False
)
def
_test_checkpoint_fp32_optimizer
(
args
,
models
,
hidden_dim
):
checkpoint_correctness_verification
(
args
,
models
=
models
,
hidden_dim
=
hidden_dim
,
tmpdir
=
tmpdir
,
fp16
=
False
)
_test_checkpoint_fp32_optimizer
(
args
=
args
,
model
=
model
,
hidden_dim
=
hidden_dim
)
_test_checkpoint_fp32_optimizer
(
args
=
args
,
model
s
=
model
s
,
hidden_dim
=
hidden_dim
)
@
pytest
.
mark
.
parametrize
(
"zero_stage"
,
[
0
,
1
])
...
...
@@ -571,10 +601,10 @@ def test_checkpoint_pipe_engine(zero_stage, tmpdir, stages=2):
@
distributed_test
(
world_size
=
4
)
def
_test
(
save_folder
,
num_stages
):
args
=
args_from_dict
(
tmpdir
,
config_dict
)
model
=
LinearStackPipe
(
num_stages
=
num_stages
)
model
s
=
[
LinearStackPipe
(
num_stages
=
num_stages
)
for
_
in
range
(
2
)]
checkpoint_correctness_verification
(
args
=
args
,
model
=
model
,
hidden_dim
=
model
.
hidden_dim
,
model
s
=
model
s
,
hidden_dim
=
model
s
[
0
]
.
hidden_dim
,
tmpdir
=
save_folder
,
fp16
=
config_dict
[
'fp16'
][
'enabled'
],
load_optimizer_states
=
True
,
...
...
@@ -635,3 +665,42 @@ def test_checkpoint_pipe_module(base_topo, test_topo, tmpdir):
assert
torch
.
allclose
(
p0
,
p1
,
atol
=
1e-07
),
f
"Model state
{
p0
}
is not equal to
{
p1
}
"
_test
(
base_topo
,
test_topo
,
save_folder
=
tmpdir
)
@
pytest
.
mark
.
parametrize
(
'zero_stage'
,
[
1
,
2
])
def
test_checkpoint_zero_hybrid_optimizer_state
(
tmpdir
,
zero_stage
):
config_dict
=
{
"train_micro_batch_size_per_gpu"
:
2
,
"gradient_accumulation_steps"
:
2
,
"steps_per_print"
:
1
,
"zero_optimization"
:
{
"stage"
:
zero_stage
},
"zero_allow_untested_optimizer"
:
True
,
"fp16"
:
{
"enabled"
:
True
,
"initial_scale_power"
:
8
}
}
args
=
args_from_dict
(
tmpdir
,
config_dict
)
hidden_dim
=
10
models
=
[
SimpleModel
(
hidden_dim
=
hidden_dim
)
for
_
in
range
(
2
)]
optimizers
=
[
HybridStateOptimizer
(
model
.
parameters
())
for
model
in
models
]
@
distributed_test
(
world_size
=
[
2
])
def
_test_checkpoint_zero_hybrid_optimizer_state
(
args
,
models
,
optimizers
,
hidden_dim
):
checkpoint_correctness_verification
(
args
,
models
=
models
,
base_optimizers
=
optimizers
,
hidden_dim
=
hidden_dim
,
tmpdir
=
tmpdir
,
load_optimizer_states
=
True
)
_test_checkpoint_zero_hybrid_optimizer_state
(
args
=
args
,
models
=
models
,
optimizers
=
optimizers
,
hidden_dim
=
hidden_dim
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment