Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
197d0bf4
Unverified
Commit
197d0bf4
authored
Feb 28, 2023
by
YuliangLiu0306
Committed by
GitHub
Feb 28, 2023
Browse files
[autoparallel] apply repeat block to reduce solving time (#2912)
parent
a8480911
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
57 additions
and
28 deletions
+57
-28
colossalai/auto_parallel/tensor_shard/initialize.py
colossalai/auto_parallel/tensor_shard/initialize.py
+5
-3
colossalai/auto_parallel/tensor_shard/solver/solver.py
colossalai/auto_parallel/tensor_shard/solver/solver.py
+25
-14
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
...to_parallel/tensor_shard/solver/strategies_constructor.py
+21
-0
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
...test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
+3
-5
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
...uto_parallel/test_tensor_shard/test_node_handler/utils.py
+1
-3
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
..._parallel/test_tensor_shard/test_solver_with_resnet_v2.py
+2
-3
No files found.
colossalai/auto_parallel/tensor_shard/initialize.py
View file @
197d0bf4
...
...
@@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
'''
graph_analyser
=
GraphAnalyser
(
gm
)
liveness_list
=
graph_analyser
.
liveness_analysis
()
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
# liveness_list = graph_analyser.liveness_analysis()
cost_graph
=
CostGraph
(
strategy_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
solver
=
Solver
(
gm
.
graph
,
strategy_constructor
,
cost_graph
,
graph_analyser
,
memory_budget
=
memory_budget
)
solver
=
Solver
(
gm
.
graph
,
strategy_constructor
,
cost_graph
,
memory_budget
=
memory_budget
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
...
...
colossalai/auto_parallel/tensor_shard/solver/solver.py
View file @
197d0bf4
...
...
@@ -32,7 +32,7 @@ class Solver:
graph
:
Graph
,
strategies_constructor
:
StrategiesConstructor
,
cost_graph
:
CostGraph
,
graph_analyser
:
GraphAnalyser
,
graph_analyser
:
GraphAnalyser
=
None
,
memory_budget
:
float
=
-
1.0
,
solution_numbers
:
int
=
1
,
forward_only
:
bool
=
False
,
...
...
@@ -63,7 +63,10 @@ class Solver:
self
.
memory_increasing_coefficient
=
memory_increasing_coefficient
else
:
self
.
memory_increasing_coefficient
=
1
self
.
liveness_list
=
self
.
graph_analyser
.
liveness_analysis
()
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# self.liveness_list = self.graph_analyser.liveness_analysis()
self
.
liveness_list
=
self
.
nodes
self
.
node_index_dict
=
self
.
_generate_node_index_dict
()
# The last solution vector of auto sharding.
self
.
last_s_val
=
None
...
...
@@ -140,7 +143,7 @@ class Solver:
liveness_set
=
self
.
liveness_list
# omit alias_set now
alias_set
=
None
alias_set
=
self
.
strategies_constructor
.
alias_set
alias_convert_costs
=
None
# prepare compute_costs, communication_costs and memory_costs
...
...
@@ -230,6 +233,7 @@ class Solver:
# 0. Unpack flatten numpy arrays
s_follow
=
following_nodes
s_alias
=
alias_set
E
=
edge_pairs
.
reshape
((
-
1
,
2
))
# noqa
r
=
[]
...
...
@@ -294,8 +298,11 @@ class Solver:
if
strategies_len
[
i
]
==
1
:
s
.
append
([
1
])
else
:
if
i
not
in
s_alias
:
num_nodes
+=
1
s
.
append
(
LpVariable
.
matrix
(
f
"s[
{
i
}
]"
,
(
range
(
strategies_len
[
i
]),),
cat
=
"Binary"
))
else
:
s
.
append
(
s
[
s_alias
[
i
]])
else
:
if
s_follow
[
i
]
<
len
(
s
):
s
.
append
(
s
[
s_follow
[
i
]])
...
...
@@ -311,15 +318,20 @@ class Solver:
#############################
e
=
[]
num_edges
=
0
map_edge_to_idx
=
{}
for
(
idx
,
(
i
,
j
))
in
enumerate
(
E
):
if
len
(
s
[
i
])
==
1
:
e
.
append
(
s
[
j
])
elif
len
(
s
[
j
])
==
1
:
e
.
append
(
s
[
i
])
else
:
if
i
in
s_alias
and
j
in
s_alias
and
(
s_alias
[
i
],
s_alias
[
j
])
in
map_edge_to_idx
:
e
.
append
(
e
[
map_edge_to_idx
[(
s_alias
[
i
],
s_alias
[
j
])]])
else
:
num_edges
+=
1
e
.
append
(
LpVariable
.
matrix
(
f
"e[
{
i
}
,
{
j
}
]"
,
(
range
(
len
(
s
[
i
])
*
len
(
s
[
j
])),),
cat
=
"Binary"
))
assert
len
(
e
[
idx
])
==
len
(
r
[
idx
])
map_edge_to_idx
[(
i
,
j
)]
=
idx
for
element
in
s
:
assert
len
(
element
)
>
0
# 2. Set initial value
...
...
@@ -371,12 +383,11 @@ class Solver:
# compute memory consumption with liveness set #
#################################################
if
memory_budget
>
0
:
for
liveness_stage
in
liveness_set
:
mem
=
0
for
live_variable
in
liveness_stage
.
unique_live_vars
:
if
live_variable
.
node
not
in
self
.
node_index_dict
:
for
node
in
liveness_set
:
if
node
not
in
self
.
node_index_dict
:
continue
node_index
=
self
.
node_index_dict
[
live_variable
.
node
]
node_index
=
self
.
node_index_dict
[
node
]
mem
+=
lpSum
(
s
[
node_index
][
j
]
*
m
[
node_index
][
j
]
for
j
in
range
(
len
(
s
[
node_index
])))
prob
+=
mem
<=
memory_budget
...
...
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
View file @
197d0bf4
...
...
@@ -15,6 +15,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
StrategiesVector
from
colossalai.auto_parallel.tensor_shard.utils
import
generate_resharding_costs
,
generate_sharding_spec
from
colossalai.auto_parallel.tensor_shard.utils.factory
import
find_repeat_blocks
from
colossalai.device.device_mesh
import
DeviceMesh
from
..options
import
DataloaderOption
,
SolverOptions
...
...
@@ -42,6 +43,7 @@ class StrategiesConstructor:
self
.
strategy_map
=
{}
self
.
solver_options
=
solver_options
self
.
no_strategy_nodes
=
[]
self
.
alias_set
=
None
def
remove_duplicated_strategy
(
self
,
strategies_vector
):
'''
...
...
@@ -59,6 +61,22 @@ class StrategiesConstructor:
for
strategy
in
remove_list
:
strategies_vector
.
remove
(
strategy
)
def
generate_alias_set
(
self
):
node_list
=
[
strategy_vector
.
node
for
strategy_vector
in
self
.
leaf_strategies
]
common_blocks
=
find_repeat_blocks
(
node_list
,
self
.
root_module
,
common_length_threshold
=
10
)
repeat_block_nums
=
len
(
common_blocks
)
alias_set
=
{}
if
repeat_block_nums
==
0
:
return
alias_set
for
index
,
common_node
in
enumerate
(
common_blocks
[
0
]):
for
i
in
range
(
1
,
repeat_block_nums
):
alias_set
[
node_list
.
index
(
common_blocks
[
i
][
index
])]
=
node_list
.
index
(
common_node
)
return
alias_set
def
build_strategies_and_cost
(
self
):
"""
This method is to build the strategy vector for each node in the computation graph.
...
...
@@ -175,3 +193,6 @@ class StrategiesConstructor:
self
.
leaf_strategies
.
remove
(
node
.
strategies_vector
)
if
node
in
self
.
strategy_map
:
self
.
strategy_map
.
pop
(
node
)
alias_set
=
self
.
generate_alias_set
()
self
.
alias_set
=
alias_set
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
View file @
197d0bf4
...
...
@@ -15,13 +15,13 @@ from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2
BATCH_SIZE
=
1
SEQ_LENGTH
=
32
HIDDEN_DIM
=
768
HIDDEN_DIM
=
384
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
parameterize
(
'model_cls'
,
[
GPT2Block
,
GPT2Attention
,
GPT2MLP
,
GPT2Model
])
def
test_self_attention_block
(
model_cls
):
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
4
,
n_head
=
16
,
n_embd
=
HIDDEN_DIM
)
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
12
,
n_head
=
16
,
n_embd
=
HIDDEN_DIM
)
if
model_cls
==
GPT2MLP
:
model
=
model_cls
(
intermediate_size
=
4
*
config
.
hidden_size
,
config
=
config
)
else
:
...
...
@@ -54,15 +54,13 @@ def test_self_attention_block(model_cls):
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
print
(
gm
.
graph
)
gm
.
recompile
()
graph_analyser
=
GraphAnalyser
(
gm
)
liveness_list
=
graph_analyser
.
liveness_analysis
()
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
,
memory_budget
=-
1
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
memory_budget
=-
1
)
ret
=
solver
.
call_solver_serialized_args
()
strategies_list
=
solver
.
last_s_val
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
View file @
197d0bf4
...
...
@@ -9,7 +9,6 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_pre
from
colossalai.auto_parallel.tensor_shard.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.solver
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.solver.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.solver.solver
import
Solver
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
...
...
@@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
# solution construction
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
graph_analyser
=
GraphAnalyser
(
gm
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
,
verbose
=
False
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
verbose
=
False
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
...
...
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
View file @
197d0bf4
...
...
@@ -51,15 +51,14 @@ def test_cost_graph():
# return fc
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
graph_analyser
=
GraphAnalyser
(
gm
)
liveness_list
=
graph_analyser
.
liveness_analysis
()
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
)
ret
=
solver
.
call_solver_serialized_args
()
print
(
ret
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment