Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
dd7cc582
Commit
dd7cc582
authored
Jul 06, 2023
by
LuGY
Committed by
Hongxin Liu
Jul 31, 2023
Browse files
[zero] add state dict for low level zero (#4179)
* add state dict for zero * fix unit test * polish
parent
c668801d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
188 additions
and
1 deletion
+188
-1
colossalai/zero/low_level/low_level_optim.py
colossalai/zero/low_level/low_level_optim.py
+67
-1
tests/test_zero/test_low_level/test_zero_ckpt.py
tests/test_zero/test_low_level/test_zero_ckpt.py
+121
-0
No files found.
colossalai/zero/low_level/low_level_optim.py
View file @
dd7cc582
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
# this code is inspired by the DeepSpeed library and implemented with our own design from scratch
import
copy
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
functools
import
partial
from
functools
import
partial
from
typing
import
Optional
from
typing
import
Optional
...
@@ -198,7 +199,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -198,7 +199,7 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
params_current_rank
=
[]
params_current_rank
=
[]
device
=
'cpu'
if
self
.
_cpu_offload
else
get_current_device
()
device
=
'cpu'
if
self
.
_cpu_offload
else
get_current_device
()
for
param
in
reversed
(
param_list
)
:
for
param
in
param_list
:
padding_size
=
(
self
.
_world_size
-
param
.
numel
()
%
self
.
_world_size
)
%
self
.
_world_size
padding_size
=
(
self
.
_world_size
-
param
.
numel
()
%
self
.
_world_size
)
%
self
.
_world_size
self
.
_param_store
.
record_param_padding_size
(
param
,
padding_size
)
self
.
_param_store
.
record_param_padding_size
(
param
,
padding_size
)
...
@@ -468,3 +469,68 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
...
@@ -468,3 +469,68 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
yield
yield
finally
:
finally
:
self
.
require_grad_sync
=
old_require_grad_sync
self
.
require_grad_sync
=
old_require_grad_sync
##############
# State Dict #
##############
def
_pack_state
(
self
,
state
:
dict
)
->
dict
:
# comes from pytorch optimizer.state_dict()
param_mappings
=
{}
start_index
=
0
def
pack_group
(
group
):
nonlocal
start_index
packed
=
{
k
:
v
for
k
,
v
in
group
.
items
()
if
k
!=
'params'
}
param_mappings
.
update
(
{
id
(
p
):
i
for
i
,
p
in
enumerate
(
group
[
'params'
],
start_index
)
if
id
(
p
)
not
in
param_mappings
})
packed
[
'params'
]
=
[
param_mappings
[
id
(
p
)]
for
p
in
group
[
'params'
]]
start_index
+=
len
(
packed
[
'params'
])
return
packed
param_groups
=
[
pack_group
(
g
)
for
g
in
self
.
param_groups
]
# Remap state to use order indices as keys
packed_state
=
{(
param_mappings
[
id
(
k
)]
if
isinstance
(
k
,
torch
.
Tensor
)
else
k
):
v
for
k
,
v
in
state
.
items
()}
return
{
'state'
:
packed_state
,
'param_groups'
:
param_groups
}
def
state_dict
(
self
)
->
dict
:
"""Return a state_dict same with DDP
Returns:
dict: the pytorch form state_dict
"""
zero_state
=
dict
()
for
param
,
state
in
self
.
optim
.
state
.
items
():
zero_state
[
param
]
=
copy
.
deepcopy
(
state
)
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
)
and
k
!=
'step'
:
working_param
=
self
.
_param_store
.
master_to_working_param
[
id
(
param
)]
gather_tensor
=
[
torch
.
zeros_like
(
v
)
for
_
in
range
(
self
.
_world_size
)]
dist
.
all_gather
(
gather_tensor
,
v
,
group
=
self
.
dp_pg
)
param_state
=
torch
.
stack
(
gather_tensor
).
view
(
-
1
)[:
working_param
.
numel
()].
reshape_as
(
working_param
)
zero_state
[
param
][
k
]
=
param_state
states_dict
=
self
.
_pack_state
(
zero_state
)
return
states_dict
def
load_state_dict
(
self
,
state_dict
:
dict
):
"""Load state dict, requires the state_dict be the pytorch form
Args:
state_dict (dict): A pytorch form state_dict
"""
zero_state_dict
=
copy
.
deepcopy
(
state_dict
)
for
param_idx
,
state
in
zero_state_dict
[
'state'
].
items
():
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
)
and
k
!=
'step'
:
padding_size
=
(
self
.
_world_size
-
v
.
numel
()
%
self
.
_world_size
)
%
self
.
_world_size
with
torch
.
no_grad
():
v
=
v
.
flatten
()
if
padding_size
>
0
:
v
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
padding_size
])
v_list
=
v
.
split
(
v
.
numel
()
//
self
.
_world_size
)
zero_state_dict
[
'state'
][
param_idx
][
k
]
=
v_list
[
self
.
_local_rank
].
detach
()
self
.
optim
.
load_state_dict
(
zero_state_dict
)
zero_state_dict
=
dict
()
tests/test_zero/test_low_level/test_zero_ckpt.py
0 → 100644
View file @
dd7cc582
import
copy
import
pytest
import
torch
import
torch.distributed
as
dist
import
torch.nn
as
nn
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
from
torch.testing
import
assert_close
import
colossalai
from
colossalai.testing
import
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing.random
import
seed_all
from
colossalai.zero
import
LowLevelZeroOptimizer
class
MlpModel
(
nn
.
Module
):
def
__init__
(
self
):
super
(
MlpModel
,
self
).
__init__
()
self
.
linear1
=
nn
.
Linear
(
12
,
24
)
self
.
linear2
=
nn
.
Linear
(
24
,
12
)
def
forward
(
self
,
x
):
x
=
self
.
linear1
(
x
)
x
=
self
.
linear2
(
x
)
return
x
def
loose_close
(
a
,
b
,
dtype
:
torch
.
dtype
=
torch
.
float32
):
rtol
=
None
atol
=
None
if
dtype
is
torch
.
float16
:
rtol
=
5e-2
atol
=
5e-4
elif
dtype
is
torch
.
bfloat16
:
rtol
=
4e-3
atol
=
4e-3
a
=
a
.
detach
().
to
(
dtype
)
b
=
b
.
detach
().
to
(
dtype
)
assert_close
(
a
,
b
,
rtol
=
rtol
,
atol
=
atol
)
def
exam_zero_1_torch_ddp_ckpt
():
"""
We examine the state_dict of zero and DDP.
Moreover, we examine the zero's loading checkpoint of a torch ckpt.
"""
local_rank
=
torch
.
distributed
.
get_rank
()
seed_all
(
1453
)
# create models
torch_model
=
MlpModel
().
cuda
()
zero_model
=
copy
.
deepcopy
(
torch_model
)
torch_model
=
DDP
(
torch_model
.
cuda
(),
static_graph
=
True
).
cuda
()
# create optimizer
zero_optimizer
=
torch
.
optim
.
Adam
(
zero_model
.
parameters
(),
lr
=
1
)
# we only test stage 1 here
# the state dicts of stage 1 and stage 2 are the same
zero_optimizer
=
LowLevelZeroOptimizer
(
zero_optimizer
,
overlap_communication
=
True
,
initial_scale
=
1
,
reduce_bucket_size
=
262144
)
torch_optimizer
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1
)
seed_all
(
1453
+
local_rank
)
# create
input_data
=
torch
.
rand
(
4
,
12
).
cuda
()
# forward
zero_output
=
zero_model
(
input_data
)
torch_output
=
torch_model
(
input_data
)
# backward
zero_optimizer
.
backward
(
zero_output
.
mean
().
float
())
torch_output
.
mean
().
backward
()
# step
zero_optimizer
.
step
()
torch_optimizer
.
step
()
torch_state_dict
=
torch_optimizer
.
state_dict
()
zero_state_dict
=
zero_optimizer
.
state_dict
()
# examine the original state dict
for
torch_state
,
zero_state
in
zip
(
torch_state_dict
[
'state'
].
values
(),
zero_state_dict
[
'state'
].
values
()):
for
t_v
,
z_v
in
zip
(
torch_state
.
values
(),
zero_state
.
values
()):
loose_close
(
t_v
,
z_v
)
# empty the optimzer state
zero_optimizer
.
optim
.
state
=
[]
# zero load a torch checkpoint
zero_optimizer
.
load_state_dict
(
copy
.
deepcopy
(
torch_state_dict
))
zero_state_dict
=
zero_optimizer
.
state_dict
()
# examine the loaded state dict
for
torch_state
,
zero_state
in
zip
(
torch_state_dict
[
'state'
].
values
(),
zero_state_dict
[
'state'
].
values
()):
for
t_v
,
z_v
in
zip
(
torch_state
.
values
(),
zero_state
.
values
()):
loose_close
(
t_v
,
z_v
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
dict
(),
rank
=
rank
,
world_size
=
world_size
,
port
=
port
,
host
=
'localhost'
)
exam_zero_1_torch_ddp_ckpt
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_zero_ckpt
():
spawn
(
run_dist
,
2
)
if
__name__
==
'__main__'
:
test_zero_ckpt
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment