Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
af32022f
Unverified
Commit
af32022f
authored
Jan 03, 2023
by
Jiarui Fang
Committed by
GitHub
Jan 03, 2023
Browse files
[Gemini] fix the convert_to_torch_module bug (#2269)
parent
879df8b9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
48 additions
and
21 deletions
+48
-21
colossalai/gemini/gemini_mgr.py
colossalai/gemini/gemini_mgr.py
+1
-1
colossalai/gemini/placement_policy.py
colossalai/gemini/placement_policy.py
+1
-1
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+39
-13
colossalai/nn/parallel/utils.py
colossalai/nn/parallel/utils.py
+7
-6
No files found.
colossalai/gemini/gemini_mgr.py
View file @
af32022f
...
@@ -30,7 +30,7 @@ class GeminiManager:
...
@@ -30,7 +30,7 @@ class GeminiManager:
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
def
__init__
(
self
,
placement_policy
:
str
,
chunk_manager
:
ChunkManager
,
memstats
:
Optional
[
MemStats
]
=
None
)
->
None
:
assert
placement_policy
in
PlacementPolicyFactory
.
get_pol
o
cy_names
()
assert
placement_policy
in
PlacementPolicyFactory
.
get_pol
i
cy_names
()
self
.
policy_name
=
placement_policy
self
.
policy_name
=
placement_policy
policy_cls
=
PlacementPolicyFactory
.
create
(
placement_policy
)
policy_cls
=
PlacementPolicyFactory
.
create
(
placement_policy
)
self
.
_chunk_manager
=
chunk_manager
self
.
_chunk_manager
=
chunk_manager
...
...
colossalai/gemini/placement_policy.py
View file @
af32022f
...
@@ -236,7 +236,7 @@ class PlacementPolicyFactory:
...
@@ -236,7 +236,7 @@ class PlacementPolicyFactory:
return
PlacementPolicyFactory
.
policies
[
policy_name
]
return
PlacementPolicyFactory
.
policies
[
policy_name
]
@
staticmethod
@
staticmethod
def
get_pol
o
cy_names
():
def
get_pol
i
cy_names
():
return
tuple
(
PlacementPolicyFactory
.
policies
.
keys
())
return
tuple
(
PlacementPolicyFactory
.
policies
.
keys
())
@
staticmethod
@
staticmethod
...
...
colossalai/nn/parallel/data_parallel.py
View file @
af32022f
...
@@ -360,24 +360,20 @@ class ZeroDDP(ColoDDP):
...
@@ -360,24 +360,20 @@ class ZeroDDP(ColoDDP):
destination
=
hook_result
destination
=
hook_result
return
destination
return
destination
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
,
only_rank_0
=
True
):
def
_get_param_to_save_data
(
self
,
param_list
:
List
[
torch
.
nn
.
Parameter
],
only_rank_0
:
bool
)
->
Dict
:
r
"""Saves module state to `destination` dictionary, containing a state
"""
of the module, but not its descendants. This is called on every
get param content from chunks.
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
Args:
destination (dict): a dict where state will be stored
param_list (_type_): a list of torch.nn.Parameters
prefix (str): the prefix for parameters and buffers used in this
only_rank_0 (_type_): _description_
module
"""
assert
keep_vars
is
False
,
"`state_dict` with parameter, `keep_vars=True`, is not supported now."
Returns:
Dict: a dict whose key is param name and value is param with correct payload
"""
# save parameters
# save parameters
param_to_save_data
=
dict
()
param_to_save_data
=
dict
()
chunk_list
=
self
.
chunk_manager
.
get_chunks
(
self
.
fp32_params
)
chunk_list
=
self
.
chunk_manager
.
get_chunks
(
param_list
)
for
chunk
in
chunk_list
:
for
chunk
in
chunk_list
:
temp_chunk
=
get_temp_total_chunk_on_cuda
(
chunk
)
temp_chunk
=
get_temp_total_chunk_on_cuda
(
chunk
)
...
@@ -391,7 +387,37 @@ class ZeroDDP(ColoDDP):
...
@@ -391,7 +387,37 @@ class ZeroDDP(ColoDDP):
param_to_save_data
[
tensor
]
=
record_tensor
param_to_save_data
[
tensor
]
=
record_tensor
del
temp_chunk
del
temp_chunk
return
param_to_save_data
def
torch_named_parameters
(
self
):
"""
get named_parameters() of self.module. It is used the same of PyTorch param and returns the real param.data payload.
It works the same as torch.Module named_parameters
"""
params_list
=
[
p
for
p
in
self
.
parameters
(
recurse
=
True
)]
param_to_save_data
=
self
.
_get_param_to_save_data
(
params_list
,
False
)
for
(
name
,
_
),
p
in
zip
(
self
.
named_parameters
(
recurse
=
True
),
params_list
):
if
p
is
not
None
:
assert
p
in
param_to_save_data
,
"Parameter '{}' is neglected in the chunk list"
.
format
(
name
)
record_parameter
=
param_to_save_data
[
p
]
yield
name
,
record_parameter
def
_save_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
,
only_rank_0
=
True
):
r
"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
assert
keep_vars
is
False
,
"`state_dict` with parameter, `keep_vars=True`, is not supported now."
param_to_save_data
=
self
.
_get_param_to_save_data
(
self
.
fp32_params
,
only_rank_0
)
for
(
name
,
p
),
fp32_p
in
zip
(
self
.
named_parameters
(),
self
.
fp32_params
):
for
(
name
,
p
),
fp32_p
in
zip
(
self
.
named_parameters
(),
self
.
fp32_params
):
if
p
is
not
None
:
if
p
is
not
None
:
assert
fp32_p
in
param_to_save_data
,
"Parameter '{}' is neglected in the chunk list"
.
format
(
name
)
assert
fp32_p
in
param_to_save_data
,
"Parameter '{}' is neglected in the chunk list"
.
format
(
name
)
...
...
colossalai/nn/parallel/utils.py
View file @
af32022f
...
@@ -2,7 +2,6 @@ import torch
...
@@ -2,7 +2,6 @@ import torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
colossalai.gemini.chunk
import
Chunk
from
colossalai.gemini.chunk
import
Chunk
from
colossalai.tensor
import
ColoTensor
from
colossalai.utils
import
get_current_device
from
colossalai.utils
import
get_current_device
...
@@ -22,6 +21,7 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
...
@@ -22,6 +21,7 @@ def get_temp_total_chunk_on_cuda(chunk: Chunk):
return
total_temp
return
total_temp
# TODO() not work for module where two params share the same tensor.
def
_add_param
(
model
,
name
,
param
):
def
_add_param
(
model
,
name
,
param
):
name_list
=
name
.
split
(
'.'
)
name_list
=
name
.
split
(
'.'
)
module
=
model
.
_modules
[
name_list
[
0
]]
module
=
model
.
_modules
[
name_list
[
0
]]
...
@@ -30,7 +30,7 @@ def _add_param(model, name, param):
...
@@ -30,7 +30,7 @@ def _add_param(model, name, param):
module
.
_parameters
[
name_list
[
-
1
]]
=
param
module
.
_parameters
[
name_list
[
-
1
]]
=
param
def
convert_to_torch_module
(
gemini_ddp_model
)
->
torch
.
nn
.
Module
:
def
convert_to_torch_module
(
gemini_ddp_model
:
'GeminiDDP'
)
->
torch
.
nn
.
Module
:
"""convert_to_torch_module
"""convert_to_torch_module
Args:
Args:
...
@@ -39,11 +39,12 @@ def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
...
@@ -39,11 +39,12 @@ def convert_to_torch_module(gemini_ddp_model) -> torch.nn.Module:
Returns:
Returns:
torch.nn.Module: a torch model contains the params of gemini_ddp_model
torch.nn.Module: a torch model contains the params of gemini_ddp_model
"""
"""
from
colossalai.nn.parallel
import
GeminiDDP
assert
isinstance
(
gemini_ddp_model
,
GeminiDDP
)
module
=
gemini_ddp_model
.
module
module
=
gemini_ddp_model
.
module
for
n
,
p
in
module
.
named_parameters
():
# replace ColoTensor to torch.nn.Tensor in module
if
isinstance
(
p
,
ColoTensor
):
for
n
,
p
in
gemini_ddp_model
.
torch_named_parameters
():
p
.
to_replicate_
()
_add_param
(
module
,
n
,
p
)
_add_param
(
module
,
n
,
p
.
data
)
return
module
return
module
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment