Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2c787d7f
Unverified
Commit
2c787d7f
authored
Aug 31, 2023
by
Baizhou Zhang
Committed by
GitHub
Aug 31, 2023
Browse files
[shardformer] fix submodule replacement bug when enabling pp (#4544)
parent
ec18fc73
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
21 additions
and
12 deletions
+21
-12
colossalai/shardformer/shard/sharder.py
colossalai/shardformer/shard/sharder.py
+13
-12
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
...heckpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
+2
-0
tests/test_shardformer/test_model/test_shard_chatglm2.py
tests/test_shardformer/test_model/test_shard_chatglm2.py
+2
-0
tests/test_shardformer/test_model/test_shard_gpt2.py
tests/test_shardformer/test_model/test_shard_gpt2.py
+2
-0
tests/test_shardformer/test_model/test_shard_opt.py
tests/test_shardformer/test_model/test_shard_opt.py
+2
-0
No files found.
colossalai/shardformer/shard/sharder.py
View file @
2c787d7f
...
@@ -92,22 +92,21 @@ class ModelSharder(object):
...
@@ -92,22 +92,21 @@ class ModelSharder(object):
param_replacement (List[Callable]): The function list to get parameter shard information in policy
param_replacement (List[Callable]): The function list to get parameter shard information in policy
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
method_replacement (Dict[str, Callable]): Key is the method name, value is the method for replacement
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
sub_module_replacement ((List[SubModuleReplacementDescription]): The function list to get sub module shard information in policy
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
"""
# released layers are not shardable
can_replace_param_or_layer
=
include
is
None
or
module
in
include
if
(
isinstance
(
origin_cls
,
str
)
and
origin_cls
==
module
.
__class__
.
__name__
)
or
\
if
(
isinstance
(
origin_cls
,
str
)
and
origin_cls
==
module
.
__class__
.
__name__
)
or
\
(
module
.
__class__
==
origin_cls
):
(
module
.
__class__
==
origin_cls
):
if
attr_replacement
is
not
None
:
if
attr_replacement
is
not
None
:
self
.
_replace_attr
(
module
,
attr_replacement
)
self
.
_replace_attr
(
module
,
attr_replacement
)
if
param_replacement
is
not
None
and
can_replace_param_or_layer
:
if
param_replacement
is
not
None
and
(
include
is
None
or
module
in
include
)
:
self
.
_replace_param
(
module
,
param_replacement
)
self
.
_replace_param
(
module
,
param_replacement
)
if
method_replacement
is
not
None
:
if
method_replacement
is
not
None
:
self
.
_replace_method
(
module
,
method_replacement
)
self
.
_replace_method
(
module
,
method_replacement
)
if
sub_module_replacement
is
not
None
and
can_replace_param_or_layer
:
if
sub_module_replacement
is
not
None
:
self
.
_replace_sub_module
(
module
,
sub_module_replacement
)
self
.
_replace_sub_module
(
module
,
sub_module_replacement
,
include
)
for
name
,
child
in
module
.
named_children
():
for
name
,
child
in
module
.
named_children
():
self
.
_recursive_replace_layer
(
child
,
self
.
_recursive_replace_layer
(
child
,
...
@@ -154,18 +153,17 @@ class ModelSharder(object):
...
@@ -154,18 +153,17 @@ class ModelSharder(object):
bound_method
=
MethodType
(
new_method
,
module
)
bound_method
=
MethodType
(
new_method
,
module
)
setattr
(
module
,
method_name
,
bound_method
)
setattr
(
module
,
method_name
,
bound_method
)
def
_replace_sub_module
(
def
_replace_sub_module
(
self
,
self
,
org_layer
:
nn
.
Module
,
org_layer
:
nn
.
Module
,
sub_module_replacement
:
List
[
SubModuleReplacementDescription
],
sub_module_replacement
:
List
[
SubModuleReplacementDescription
],
)
->
None
:
include
:
Optional
[
Set
[
nn
.
Module
]]
=
None
)
->
None
:
r
"""
r
"""
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Shard one layer according to the policy, the layer should be the same class as the key in policy's argument_policy return dict
Args:
Args:
org_layer (torch.nn.Module): The origin layer object to shard
org_layer (torch.nn.Module): The origin layer object to shard
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
sub_module_replacement (List[SubModuleReplacementDescription]): The sub module replacement description list
include (Set[nn.Module], optional): The set of modules to keep on current device when pipeline parallel is enabled. Defaults to None
"""
"""
for
description
in
sub_module_replacement
:
for
description
in
sub_module_replacement
:
suffix
=
description
.
suffix
suffix
=
description
.
suffix
...
@@ -174,9 +172,12 @@ class ModelSharder(object):
...
@@ -174,9 +172,12 @@ class ModelSharder(object):
assert
target_module
is
not
None
,
'target_module should not be None'
assert
target_module
is
not
None
,
'target_module should not be None'
# TODO: support different parallel mode
native_sub_module
=
getattr_
(
org_layer
,
suffix
,
ignore
=
True
)
native_sub_module
=
getattr_
(
org_layer
,
suffix
,
ignore
=
True
)
# Skip replacement if submodule is not kept by current device when pipeline parallel is enabled.
if
(
include
is
not
None
)
and
(
native_sub_module
is
not
None
)
and
(
native_sub_module
not
in
include
):
continue
assert
not
isinstance
(
native_sub_module
,
target_module
),
\
assert
not
isinstance
(
native_sub_module
,
target_module
),
\
f
"The module with suffix
{
suffix
}
has been replaced, please check the policy"
f
"The module with suffix
{
suffix
}
has been replaced, please check the policy"
...
...
tests/test_checkpoint_io/test_hybrid_parallel_plugin_checkpoint_io.py
View file @
2c787d7f
...
@@ -7,6 +7,7 @@ from utils import shared_tempdir
...
@@ -7,6 +7,7 @@ from utils import shared_tempdir
import
colossalai
import
colossalai
from
colossalai.booster
import
Booster
from
colossalai.booster
import
Booster
from
colossalai.booster.plugin
import
HybridParallelPlugin
from
colossalai.booster.plugin
import
HybridParallelPlugin
from
colossalai.shardformer.layer.utils
import
Randomizer
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.testing
import
(
from
colossalai.testing
import
(
check_state_dict_equal
,
check_state_dict_equal
,
...
@@ -100,6 +101,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
...
@@ -100,6 +101,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
booster
.
load_model
(
new_model
,
model_ckpt_path
)
booster
.
load_model
(
new_model
,
model_ckpt_path
)
check_state_dict_equal
(
model
.
unwrap
().
state_dict
(),
new_model
.
unwrap
().
state_dict
(),
False
)
check_state_dict_equal
(
model
.
unwrap
().
state_dict
(),
new_model
.
unwrap
().
state_dict
(),
False
)
Randomizer
.
reset_index
()
clear_layout_converter
()
clear_layout_converter
()
...
...
tests/test_shardformer/test_model/test_shard_chatglm2.py
View file @
2c787d7f
...
@@ -4,6 +4,7 @@ from torch import distributed as dist
...
@@ -4,6 +4,7 @@ from torch import distributed as dist
import
colossalai
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.layer.utils
import
Randomizer
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.kit.model_zoo
import
model_zoo
...
@@ -105,6 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -105,6 +106,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grads
# check grads
check_all_grad_tensors
(
grads_to_check
)
check_all_grad_tensors
(
grads_to_check
)
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_gpt2.py
View file @
2c787d7f
...
@@ -4,6 +4,7 @@ from torch import distributed as dist
...
@@ -4,6 +4,7 @@ from torch import distributed as dist
import
colossalai
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.layer.utils
import
Randomizer
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.kit.model_zoo
import
model_zoo
...
@@ -97,6 +98,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -97,6 +98,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grads
# check grads
check_all_grad_tensors
(
grads_to_check
)
check_all_grad_tensors
(
grads_to_check
)
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
tests/test_shardformer/test_model/test_shard_opt.py
View file @
2c787d7f
...
@@ -6,6 +6,7 @@ from torch import distributed as dist
...
@@ -6,6 +6,7 @@ from torch import distributed as dist
import
colossalai
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.layer.utils
import
Randomizer
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.tensor.d_tensor.api
import
clear_layout_converter
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
colossalai.testing
import
clear_cache_before_run
,
parameterize
,
rerun_if_address_is_in_use
,
spawn
from
tests.kit.model_zoo
import
model_zoo
from
tests.kit.model_zoo
import
model_zoo
...
@@ -107,6 +108,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
...
@@ -107,6 +108,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check grads
# check grads
check_all_grad_tensors
(
grads_to_check
)
check_all_grad_tensors
(
grads_to_check
)
Randomizer
.
reset_index
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment