Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
8dced41a
Unverified
Commit
8dced41a
authored
Jul 29, 2022
by
ver217
Committed by
GitHub
Jul 29, 2022
Browse files
[zero] zero optim state_dict takes only_rank_0 (#1384)
* zero optim state_dict takes only_rank_0 * fix unit test
parent
7d5d628e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
13 deletions
+18
-13
colossalai/zero/zero_optimizer.py
colossalai/zero/zero_optimizer.py
+14
-10
tests/test_zero/test_zero_optim_state_dict.py
tests/test_zero/test_zero_optim_state_dict.py
+4
-3
No files found.
colossalai/zero/zero_optimizer.py
View file @
8dced41a
...
@@ -193,8 +193,9 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -193,8 +193,9 @@ class ZeroOptimizer(ColossalaiOptimizer):
if
isinstance
(
val
,
torch
.
Tensor
):
if
isinstance
(
val
,
torch
.
Tensor
):
self
.
chunk_manager
.
add_extern_static_tensor
(
val
)
self
.
chunk_manager
.
add_extern_static_tensor
(
val
)
def
state_dict
(
self
):
def
state_dict
(
self
,
only_rank_0
:
bool
=
True
):
r
"""Returns the state of the optimizer as a :class:`dict`. For DP rank != 0, this function returns None.
r
"""Returns the state of the optimizer as a :class:`dict`. If only_rank_0 is True, for DP rank != 0, this function returns None.
This saves memory usage.
It contains two entries:
It contains two entries:
...
@@ -204,7 +205,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -204,7 +205,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
parameter group is a dict
parameter group is a dict
"""
"""
is_rank_0
=
self
.
chunk_manager
.
process_group
.
dp_local_rank
()
==
0
is_rank_0
=
self
.
chunk_manager
.
process_group
.
dp_local_rank
()
==
0
if
not
self
.
chunk_manager
.
enable_distributed_storage
and
not
is_rank_0
:
if
not
self
.
chunk_manager
.
enable_distributed_storage
and
only_rank_0
and
not
is_rank_0
:
return
return
optim_state_dict
=
super
().
state_dict
()
optim_state_dict
=
super
().
state_dict
()
scaler_state_dict
=
self
.
grad_scaler
.
state_dict
()
scaler_state_dict
=
self
.
grad_scaler
.
state_dict
()
...
@@ -214,14 +215,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
...
@@ -214,14 +215,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
local_state
=
{
k
:
convert_state_dict_to_cpu
(
v
)
for
k
,
v
in
optim_state_dict
[
'state'
].
items
()
if
len
(
v
)
>
0
}
local_state
=
{
k
:
convert_state_dict_to_cpu
(
v
)
for
k
,
v
in
optim_state_dict
[
'state'
].
items
()
if
len
(
v
)
>
0
}
if
not
self
.
chunk_manager
.
process_group
.
has_cpu_groups
:
if
not
self
.
chunk_manager
.
process_group
.
has_cpu_groups
:
self
.
chunk_manager
.
process_group
.
set_cpu_groups
()
self
.
chunk_manager
.
process_group
.
set_cpu_groups
()
dst_rank
=
self
.
chunk_manager
.
process_group
.
dp_rank_list
()[
0
]
output
=
[
None
for
_
in
range
(
self
.
chunk_manager
.
process_group
.
dp_world_size
())]
output
=
[
None
for
_
in
range
(
self
.
chunk_manager
.
process_group
.
dp_world_size
())]
dist
.
gather_object
(
local_state
,
if
only_rank_0
:
output
if
self
.
chunk_manager
.
process_group
.
dp_local_rank
()
==
0
else
None
,
dst_rank
=
self
.
chunk_manager
.
process_group
.
dp_rank_list
()[
0
]
dst
=
dst_rank
,
dist
.
gather_object
(
local_state
,
group
=
self
.
chunk_manager
.
process_group
.
cpu_dp_process_group
())
output
if
self
.
chunk_manager
.
process_group
.
dp_local_rank
()
==
0
else
None
,
if
not
is_rank_0
:
dst
=
dst_rank
,
return
group
=
self
.
chunk_manager
.
process_group
.
cpu_dp_process_group
())
if
not
is_rank_0
:
return
else
:
dist
.
all_gather_object
(
output
,
local_state
,
group
=
self
.
chunk_manager
.
process_group
.
cpu_dp_process_group
())
for
state
in
output
:
for
state
in
output
:
optim_state_dict
[
'state'
].
update
(
state
)
optim_state_dict
[
'state'
].
update
(
state
)
return
optim_state_dict
return
optim_state_dict
...
...
tests/test_zero/test_zero_optim_state_dict.py
View file @
8dced41a
...
@@ -45,7 +45,8 @@ def check_state_dict(state_dict, torch_state_dict):
...
@@ -45,7 +45,8 @@ def check_state_dict(state_dict, torch_state_dict):
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_chunk'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
@
parameterize
(
'use_zero'
,
[
False
,
True
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
])
def
run_zero_optim_state_dict
(
use_chunk
,
use_zero
,
placement_policy
):
@
parameterize
(
'only_rank_0'
,
[
False
,
True
])
def
run_zero_optim_state_dict
(
use_chunk
,
use_zero
,
placement_policy
,
only_rank_0
):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
'gpt2'
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
...
@@ -76,8 +77,8 @@ def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy):
...
@@ -76,8 +77,8 @@ def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy):
optim
.
load_state_dict
(
torch_state_dict
)
optim
.
load_state_dict
(
torch_state_dict
)
check_load_state_dict
(
optim
,
torch_optim
)
check_load_state_dict
(
optim
,
torch_optim
)
state_dict
=
optim
.
state_dict
()
state_dict
=
optim
.
state_dict
(
only_rank_0
)
if
pg
.
rank
()
==
0
:
if
not
only_rank_0
or
pg
.
rank
()
==
0
:
check_state_dict
(
state_dict
,
torch_state_dict
)
check_state_dict
(
state_dict
,
torch_state_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment