Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
943a9632
Unverified
Commit
943a9632
authored
Jul 26, 2022
by
HELSON
Committed by
GitHub
Jul 26, 2022
Browse files
[hotfix] fix no optimizer in save/load (#1363)
parent
cd063ac3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
38 additions
and
36 deletions
+38
-36
colossalai/tensor/__init__.py
colossalai/tensor/__init__.py
+2
-2
colossalai/utils/checkpoint/module_checkpoint.py
colossalai/utils/checkpoint/module_checkpoint.py
+36
-34
No files found.
colossalai/tensor/__init__.py
View file @
943a9632
...
@@ -13,6 +13,6 @@ from . import distspec
...
@@ -13,6 +13,6 @@ from . import distspec
__all__
=
[
__all__
=
[
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'ComputeSpec'
,
'named_params_with_colotensor'
,
'ColoParameter'
,
'ColoTensor'
,
'convert_parameter'
,
'ComputePattern'
,
'ComputeSpec'
,
'named_params_with_colotensor'
,
'ColoParameter'
,
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'ParamOpHookManager'
,
'
ChunkManager'
,
'TensorState'
,
'ProcessGroup
'
,
'distspec'
,
'DistSpecManager'
,
'ParamOpHook'
,
'ParamOpHookManager'
,
'
ProcessGroup'
,
'ColoTensorSpec
'
,
'ColoTensorSpec'
,
'TensorSpec'
,
'ShardSpec'
,
'ReplicaSpec'
'ShardSpec'
,
'ReplicaSpec'
]
]
colossalai/utils/checkpoint/module_checkpoint.py
View file @
943a9632
...
@@ -46,28 +46,29 @@ def save_checkpoint(dire: str,
...
@@ -46,28 +46,29 @@ def save_checkpoint(dire: str,
# synchronize all the processes
# synchronize all the processes
dist
.
barrier
()
dist
.
barrier
()
mapping
=
dict
()
if
optimizer
is
not
None
:
optim_state
=
optimizer
.
state_dict
()
mapping
=
dict
()
for
k
,
v
in
optim_state
[
'state'
].
items
():
optim_state
=
optimizer
.
state_dict
()
for
n
,
t
in
v
.
items
():
for
k
,
v
in
optim_state
[
'state'
].
items
():
if
isinstance
(
t
,
ColoTensor
):
mapping
[(
k
,
n
)]
=
t
.
dist_spec
gather_tensor
(
t
)
if
rank
==
0
:
save_state
=
{
'epoch'
:
epoch
,
'optim'
:
optim_state
}
torch
.
save
(
save_state
,
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
))
# recover colo tensors in rank0
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
for
n
,
t
in
v
.
items
():
for
n
,
t
in
v
.
items
():
if
isinstance
(
t
,
ColoTensor
):
if
isinstance
(
t
,
ColoTensor
):
assert
hasattr
(
t
,
'save_ready'
)
mapping
[(
k
,
n
)]
=
t
.
dist_spec
t
.
set_dist_spec
(
mapping
[(
k
,
n
)])
gather_tensor
(
t
)
delattr
(
t
,
'save_ready'
)
del
optim_state
if
rank
==
0
:
del
mapping
save_state
=
{
'epoch'
:
epoch
,
'optim'
:
optim_state
}
dist
.
barrier
()
torch
.
save
(
save_state
,
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
))
# recover colo tensors in rank0
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
for
n
,
t
in
v
.
items
():
if
isinstance
(
t
,
ColoTensor
):
assert
hasattr
(
t
,
'save_ready'
)
t
.
set_dist_spec
(
mapping
[(
k
,
n
)])
delattr
(
t
,
'save_ready'
)
del
optim_state
del
mapping
dist
.
barrier
()
def
load_checkpoint
(
dire
,
def
load_checkpoint
(
dire
,
...
@@ -108,21 +109,22 @@ def load_checkpoint(dire,
...
@@ -108,21 +109,22 @@ def load_checkpoint(dire,
delattr
(
p
,
'save_ready'
)
delattr
(
p
,
'save_ready'
)
del
mapping
del
mapping
mapping
=
dict
()
if
optimizer
is
not
None
:
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
mapping
=
dict
()
for
n
,
t
in
v
.
items
():
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
if
isinstance
(
t
,
ColoTensor
):
for
n
,
t
in
v
.
items
():
mapping
[(
k
,
n
)]
=
t
.
dist_spec
if
isinstance
(
t
,
ColoTensor
):
gather_tensor
(
t
)
mapping
[(
k
,
n
)]
=
t
.
dist_spec
gather_tensor
(
t
)
if
rank
==
0
:
if
rank
==
0
:
colo_checkpoint
=
torch
.
load
(
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
))
colo_checkpoint
=
torch
.
load
(
dire
+
'/epoch_{}_optim.pth'
.
format
(
epoch
))
optimizer
.
load_state_dict
(
colo_checkpoint
[
'optim'
])
optimizer
.
load_state_dict
(
colo_checkpoint
[
'optim'
])
dist
.
barrier
()
dist
.
barrier
()
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
for
k
,
v
in
optimizer
.
state_dict
()[
'state'
].
items
():
for
n
,
t
in
v
.
items
():
for
n
,
t
in
v
.
items
():
if
isinstance
(
t
,
ColoTensor
):
if
isinstance
(
t
,
ColoTensor
):
scatter_tensor
(
t
,
mapping
[(
k
,
n
)])
scatter_tensor
(
t
,
mapping
[(
k
,
n
)])
del
mapping
del
mapping
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment