Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
26a37b5c
Unverified
Commit
26a37b5c
authored
Aug 19, 2022
by
YuliangLiu0306
Committed by
GitHub
Aug 19, 2022
Browse files
[autoparallel] Add conv handler to generate strategies and costs info for conv (#1467)
parent
1b491ad7
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
554 additions
and
1 deletion
+554
-1
colossalai/auto_parallel/solver/conv_handler.py
colossalai/auto_parallel/solver/conv_handler.py
+384
-0
colossalai/auto_parallel/solver/sharding_strategy.py
colossalai/auto_parallel/solver/sharding_strategy.py
+54
-0
colossalai/tensor/sharding_spec.py
colossalai/tensor/sharding_spec.py
+1
-1
tests/test_auto_parallel/test_conv_handler.py
tests/test_auto_parallel/test_conv_handler.py
+115
-0
No files found.
colossalai/auto_parallel/solver/conv_handler.py
0 → 100644
View file @
26a37b5c
This diff is collapsed.
Click to expand it.
colossalai/auto_parallel/solver/sharding_strategy.py
0 → 100644
View file @
26a37b5c
class
ShardingStrategy
:
'''
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
and costs information using in solver.
Argument:
name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
compute_cost(float): Computation cost to complete this strategy.(default to 0)
communication_cost(float): Communication cost to complete this strategy.(default to 0)
memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.(default to None)
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
'''
def
__init__
(
self
,
name
,
output_sharding_spec
,
compute_cost
=
0
,
communication_cost
=
0
,
memory_cost
=
0
,
resharding_costs
=
None
,
input_shardings
=
None
):
self
.
name
=
name
self
.
output_sharding_spec
=
output_sharding_spec
self
.
compute_cost
=
compute_cost
self
.
communication_cost
=
communication_cost
self
.
memory_cost
=
memory_cost
self
.
resharding_costs
=
resharding_costs
self
.
input_shardings
=
input_shardings
class
StrategiesVector
:
'''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
Argument:
node(Node): node to build corresponding strategies_vector.
in_nodes(List[Node]): input nodes in the argument list of the node.
following_nodes(List[Node]): the nodes take the target node as their argument.
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
'''
def
__init__
(
self
,
node
,
in_nodes
,
following_nodes
=
None
,
strategies
=
[]):
self
.
node
=
node
self
.
in_nodes
=
in_nodes
self
.
following_nodes
=
following_nodes
self
.
strategies
=
strategies
def
check_merge
(
self
):
pass
colossalai/tensor/sharding_spec.py
View file @
26a37b5c
...
...
@@ -199,7 +199,7 @@ class ShardingSpec:
if
not
dim_spec
.
is_replica
:
if
index
not
in
new_dim_partition_dict
:
new_dim_partition_dict
[
index
]
=
[]
new_dim_partition_dict
[
index
].
app
end
(
dim_spec
.
shard_list
)
new_dim_partition_dict
[
index
].
ext
end
(
dim_spec
.
shard_list
)
self
.
dim_partition_dict
=
new_dim_partition_dict
def
sharding_sequence_difference
(
self
,
other
):
...
...
tests/test_auto_parallel/test_conv_handler.py
0 → 100644
View file @
26a37b5c
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
colossalai.fx.proxy
import
ColoProxy
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
_DimSpec
from
colossalai.auto_parallel.solver.conv_handler
import
ConvHandler
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.device.device_mesh
import
DeviceMesh
class
ConvModel
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_out
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
c_in
,
c_out
,
kernel_size
=
3
)
def
forward
(
self
,
x
):
x
=
x
*
2
x
=
self
.
conv
(
x
)
return
x
def
test_conv_handler
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
entire_shape
=
torch
.
Size
((
4
,
16
,
64
,
64
))
shape_consistency_manager
=
ShapeConsistencyManager
()
tracer
=
ColoTracer
()
model
=
ConvModel
(
16
,
32
)
input_sample
=
{
'x'
:
torch
.
rand
(
4
,
16
,
64
,
64
).
to
(
'meta'
)}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
# [x, mul, conv, output]
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
strategies_for_input
=
[]
sharding_option
=
(
None
,
0
,
1
)
for
first_sharding_index
in
sharding_option
:
for
second_sharding_index
in
sharding_option
:
if
first_sharding_index
is
not
None
and
second_sharding_index
==
first_sharding_index
:
continue
if
first_sharding_index
is
None
:
first_dim_spec
=
_DimSpec
([])
else
:
first_dim_spec
=
_DimSpec
([
first_sharding_index
])
if
second_sharding_index
is
None
:
second_dim_spec
=
_DimSpec
([])
else
:
second_dim_spec
=
_DimSpec
([
second_sharding_index
])
replica_dim_spec
=
_DimSpec
([])
sharding_sequence
=
[
first_dim_spec
,
second_dim_spec
,
replica_dim_spec
,
replica_dim_spec
]
sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
entire_shape
,
sharding_sequence
=
sharding_sequence
)
strategies_for_input
.
append
(
sharding_spec
)
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input
=
StrategiesVector
(
node
=
nodes
[
0
],
in_nodes
=
[
nodes
[
1
],
2
],
strategies
=
strategies_for_input
)
setattr
(
nodes
[
1
],
'strategies_vector'
,
strategies_vector_for_input
)
strategies_vector
=
StrategiesVector
(
node
=
nodes
[
2
],
in_nodes
=
[
nodes
[
1
],
])
conv_handler
=
ConvHandler
(
input_node
=
nodes
[
1
],
input_index
=
0
,
weight
=
dict
(
gm
.
named_modules
())[
nodes
[
2
].
name
].
weight
,
output_node
=
nodes
[
2
],
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
,
shape_consistency_manager
=
shape_consistency_manager
)
conv_handler
.
register_strategy_into_strategies_vector
()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
strategy_name_list
=
[
strategy
.
name
for
strategy
in
conv_handler
.
strategies_vector
.
strategies
]
# SS = SR x RS
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0'
in
strategy_name_list
# SR = SS x SR
assert
'S0R = S0S1 x S1R'
in
strategy_name_list
assert
'S1R = S1S0 x S0R'
in
strategy_name_list
# RS = RS x SS
assert
'RS0 = RS1 x S1S0'
in
strategy_name_list
assert
'RS1 = RS0 x S0S1'
in
strategy_name_list
# RS = RR x RS
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
# RR= RR x RR
assert
'RR = RR x RR'
in
strategy_name_list
if
__name__
==
'__main__'
:
test_conv_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment