Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
41429b9b
Unverified
Commit
41429b9b
authored
Jan 11, 2023
by
YuliangLiu0306
Committed by
GitHub
Jan 11, 2023
Browse files
[autoparallel] add shard option (#2423)
parent
1b7587d9
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
149 additions
and
1 deletion
+149
-1
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
...salai/auto_parallel/tensor_shard/node_handler/__init__.py
+2
-1
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
...i/auto_parallel/tensor_shard/node_handler/node_handler.py
+18
-0
colossalai/auto_parallel/tensor_shard/node_handler/option.py
colossalai/auto_parallel/tensor_shard/node_handler/option.py
+17
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py
.../test_tensor_shard/test_node_handler/test_shard_option.py
+112
-0
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/__init__.py
View file @
41429b9b
...
@@ -11,6 +11,7 @@ from .layer_norm_handler import LayerNormModuleHandler
...
@@ -11,6 +11,7 @@ from .layer_norm_handler import LayerNormModuleHandler
from
.linear_handler
import
LinearFunctionHandler
,
LinearModuleHandler
from
.linear_handler
import
LinearFunctionHandler
,
LinearModuleHandler
from
.matmul_handler
import
MatMulHandler
from
.matmul_handler
import
MatMulHandler
from
.normal_pooling_handler
import
NormPoolingHandler
from
.normal_pooling_handler
import
NormPoolingHandler
from
.option
import
ShardOption
from
.output_handler
import
OutputHandler
from
.output_handler
import
OutputHandler
from
.placeholder_handler
import
PlaceholderHandler
from
.placeholder_handler
import
PlaceholderHandler
from
.registry
import
operator_registry
from
.registry
import
operator_registry
...
@@ -27,5 +28,5 @@ __all__ = [
...
@@ -27,5 +28,5 @@ __all__ = [
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlaceholderHandler'
,
'OutputHandler'
,
'WhereHandler'
,
'UnaryElementwiseHandler'
,
'ReshapeHandler'
,
'PlaceholderHandler'
,
'OutputHandler'
,
'WhereHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'NormPoolingHandler'
,
'BinaryElementwiseHandler'
,
'MatMulHandler'
,
'operator_registry'
,
'ADDMMFunctionHandler'
,
'GetItemHandler'
,
'GetattrHandler'
,
'ViewHandler'
,
'PermuteHandler'
,
'TensorConstructorHandler'
,
'GetItemHandler'
,
'GetattrHandler'
,
'ViewHandler'
,
'PermuteHandler'
,
'TensorConstructorHandler'
,
'EmbeddingModuleHandler'
,
'EmbeddingFunctionHandler'
,
'SumHandler'
,
'SoftmaxHandler'
'EmbeddingModuleHandler'
,
'EmbeddingFunctionHandler'
,
'SumHandler'
,
'SoftmaxHandler'
,
'ShardOption'
]
]
colossalai/auto_parallel/tensor_shard/node_handler/node_handler.py
View file @
41429b9b
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
colossalai.auto_parallel.meta_profiler.metainfo
import
MetaInfo
,
meta_register
from
colossalai.auto_parallel.meta_profiler.metainfo
import
MetaInfo
,
meta_register
from
colossalai.auto_parallel.tensor_shard.node_handler.option
import
ShardOption
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationData
,
OperationDataType
,
OperationDataType
,
...
@@ -35,12 +36,14 @@ class NodeHandler(ABC):
...
@@ -35,12 +36,14 @@ class NodeHandler(ABC):
node
:
Node
,
node
:
Node
,
device_mesh
:
DeviceMesh
,
device_mesh
:
DeviceMesh
,
strategies_vector
:
StrategiesVector
,
strategies_vector
:
StrategiesVector
,
shard_option
:
ShardOption
=
ShardOption
.
STANDARD
,
)
->
None
:
)
->
None
:
self
.
node
=
node
self
.
node
=
node
self
.
predecessor_node
=
list
(
node
.
_input_nodes
.
keys
())
self
.
predecessor_node
=
list
(
node
.
_input_nodes
.
keys
())
self
.
successor_node
=
list
(
node
.
users
.
keys
())
self
.
successor_node
=
list
(
node
.
users
.
keys
())
self
.
device_mesh
=
device_mesh
self
.
device_mesh
=
device_mesh
self
.
strategies_vector
=
strategies_vector
self
.
strategies_vector
=
strategies_vector
self
.
shard_option
=
shard_option
def
update_resharding_cost
(
self
,
strategy
:
ShardingStrategy
)
->
None
:
def
update_resharding_cost
(
self
,
strategy
:
ShardingStrategy
)
->
None
:
"""
"""
...
@@ -181,6 +184,21 @@ class NodeHandler(ABC):
...
@@ -181,6 +184,21 @@ class NodeHandler(ABC):
if
op_data
.
data
is
not
None
and
isinstance
(
op_data
.
data
,
torch
.
Tensor
):
if
op_data
.
data
is
not
None
and
isinstance
(
op_data
.
data
,
torch
.
Tensor
):
check_sharding_spec_validity
(
sharding_spec
,
op_data
.
data
)
check_sharding_spec_validity
(
sharding_spec
,
op_data
.
data
)
remove_strategy_list
=
[]
for
strategy
in
self
.
strategies_vector
:
shard_level
=
0
for
op_data
,
sharding_spec
in
strategy
.
sharding_specs
.
items
():
if
op_data
.
data
is
not
None
and
isinstance
(
op_data
.
data
,
torch
.
Tensor
):
for
dim
,
shard_axis
in
sharding_spec
.
dim_partition_dict
.
items
():
shard_level
+=
len
(
shard_axis
)
if
self
.
shard_option
==
ShardOption
.
SHARD
and
shard_level
==
0
:
remove_strategy_list
.
append
(
strategy
)
if
self
.
shard_option
==
ShardOption
.
FULL_SHARD
and
shard_level
<=
1
:
remove_strategy_list
.
append
(
strategy
)
for
strategy
in
remove_strategy_list
:
self
.
strategies_vector
.
remove
(
strategy
)
return
self
.
strategies_vector
return
self
.
strategies_vector
def
post_process
(
self
,
strategy
:
ShardingStrategy
)
->
Union
[
ShardingStrategy
,
List
[
ShardingStrategy
]]:
def
post_process
(
self
,
strategy
:
ShardingStrategy
)
->
Union
[
ShardingStrategy
,
List
[
ShardingStrategy
]]:
...
...
colossalai/auto_parallel/tensor_shard/node_handler/option.py
0 → 100644
View file @
41429b9b
from
enum
import
Enum
__all__
=
[
'ShardOption'
]
class
ShardOption
(
Enum
):
"""
This enum class is to define the shard level required in node strategies.
Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
"""
STANDARD
=
0
SHARD
=
1
FULL_SHARD
=
2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_shard_option.py
0 → 100644
View file @
41429b9b
from
functools
import
partial
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler
import
LinearFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.option
import
ShardOption
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing
import
parameterize
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.utils
import
parameterize
class
LinearModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
input
,
others
,
bias
=
None
):
x
=
nn
.
functional
.
linear
(
input
,
others
,
bias
=
bias
)
return
x
def
check_shard_option
(
shard_option
):
model
=
LinearModel
().
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
4
,
4
,
16
).
to
(
'meta'
),
'others'
:
torch
.
rand
(
32
,
16
).
to
(
'meta'
)
})
gm
=
ColoGraphModule
(
model
,
graph
)
linear_func_node
=
list
(
graph
.
nodes
)[
2
]
strategies_vector
=
StrategiesVector
(
linear_func_node
)
# build handler
handler
=
LinearFunctionHandler
(
node
=
linear_func_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
,
shard_option
=
shard_option
)
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# SS = SR x RS
assert
'S1S0 = S1R x RS0_0'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_1'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_2'
in
strategy_name_list
assert
'S0S1 = S0R x RS1_0'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0_2'
in
strategy_name_list
# SR = SS x SR
assert
'S0R = S0S1 x S1R_1'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_2'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_0'
in
strategy_name_list
assert
'S0R = S0S1 x S1R_0'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_1'
in
strategy_name_list
assert
'S1R = S1S0 x S0R_2'
in
strategy_name_list
# RS = RS x SS
assert
'RS0 = RS1 x S1S0'
in
strategy_name_list
assert
'RS1 = RS0 x S0S1'
in
strategy_name_list
# S01R = S01R x RR
assert
'S01R = S01R x RR_0'
in
strategy_name_list
assert
'S01R = S01R x RR_1'
in
strategy_name_list
assert
'S01R = S01R x RR_2'
in
strategy_name_list
# RR = RS01 x S01R
assert
'RR = RS01 x S01R'
in
strategy_name_list
# RS01 = RR x RS01
assert
'RS01 = RR x RS01'
in
strategy_name_list
if
shard_option
==
ShardOption
.
SHARD
:
# RR = RS x SR
assert
'RR = RS0 x S0R'
in
strategy_name_list
assert
'RR = RS1 x S1R'
in
strategy_name_list
# RS= RR x RS
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
if
shard_option
==
ShardOption
.
STANDARD
:
# RR = RS x SR
assert
'RR = RS0 x S0R'
in
strategy_name_list
assert
'RR = RS1 x S1R'
in
strategy_name_list
# RS= RR x RS
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
# RR = RR x RR
assert
'RR = RR x RR'
in
strategy_name_list
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_shard_option
():
for
shard_option
in
[
ShardOption
.
STANDARD
,
ShardOption
.
SHARD
,
ShardOption
.
FULL_SHARD
]:
check_shard_option
(
shard_option
)
if
__name__
==
'__main__'
:
test_shard_option
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment