Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
f8eec98f
Unverified
Commit
f8eec98f
authored
Jun 22, 2022
by
Frank Lee
Committed by
GitHub
Jun 22, 2022
Browse files
[tensor] fixed non-serializable colo parameter during model checkpointing (#1153)
parent
ffa025e1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
39 additions
and
3 deletions
+39
-3
colossalai/utils/model/colo_init_context.py
colossalai/utils/model/colo_init_context.py
+39
-3
No files found.
colossalai/utils/model/colo_init_context.py
View file @
f8eec98f
from
.utils
import
InsertPostInitMethodToModuleSubClasses
import
torch
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
from
colossalai.tensor
import
ColoTensor
,
ColoParameter
,
distspec
,
TensorSpec
from
colossalai.nn.parallel.layers
import
register_colo_module
,
\
ColoLinear
,
ColoEmbedding
from
copy
import
copy
from
torch
import
nn
from
typing
import
Iterator
,
Tuple
,
Union
from
functools
import
partialmethod
# find named_params includes replica
...
...
@@ -34,6 +34,38 @@ def ColoModulize(module):
module
.
_colo_visited
=
True
def
colo_state_dict
(
self
,
destination
=
None
,
prefix
=
''
,
keep_vars
=
False
,
state_dict_func
=
None
):
# build param to spec mapping
mapping
=
dict
()
# gather all params
has_dist_parameter
=
False
with
torch
.
no_grad
():
for
param
in
self
.
parameters
():
if
isinstance
(
param
,
ColoParameter
)
and
param
.
has_spec
():
has_dist_parameter
=
True
mapping
[
id
(
param
)]
=
copy
(
param
.
spec
)
param
.
set_spec
(
TensorSpec
(
distspec
.
replicate
()))
# TODO: fix when keep_vars = True
# when keep_vars = False, the state_dict_func will call detach to create
# new tensors, but when keep_vars = True, the recovery of spec will be reflected
# in the `ret`, such that the final state dict will still contain process group,
# raising exception as it is not serializable
assert
not
(
keep_vars
and
has_dist_parameter
),
'keep_vars cannot be True when there are distributed ColoParameters.'
ret
=
state_dict_func
(
self
,
destination
,
prefix
,
keep_vars
)
# recover
with
torch
.
no_grad
():
for
param
in
self
.
parameters
():
param_id
=
id
(
param
)
if
param_id
in
mapping
:
spec
=
mapping
[
id
(
param
)]
param
.
set_spec
(
spec
)
return
ret
class
ColoInitContext
(
InsertPostInitMethodToModuleSubClasses
):
def
__init__
(
self
,
lazy_memory_allocate
:
bool
=
False
,
device
:
torch
.
device
=
torch
.
device
(
'cpu'
)):
...
...
@@ -52,6 +84,10 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
register_colo_module
(
torch
.
nn
.
Linear
,
ColoLinear
())
register_colo_module
(
torch
.
nn
.
Embedding
,
ColoEmbedding
())
def
_pre_context_exec
(
self
):
self
.
state_dict_func
=
nn
.
Module
.
state_dict
nn
.
Module
.
state_dict
=
partialmethod
(
colo_state_dict
,
state_dict_func
=
self
.
state_dict_func
)
def
_post_init_method
(
self
,
module
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
):
"""
The function to call at the end of the constructor of each module.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment