Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
413c0534
Unverified
Commit
413c0534
authored
Aug 25, 2022
by
YuliangLiu0306
Committed by
GitHub
Aug 25, 2022
Browse files
[autoparallel] add cost graph class (#1481)
* [autoparallel] add cost graph class * polish code
parent
4b03c25f
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
141 additions
and
5 deletions
+141
-5
colossalai/auto_parallel/solver/cost_graph.py
colossalai/auto_parallel/solver/cost_graph.py
+131
-0
colossalai/auto_parallel/solver/operator_handler.py
colossalai/auto_parallel/solver/operator_handler.py
+2
-1
colossalai/auto_parallel/solver/sharding_strategy.py
colossalai/auto_parallel/solver/sharding_strategy.py
+1
-1
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
...salai/fx/tracer/meta_patch/patched_function/arithmetic.py
+1
-1
tests/test_auto_parallel/test_conv_handler.py
tests/test_auto_parallel/test_conv_handler.py
+3
-1
tests/test_auto_parallel/test_dot_handler.py
tests/test_auto_parallel/test_dot_handler.py
+3
-1
No files found.
colossalai/auto_parallel/solver/cost_graph.py
0 → 100644
View file @
413c0534
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
typing
import
List
from
torch.fx.node
import
Node
class
CostGraph
:
'''
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
be given by the StrategiesVector depending on the type of target node and following nodes.
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
'''
def
__init__
(
self
,
leaf_strategies
,
simplify
=
True
):
self
.
leaf_strategies
=
leaf_strategies
# stores number of strategies in each node
self
.
node_lens
=
{
strategies_vector
.
node
:
len
(
strategies_vector
)
for
strategies_vector
in
self
.
leaf_strategies
}
# extra_node_costs will store the extra costs introduced by merging nodes
self
.
extra_node_costs
=
{}
self
.
simplify
=
simplify
self
.
_build_cost_graph
()
def
_build_cost_graph
(
self
):
'''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
'''
self
.
edge_costs
=
{}
if
self
.
simplify
:
self
.
merge_pair
=
[]
for
strategies_vector
in
self
.
leaf_strategies
:
# build edge_cost
dst_node
=
strategies_vector
.
node
for
src_node
in
strategies_vector
.
predecessor_nodes
:
node_pair
=
(
src_node
,
dst_node
)
src_index
=
strategies_vector
.
predecessor_nodes
.
index
(
src_node
)
edge_cost
=
{}
for
i
in
range
(
len
(
strategies_vector
)):
for
j
in
range
(
len
(
src_node
.
stategy_vector
)):
edge_cost
[(
i
,
j
)]
=
strategies_vector
[
i
].
resharding_costs
[
src_index
][
j
]
self
.
edge_costs
[
node_pair
]
=
edge_cost
# add parents and children attribute to node
setattr
(
dst_node
,
'parents'
,
strategies_vector
.
predecessor_nodes
)
setattr
(
dst_node
,
'children'
,
strategies_vector
.
successor_nodes
)
if
self
.
simplify
and
strategies_vector
.
check_merge
():
for
following_node
in
strategies_vector
.
successor_nodes
:
self
.
merge_pair
.
append
((
dst_node
,
following_node
))
def
get_edge_cost
(
self
,
src_node
,
dst_node
):
return
self
.
edge_costs
[(
src_node
,
dst_node
)]
def
merge_node
(
self
,
src_node
,
dst_node
):
'''
To merge src_node into dst_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
x represents the picking strategy of node 1 merged into node 2 strategy 0.
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
another is the origin extra costs in src_node strategy.
3. Build connections between new node pairs, and remove the src_node after all consumer nodes
detached from it.
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
'''
src_node_index
=
dst_node
.
parents
.
index
(
src_node
)
# build merge_map
merge_map
=
{}
for
dst_strate_index
,
strategy
in
enumerate
(
dst_node
.
strategies_vector
):
resharding_costs
=
strategy
.
resharding_costs
resharding_cost_for_src
=
resharding_costs
[
src_node_index
]
lowest_cost_index
=
resharding_cost_for_src
.
index
(
min
(
resharding_cost_for_src
))
merge_map
[
dst_strate_index
]
=
lowest_cost_index
# extra_node_cost for dst node
extra_node_costs
[
dst_node
]
=
[
0.0
for
_
in
range
(
self
.
node_lens
[
dst_node
])]
for
dst_strate_index
,
strategy
in
enumerate
(
dst_node
.
strategies_vector
):
target_strate_index
=
merge_map
[
dst_strate_index
]
extra_node_costs
[
dst_node
][
dst_strate_index
]
+=
strategy
.
resharding_costs
[
src_node_index
][
target_strate_index
]
if
src_node
in
extra_node_costs
:
extra_node_costs
[
dst_node
][
dst_strate_index
]
+=
extra_node_costs
[
src_node
][
target_strate_index
]
# connect dst node and parents of src node
dst_node
.
parents
.
remove
(
src_node
)
src_node
.
children
.
remove
(
dst_node
)
node_pair_to_remove
=
[(
src_node
,
dst_node
)]
for
parent_node
in
src_node
.
parents
:
if
parent_node
not
in
dst_node
.
parents
:
dst_node
.
parents
.
append
(
parent
)
if
dst_node
not
in
parent_node
.
children
:
parent_node
.
children
.
append
(
dst_node
)
# remove src node from cost graph when src node has no consumer.
if
len
(
src_node
.
children
)
==
0
:
parent_node
.
children
.
remove
(
src_node
)
node_pair
=
(
parent_node
,
src_node
)
self
.
edge_costs
.
pop
(
node_pair
)
# add new node pair to cost graph
for
parent_node
in
src_node
.
parents
:
new_node_pair
=
(
parent_node
,
dst_node
)
old_node_pair
=
(
parent_node
,
src_node
)
if
new_node_pair
in
self
.
edge_costs
:
continue
edge_cost
=
{}
for
i
in
range
(
self
.
node_lens
[
dst_node
]):
for
j
in
range
(
self
.
node_lens
[
parent_node
]):
src_strate_index
=
merge_map
[
i
]
edge_cost
[(
i
,
j
)]
=
self
.
edge_costs
[
old_node_pair
][(
j
,
src_strate_index
)]
self
.
edge_costs
[
new_node_pair
]
=
edge_cost
def
simplify_graph
(
self
):
if
not
self
.
simplify
:
return
for
(
src_node
,
dst_node
)
in
self
.
merge_pair
:
self
.
merge_node
(
src_node
,
dst_node
)
colossalai/auto_parallel/solver/operator_handler.py
View file @
413c0534
...
...
@@ -84,6 +84,7 @@ class OperatorHandler(ABC):
for
input_node
,
input_spec
in
zip
(
self
.
predecessor_node
,
sharding_spec_for_input
):
resharding_costs
[
input_node
]
=
[]
for
strategy
in
input_node
.
strategies_vector
:
_
,
_
,
resharding_cost
=
self
.
shape_consistency_manager
.
shape_consistency
(
strategy
,
input_spec
)
_
,
_
,
resharding_cost
=
self
.
shape_consistency_manager
.
shape_consistency
(
strategy
.
output_sharding_spec
,
input_spec
)
resharding_costs
[
input_node
].
append
(
resharding_cost
)
return
resharding_cost
colossalai/auto_parallel/solver/sharding_strategy.py
View file @
413c0534
...
...
@@ -47,7 +47,7 @@ class StrategiesVector(list):
self
.
node
=
node
# fetch its input and output nodes
self
.
predecessor_nodes
=
list
(
node
.
_input_nodes
.
keys
())
self
.
successor_n
d
oes
=
list
(
node
.
users
.
keys
())
self
.
successor_no
d
es
=
list
(
node
.
users
.
keys
())
def
check_merge
(
self
):
pass
colossalai/fx/tracer/meta_patch/patched_function/arithmetic.py
View file @
413c0534
...
...
@@ -15,7 +15,7 @@ def torch_matmul(input, other, *, out=None):
shape
=
(
input
.
size
(
0
),
other
.
size
(
1
))
elif
d1
==
1
and
d2
==
2
:
shape
=
(
other
.
size
(
1
),)
elif
d1
==
2
and
d
1
==
1
:
elif
d1
==
2
and
d
2
==
1
:
shape
=
(
input
.
size
(
0
),)
else
:
max_length
=
max
(
input
.
dim
(),
other
.
dim
())
...
...
tests/test_auto_parallel/test_conv_handler.py
View file @
413c0534
...
...
@@ -70,7 +70,9 @@ def test_conv_handler():
sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
entire_shape
,
sharding_sequence
=
sharding_sequence
)
strategies_vector_for_input
.
append
(
sharding_spec
)
strategy_name
=
str
(
sharding_spec
.
sharding_sequence
)
sharding_strategy
=
ShardingStrategy
(
name
=
strategy_name
,
output_sharding_spec
=
sharding_spec
)
strategies_vector_for_input
.
append
(
sharding_strategy
)
setattr
(
nodes
[
1
],
'strategies_vector'
,
strategies_vector_for_input
)
# generate conv strategy
...
...
tests/test_auto_parallel/test_dot_handler.py
View file @
413c0534
...
...
@@ -69,7 +69,9 @@ def test_dot_handler():
sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
entire_shape
,
sharding_sequence
=
sharding_sequence
)
strategies_vector_for_input
.
append
(
sharding_spec
)
strategy_name
=
str
(
sharding_spec
.
sharding_sequence
)
sharding_strategy
=
ShardingStrategy
(
name
=
strategy_name
,
output_sharding_spec
=
sharding_spec
)
strategies_vector_for_input
.
append
(
sharding_strategy
)
setattr
(
nodes
[
1
],
'strategies_vector'
,
strategies_vector_for_input
)
# generate dot strategy
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment