Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
4e98e938
Unverified
Commit
4e98e938
authored
Aug 02, 2022
by
HELSON
Committed by
GitHub
Aug 02, 2022
Browse files
[zero] alleviate memory usage in ZeRODDP state_dict (#1398)
parent
4f5f8f77
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
114 additions
and
26 deletions
+114
-26
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+45
-14
tests/test_ddp/test_ddp_state_dict.py
tests/test_ddp/test_ddp_state_dict.py
+69
-12
No files found.
colossalai/nn/parallel/data_parallel.py
View file @
4e98e938
...
@@ -6,12 +6,13 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
...
@@ -6,12 +6,13 @@ from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
from
colossalai.gemini.chunk
import
TensorState
,
Chunk
from
colossalai.gemini.chunk
import
TensorState
,
Chunk
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.tensor.param_op_hook
import
ParamOpHookManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Dict
,
Iterable
,
List
,
Optional
from
typing
import
Dict
,
Iterable
,
List
,
Optional
,
Set
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
colossalai.tensor.colo_parameter
import
ColoParameter
from
colossalai.tensor.colo_parameter
import
ColoParameter
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
colossalai.tensor
import
ProcessGroup
as
ColoProcessGroup
from
.reducer
import
Reducer
from
.reducer
import
Reducer
try
:
try
:
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
,
_IncompatibleKeys
from
torch.nn.modules.module
import
_EXTRA_STATE_KEY_SUFFIX
,
_IncompatibleKeys
except
ImportError
:
except
ImportError
:
...
@@ -84,6 +85,18 @@ class ColoDDP(torch.nn.Module):
...
@@ -84,6 +85,18 @@ class ColoDDP(torch.nn.Module):
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
):
def
named_parameters
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
):
return
self
.
module
.
named_parameters
(
prefix
,
recurse
)
return
self
.
module
.
named_parameters
(
prefix
,
recurse
)
def
named_buffers
(
self
,
prefix
:
str
=
''
,
recurse
:
bool
=
True
):
return
self
.
module
.
named_buffers
(
prefix
,
recurse
)
def
named_children
(
self
):
return
self
.
module
.
named_children
()
def
named_modules
(
self
,
memo
:
Optional
[
Set
[
torch
.
nn
.
Module
]]
=
None
,
prefix
:
str
=
''
,
remove_duplicate
:
bool
=
True
):
return
self
.
module
.
named_modules
(
memo
,
prefix
,
remove_duplicate
)
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
self
.
module
.
zero_grad
(
set_to_none
=
True
)
self
.
module
.
zero_grad
(
set_to_none
=
True
)
return
self
.
module
(
*
args
,
**
kwargs
)
return
self
.
module
(
*
args
,
**
kwargs
)
...
@@ -274,7 +287,7 @@ class ZeroDDP(ColoDDP):
...
@@ -274,7 +287,7 @@ class ZeroDDP(ColoDDP):
for
tensor
in
chunk
.
get_tensors
():
for
tensor
in
chunk
.
get_tensors
():
self
.
grads_device
[
tensor
]
=
device
self
.
grads_device
[
tensor
]
=
device
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
):
def
state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
,
only_rank_0
:
bool
=
True
):
r
"""Returns a dictionary containing a whole state of the module.
r
"""Returns a dictionary containing a whole state of the module.
Both parameters and persistent buffers (e.g. running averages) are
Both parameters and persistent buffers (e.g. running averages) are
...
@@ -291,18 +304,22 @@ class ZeroDDP(ColoDDP):
...
@@ -291,18 +304,22 @@ class ZeroDDP(ColoDDP):
['bias', 'weight']
['bias', 'weight']
"""
"""
is_rank_0
=
self
.
chunk_manager
.
process_group
.
dp_local_rank
()
==
0
record_flag
=
(
not
only_rank_0
)
or
is_rank_0
if
destination
is
None
:
if
destination
is
None
:
destination
=
OrderedDict
()
destination
=
OrderedDict
()
destination
.
_metadata
=
OrderedDict
()
destination
.
_metadata
=
OrderedDict
()
destination
.
_metadata
[
prefix
[:
-
1
]]
=
local_metadata
=
dict
(
version
=
self
.
_version
)
destination
.
_metadata
[
prefix
[:
-
1
]]
=
local_metadata
=
dict
(
version
=
self
.
_version
)
self
.
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
self
.
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
,
record_flag
)
for
hook
in
self
.
_state_dict_hooks
.
values
():
for
hook
in
self
.
_state_dict_hooks
.
values
():
hook_result
=
hook
(
self
,
destination
,
prefix
,
local_metadata
)
hook_result
=
hook
(
self
,
destination
,
prefix
,
local_metadata
)
if
hook_result
is
not
None
:
if
hook_result
is
not
None
:
destination
=
hook_result
destination
=
hook_result
return
destination
return
destination
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
,
record_flag
:
bool
=
True
):
r
"""Saves module state to `destination` dictionary, containing a state
r
"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
submodule in :meth:`~torch.nn.Module.state_dict`.
...
@@ -315,22 +332,36 @@ class ZeroDDP(ColoDDP):
...
@@ -315,22 +332,36 @@ class ZeroDDP(ColoDDP):
prefix (str): the prefix for parameters and buffers used in this
prefix (str): the prefix for parameters and buffers used in this
module
module
"""
"""
chunks
=
self
.
chunk_manager
.
get_chunks
(
self
.
fp32_params
)
# save parameters
chunks_orig_device_type
=
[]
param_to_save_data
=
dict
()
for
chunk
in
chunks
:
chunk_list
=
self
.
chunk_manager
.
get_chunks
(
self
.
fp32_params
)
chunks_orig_device_type
.
append
(
chunk
.
device_type
)
for
chunk
in
chunk_list
:
# record the original device of the chunk
org_chunk_dev_typ
=
chunk
.
device_type
self
.
chunk_manager
.
access_chunk
(
chunk
)
self
.
chunk_manager
.
access_chunk
(
chunk
)
for
tensor
in
chunk
.
get_tensors
():
rec_p
=
torch
.
empty
([
0
])
if
record_flag
:
rec_p
=
tensor
.
cpu
()
# move the whole tensor to CPU mem
assert
tensor
not
in
param_to_save_data
param_to_save_data
[
tensor
]
=
rec_p
# release the actual memory of the chunk
self
.
chunk_manager
.
release_chunk
(
chunk
)
if
not
chunk
.
is_empty
and
org_chunk_dev_typ
==
'cpu'
:
self
.
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
))
for
(
name
,
p
),
fp32_p
in
zip
(
self
.
named_parameters
(),
self
.
fp32_params
):
for
(
name
,
p
),
fp32_p
in
zip
(
self
.
named_parameters
(),
self
.
fp32_params
):
if
p
is
not
None
:
if
p
is
not
None
:
rec_p
=
fp32_p
.
clone
()
if
fp32_p
.
device
.
type
==
'cpu'
else
fp32_p
.
cpu
()
assert
fp32_p
in
param_to_save_data
,
"Parameter '{}' is neglected in the chunk list"
.
format
(
name
)
rec_p
=
param_to_save_data
[
fp32_p
]
destination
[
prefix
+
name
]
=
rec_p
if
keep_vars
else
rec_p
.
detach
()
destination
[
prefix
+
name
]
=
rec_p
if
keep_vars
else
rec_p
.
detach
()
for
orig_dvice_type
,
chunk
in
zip
(
chunks_orig_device_type
,
chunks
):
self
.
chunk_manager
.
release_chunk
(
chunk
)
# save all buffers
if
not
chunk
.
is_empty
and
orig_dvice_type
==
'cpu'
:
self
.
chunk_manager
.
move_chunk
(
chunk
,
torch
.
device
(
'cpu'
))
for
name
,
buf
in
self
.
named_buffers
():
for
name
,
buf
in
self
.
named_buffers
():
if
buf
is
not
None
and
name
not
in
self
.
_non_persistent_buffers_set
:
if
buf
is
not
None
and
name
not
in
self
.
_non_persistent_buffers_set
:
destination
[
prefix
+
name
]
=
buf
if
keep_vars
else
buf
.
detach
()
destination
[
prefix
+
name
]
=
buf
if
keep_vars
else
buf
.
detach
()
# save extra states
extra_state_key
=
prefix
+
_EXTRA_STATE_KEY_SUFFIX
extra_state_key
=
prefix
+
_EXTRA_STATE_KEY_SUFFIX
if
getattr
(
self
.
__class__
,
"get_extra_state"
,
if
getattr
(
self
.
__class__
,
"get_extra_state"
,
torch
.
nn
.
Module
.
get_extra_state
)
is
not
torch
.
nn
.
Module
.
get_extra_state
:
torch
.
nn
.
Module
.
get_extra_state
)
is
not
torch
.
nn
.
Module
.
get_extra_state
:
...
...
tests/test_ddp/test_ddp_state_dict.py
View file @
4e98e938
import
copy
import
pytest
import
pytest
import
colossalai
import
colossalai
import
torch
import
torch
...
@@ -11,9 +13,9 @@ from functools import partial
...
@@ -11,9 +13,9 @@ from functools import partial
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
colossalai.nn.parallel
import
ZeroDDP
,
ColoDDP
from
colossalai.nn.parallel
import
ZeroDDP
,
ColoDDP
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
colossalai.gemini.gemini_mgr
import
GeminiManager
from
typing
import
Callable
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
colossalai.tensor
import
ProcessGroup
,
ColoParameter
from
colossalai.tensor
import
ProcessGroup
,
ColoParameter
from
colossalai.testing
import
parameterize
def
check_state_dict_equal
(
state_dict
:
OrderedDict
,
other_state_dict
:
OrderedDict
):
def
check_state_dict_equal
(
state_dict
:
OrderedDict
,
other_state_dict
:
OrderedDict
):
...
@@ -25,7 +27,27 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
...
@@ -25,7 +27,27 @@ def check_state_dict_equal(state_dict: OrderedDict, other_state_dict: OrderedDic
else
:
else
:
temp_t2
=
t2
temp_t2
=
t2
assert
torch
.
equal
(
t1
,
temp_t2
)
assert
torch
.
equal
(
t1
,
temp_t2
),
"
\t
{}
\n\t
{}"
.
format
(
t1
,
temp_t2
)
def
check_model_equal
(
model_a
,
model_b
,
allow_empty
:
bool
=
False
,
same_dtype
:
bool
=
True
):
for
(
na
,
pa
),
(
nb
,
pb
)
in
zip
(
model_a
.
named_parameters
(),
model_b
.
named_parameters
()):
assert
na
==
nb
if
not
allow_empty
:
assert
pa
.
storage
().
size
()
>
0
assert
pb
.
storage
().
size
()
>
0
else
:
if
pa
.
storage
().
size
()
==
0
or
pb
.
storage
().
size
()
==
0
:
continue
if
same_dtype
:
assert
pa
.
dtype
==
pb
.
dtype
temp_pb
=
pb
else
:
temp_pb
=
pb
.
to
(
pa
.
dtype
)
assert
torch
.
equal
(
pa
,
temp_pb
),
"Parameter '{}' is not equal.
\n
{} {}"
.
format
(
na
,
pa
,
pb
)
def
init_ddp
(
module
:
torch
.
nn
.
Module
)
->
ColoDDP
:
def
init_ddp
(
module
:
torch
.
nn
.
Module
)
->
ColoDDP
:
...
@@ -33,22 +55,26 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
...
@@ -33,22 +55,26 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
return
ColoDDP
(
module
,
process_group
=
pg
)
return
ColoDDP
(
module
,
process_group
=
pg
)
def
init_ddpv2
(
module
:
torch
.
nn
.
Module
,
use_chunk
:
bool
=
False
,
use_zero
:
bool
=
False
)
->
ZeroDDP
:
def
init_ddpv2
(
module
:
torch
.
nn
.
Module
,
use_chunk
:
bool
=
False
,
use_zero
:
bool
=
False
,
placement_policy
:
str
=
'cuda'
)
->
ZeroDDP
:
pg
=
ProcessGroup
()
pg
=
ProcessGroup
()
chunk_size
=
ChunkManager
.
search_chunk_size
(
module
,
64
,
4
)
if
use_chunk
else
None
chunk_size
=
ChunkManager
.
search_chunk_size
(
module
,
64
,
4
)
if
use_chunk
else
None
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
use_zero
)
chunk_manager
=
ChunkManager
(
chunk_size
,
pg
,
enable_distributed_storage
=
use_zero
)
gemini_manager
=
GeminiManager
(
'cuda'
,
chunk_manager
)
gemini_manager
=
GeminiManager
(
placement_policy
,
chunk_manager
)
return
ZeroDDP
(
module
,
gemini_manager
)
return
ZeroDDP
(
module
,
gemini_manager
)
def
run_state_dict
(
ddp_init_func
:
Callable
[[
torch
.
nn
.
Module
],
ColoDDP
]
):
def
run_
ddp_
state_dict
():
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'
nested_model
'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'
gpt2
'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
torch_model
=
model_builder
().
cuda
()
torch_model
=
model_builder
().
cuda
()
with
ColoInitContext
(
device
=
get_current_device
()):
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
model_builder
()
model
=
ddp_
init_
func
(
model
)
model
=
init_
ddp
(
model
)
torch_state_dict
=
torch_model
.
state_dict
()
torch_state_dict
=
torch_model
.
state_dict
()
for
param
in
model
.
parameters
():
for
param
in
model
.
parameters
():
if
isinstance
(
param
,
ColoParameter
):
if
isinstance
(
param
,
ColoParameter
):
assert
param
.
get_process_group
()
is
not
None
assert
param
.
get_process_group
()
is
not
None
...
@@ -62,13 +88,44 @@ def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
...
@@ -62,13 +88,44 @@ def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
check_state_dict_equal
(
torch_state_dict
,
state_dict
)
check_state_dict_equal
(
torch_state_dict
,
state_dict
)
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
@
parameterize
(
'only_rank_0'
,
[
False
,
True
])
def
run_zero_state_dict
(
use_chunk
,
placement_policy
,
use_zero
,
only_rank_0
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
torch_model
=
model_builder
().
cuda
()
org_torch_model
=
copy
.
deepcopy
(
torch_model
)
torch_state_dict
=
torch_model
.
state_dict
()
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
()
model
=
init_ddpv2
(
model
,
use_chunk
,
use_zero
,
placement_policy
)
for
param
in
model
.
parameters
():
if
isinstance
(
param
,
ColoParameter
):
assert
param
.
get_process_group
()
is
not
None
model
.
load_state_dict
(
torch_state_dict
,
strict
=
False
)
check_model_equal
(
model
,
torch_model
,
allow_empty
=
True
,
same_dtype
=
False
)
for
param
in
model
.
parameters
():
if
isinstance
(
param
,
ColoParameter
):
assert
param
.
get_process_group
()
is
not
None
pg
=
ProcessGroup
()
state_dict
=
model
.
state_dict
(
only_rank_0
=
only_rank_0
)
if
not
only_rank_0
or
pg
.
dp_local_rank
()
==
0
:
torch_model
.
load_state_dict
(
state_dict
,
strict
=
False
)
check_model_equal
(
torch_model
,
org_torch_model
,
allow_empty
=
False
,
same_dtype
=
True
)
def
run_dist
(
rank
,
world_size
,
port
):
def
run_dist
(
rank
,
world_size
,
port
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
run_state_dict
(
init_ddp
)
run_ddp_state_dict
()
run_state_dict
(
partial
(
init_ddpv2
,
use_chunk
=
False
,
use_zero
=
False
))
run_zero_state_dict
()
run_state_dict
(
partial
(
init_ddpv2
,
use_chunk
=
False
,
use_zero
=
True
))
run_state_dict
(
partial
(
init_ddpv2
,
use_chunk
=
True
,
use_zero
=
False
))
run_state_dict
(
partial
(
init_ddpv2
,
use_chunk
=
True
,
use_zero
=
True
))
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment