Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
828b9e5e
Unverified
Commit
828b9e5e
authored
Jul 28, 2022
by
ver217
Committed by
GitHub
Jul 28, 2022
Browse files
[hotfix] fix zero optim save/load state dict (#1381)
parent
b6fd165f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
154 additions
and
69 deletions
+154
-69
colossalai/tensor/process_group.py
colossalai/tensor/process_group.py
+2
-2
colossalai/zero/zero_optimizer.py
colossalai/zero/zero_optimizer.py
+98
-12
tests/test_zero/test_zero_optim_state_dict.py
tests/test_zero/test_zero_optim_state_dict.py
+54
-55
No files found.
colossalai/tensor/process_group.py
View file @
828b9e5e
...
...
@@ -104,8 +104,8 @@ class ProcessGroup:
def
set_cpu_groups
(
self
):
if
self
.
has_cpu_groups
:
return
self
.
logger
.
info
(
f
'
{
self
.
_rank
}
Gloo initialize TP group on
{
self
.
_tp_rank_list
}
, DP group on
{
self
.
_dp_rank_list
}
'
)
#
self.logger.info(
#
f'{self._rank} Gloo initialize TP group on {self._tp_rank_list}, DP group on {self._dp_rank_list}')
PYTORCHPGDICT_
.
get
(
self
.
_tp_rank_list
,
'gloo'
)
PYTORCHPGDICT_
.
get
(
self
.
_dp_rank_list
,
'gloo'
)
self
.
_has_cpu_groups
=
True
...
...
colossalai/zero/zero_optimizer.py
View file @
828b9e5e
...
...
@@ -8,6 +8,9 @@ from colossalai.amp.naive_amp.grad_scaler import DynamicGradScaler
from
colossalai.logging
import
get_dist_logger
from
colossalai.nn.optimizer
import
ColossalaiOptimizer
from
colossalai.utils
import
get_current_device
,
disposable
from
collections
import
defaultdict
,
abc
as
container_abcs
from
copy
import
deepcopy
from
itertools
import
chain
class
OptimState
(
Enum
):
...
...
@@ -191,22 +194,105 @@ class ZeroOptimizer(ColossalaiOptimizer):
self
.
chunk_manager
.
add_extern_static_tensor
(
val
)
def
state_dict
(
self
):
r
"""Returns the state of the optimizer as a :class:`dict`. For DP rank != 0, this function returns None.
It contains two entries:
* state - a dict holding current optimization state. Its content
differs between optimizer classes.
* param_groups - a list containing all parameter groups where each
parameter group is a dict
"""
is_rank_0
=
self
.
chunk_manager
.
process_group
.
dp_local_rank
()
==
0
if
not
self
.
chunk_manager
.
enable_distributed_storage
and
not
is_rank_0
:
return
optim_state_dict
=
super
().
state_dict
()
scaler_state_dict
=
self
.
grad_scaler
.
state_dict
()
optim_state_dict
[
'scaler'
]
=
scaler_state_dict
if
not
self
.
chunk_manager
.
enable_distributed_storage
:
return
optim_state_dict
local_state
=
{
k
:
convert_state_dict_to_cpu
(
v
)
for
k
,
v
in
optim_state_dict
[
'state'
].
items
()
if
len
(
v
)
>
0
}
if
not
self
.
chunk_manager
.
process_group
.
has_cpu_groups
:
self
.
chunk_manager
.
process_group
.
set_cpu_groups
()
dst_rank
=
self
.
chunk_manager
.
process_group
.
dp_rank_list
()[
0
]
output
=
[
None
for
_
in
range
(
self
.
chunk_manager
.
process_group
.
dp_world_size
())]
dist
.
gather_object
(
local_state
,
output
if
self
.
chunk_manager
.
process_group
.
dp_local_rank
()
==
0
else
None
,
dst
=
dst_rank
,
group
=
self
.
chunk_manager
.
process_group
.
cpu_dp_process_group
())
if
not
is_rank_0
:
return
for
state
in
output
:
optim_state_dict
[
'state'
].
update
(
state
)
return
optim_state_dict
def
load_state_dict
(
self
,
state_dict
):
r
"""Loads the optimizer state.
def
load_state_dict
(
self
,
*
args
,
**
kwargs
):
if
'scaler'
not
in
args
[
0
]:
Args:
state_dict (dict): optimizer state. Should be an object returned
from a call to :meth:`state_dict`.
"""
if
'scaler'
not
in
state_dict
:
self
.
_logger
.
warning
(
'Missing scaler when loading optimizer state dict'
,
ranks
=
[
0
])
else
:
scaler_state_dict
=
args
[
0
].
pop
(
'scaler'
)
self
.
grad_scaler
.
load_state_dict
(
scaler_state_dict
)
super
().
load_state_dict
(
*
args
,
**
kwargs
)
for
group
in
self
.
optim
.
param_groups
:
for
p
in
group
[
'params'
]:
state
=
self
.
optim
.
state
[
p
]
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
state
[
k
]
=
v
.
to
(
dtype
=
self
.
fp16_param_to_fp32_param
[
p
].
dtype
,
device
=
self
.
fp16_param_to_fp32_param
[
p
].
device
)
self
.
grad_scaler
.
load_state_dict
(
deepcopy
(
state_dict
[
'scaler'
]))
# Validate the state_dict
groups
=
self
.
param_groups
saved_groups
=
deepcopy
(
state_dict
[
'param_groups'
])
if
len
(
groups
)
!=
len
(
saved_groups
):
raise
ValueError
(
"loaded state dict has a different number of "
"parameter groups"
)
param_lens
=
(
len
(
g
[
'params'
])
for
g
in
groups
)
saved_lens
=
(
len
(
g
[
'params'
])
for
g
in
saved_groups
)
if
any
(
p_len
!=
s_len
for
p_len
,
s_len
in
zip
(
param_lens
,
saved_lens
)):
raise
ValueError
(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
)
# Update the state
id_map
=
{
old_id
:
p
for
old_id
,
p
in
zip
(
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
saved_groups
)),
chain
.
from_iterable
((
g
[
'params'
]
for
g
in
groups
)))
}
def
cast
(
param
,
value
):
r
"""Make a deep copy of value, casting all tensors to device of param."""
if
isinstance
(
value
,
torch
.
Tensor
):
# Floating-point types are a bit special here. They are the only ones
# that are assumed to always match the type of params.
if
param
.
is_floating_point
():
value
=
value
.
to
(
param
.
dtype
)
value
=
value
.
to
(
param
.
device
)
return
value
elif
isinstance
(
value
,
dict
):
return
{
k
:
cast
(
param
,
v
)
for
k
,
v
in
value
.
items
()}
elif
isinstance
(
value
,
container_abcs
.
Iterable
):
return
type
(
value
)(
cast
(
param
,
v
)
for
v
in
value
)
else
:
return
value
# Copy state assigned to params (and cast tensors to appropriate types).
# State that is not assigned to params is copied as is (needed for
# backward compatibility).
state
=
defaultdict
(
dict
)
for
k
,
v
in
state_dict
[
'state'
].
items
():
if
k
in
id_map
:
param
=
self
.
fp16_param_to_fp32_param
[
id_map
[
k
]]
if
param
.
storage
().
size
()
>
0
:
state
[
param
]
=
cast
(
param
,
deepcopy
(
v
))
else
:
state
[
k
]
=
deepcopy
(
v
)
# Update parameter groups, setting their 'params' value
def
update_group
(
group
,
new_group
):
new_group
[
'params'
]
=
group
[
'params'
]
return
new_group
param_groups
=
[
update_group
(
g
,
ng
)
for
g
,
ng
in
zip
(
groups
,
saved_groups
)]
self
.
__setstate__
({
'state'
:
state
,
'param_groups'
:
param_groups
})
def
convert_state_dict_to_cpu
(
state
:
Dict
[
str
,
torch
.
Tensor
]):
return
{
k
:
v
.
cpu
()
if
isinstance
(
v
,
torch
.
Tensor
)
else
v
for
k
,
v
in
state
.
items
()}
tests/test_zero/test_zero_optim_state_dict.py
View file @
828b9e5e
import
pytest
import
colossalai
import
torch
from
colossalai.context.parallel_mode
import
ParallelMode
import
torch.multiprocessing
as
mp
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.
core
import
global_context
as
gpc
from
colossalai.
gemini
import
ChunkManager
from
functools
import
partial
from
tests.test_tensor.common_utils
import
set_seed
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel.data_parallel
import
ZeroDDP
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
from
colossalai.testing
import
parameterize
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.nn.optimizer
import
HybridAdam
from
colossalai.zero
import
ZeroOptimizer
from
colossalai.testing
import
parameterize
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.tensor
import
ProcessGroup
def
init_zero
(
model
,
use_chunk
,
use_zero
,
placement_policy
):
pg
=
ProcessGroup
()
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
8192
,
8
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
use_zero
,
init_device
=
GeminiManager
.
get_default_device
(
placement_policy
))
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
return
ZeroDDP
(
model
,
gemini_manager
)
def
check_state
(
s1
,
s2
):
for
v1
,
v2
in
zip
(
s1
.
values
(),
s2
.
values
()):
if
isinstance
(
v1
,
torch
.
Tensor
):
v1
=
v1
.
to
(
v2
.
device
)
assert
torch
.
equal
(
v1
,
v2
),
f
'
{
torch
.
sum
((
v1
-
v2
).
abs
())
}
'
else
:
assert
v1
==
v2
def
run_step
(
model
,
optim
,
criterion
,
data
,
label
):
optim
.
zero_grad
()
logits
=
model
(
data
)
loss
=
criterion
(
logits
,
label
)
optim
.
backward
(
loss
)
optim
.
step
()
def
check_load_state_dict
(
optim
,
torch_optim
):
for
group
,
torch_group
in
zip
(
optim
.
optim
.
param_groups
,
torch_optim
.
param_groups
):
for
p
,
torch_p
in
zip
(
group
[
'params'
],
torch_group
[
'params'
]):
state
=
optim
.
optim
.
state
[
p
]
torch_state
=
torch_optim
.
state
[
torch_p
]
if
p
.
storage
().
size
()
==
0
:
assert
len
(
state
)
==
0
check_state
(
state
,
torch_state
)
def
check_state_dict_eq
(
state_dict
,
other
):
for
p
,
state
in
state_dict
[
'state'
].
items
():
other_state
=
other
[
'state'
][
p
]
for
k
,
v
in
state
.
items
():
if
isinstance
(
v
,
torch
.
Tensor
):
assert
torch
.
allclose
(
v
,
other_state
[
k
],
atol
=
1e-3
),
f
'
{
v
}
vs
{
other_state
[
k
]
}
'
else
:
assert
v
==
other_state
[
k
]
def
check_state_dict
(
state_dict
,
torch_state_dict
):
for
(
k1
,
s1
),
(
k2
,
s2
)
in
zip
(
state_dict
[
'state'
].
items
(),
torch_state_dict
[
'state'
].
items
()):
assert
k1
==
k2
check_state
(
s1
,
s2
)
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
def
run_
nested_model
(
use_chunk
,
use_zero
,
placement_policy
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'
nested_model
'
)
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
def
run_
zero_optim_state_dict
(
use_chunk
,
use_zero
,
placement_policy
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'
gpt2
'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
set_seed
(
42
)
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
set_seed
(
42
)
with
ColoInitContext
(
device
=
get_current_device
()):
model_copy
=
model_builder
()
model
=
init_zero
(
model
,
use_chunk
,
use_zero
,
placement_policy
)
model_copy
=
init_zero
(
model_copy
,
use_chunk
,
use_zero
,
placement_policy
)
model
=
model
.
cuda
()
torch_model
=
model_builder
().
cuda
()
pg
=
ProcessGroup
()
chunk_size
=
ChunkManager
.
search_chunk_size
(
model
,
8192
,
8
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
use_zero
,
init_device
=
GeminiManager
.
get_default_device
(
placement_policy
))
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
optim
=
HybridAdam
(
model
.
parameters
(),
lr
=
1e-3
)
optim
=
ZeroOptimizer
(
optim
,
model
,
initial_scale
=
32
)
optim_copy
=
HybridAdam
(
model_copy
.
parameters
(),
lr
=
1e-3
)
optim_copy
=
ZeroOptimizer
(
optim_copy
,
model_copy
,
initial_scale
=
32
)
optim
=
ZeroOptimizer
(
optim
,
model
,
initial_scale
=
1
)
torch_optim
=
torch
.
optim
.
Adam
(
torch_model
.
parameters
(),
lr
=
1e-3
)
model
.
train
()
model_copy
.
train
()
set_seed
(
gpc
.
get_local_rank
(
ParallelMode
.
DATA
))
data_iter
=
iter
(
train_dataloader
)
for
p
in
torch_model
.
parameters
():
p
.
grad
=
torch
.
rand_like
(
p
)
data
,
label
=
map
(
lambda
x
:
x
.
cuda
(),
next
(
data_iter
)
)
run_step
(
model
,
optim
,
criterion
,
data
,
label
)
optim
_copy
.
load_state_dict
(
optim
.
state_dict
()
)
check_state_dict
_eq
(
optim
.
state_dict
(),
optim_copy
.
state_dict
()
)
torch_optim
.
step
(
)
torch_state_dict
=
torch_optim
.
state_dict
(
)
optim
.
load_state_dict
(
torch_
state_dict
)
check_
load_
state_dict
(
optim
,
torch_optim
)
data
,
label
=
map
(
lambda
x
:
x
.
cuda
(),
next
(
data_iter
))
run_step
(
model_copy
,
optim_copy
,
criterion
,
data
,
label
)
state_dict
=
optim
.
state_dict
()
if
pg
.
rank
()
==
0
:
check_state_dict
(
state_dict
,
torch_state_dict
)
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_nested_model
()
config
=
{}
colossalai
.
launch
(
config
=
config
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_zero_optim_state_dict
()
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
])
@
rerun_if_address_is_in_use
()
def
test_zero_optim_state_di
s
t
(
world_size
):
def
test_zero_optim_state_di
c
t
(
world_size
):
run_func
=
partial
(
run_dist
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_zero_optim_state_di
s
t
(
2
)
test_zero_optim_state_di
c
t
(
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment