Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
197d0bf4
Unverified
Commit
197d0bf4
authored
Feb 28, 2023
by
YuliangLiu0306
Committed by
GitHub
Feb 28, 2023
Browse files
[autoparallel] apply repeat block to reduce solving time (#2912)
parent
a8480911
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
57 additions
and
28 deletions
+57
-28
colossalai/auto_parallel/tensor_shard/initialize.py
colossalai/auto_parallel/tensor_shard/initialize.py
+5
-3
colossalai/auto_parallel/tensor_shard/solver/solver.py
colossalai/auto_parallel/tensor_shard/solver/solver.py
+25
-14
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
...to_parallel/tensor_shard/solver/strategies_constructor.py
+21
-0
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
...test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
+3
-5
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
...uto_parallel/test_tensor_shard/test_node_handler/utils.py
+1
-3
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
..._parallel/test_tensor_shard/test_solver_with_resnet_v2.py
+2
-3
No files found.
colossalai/auto_parallel/tensor_shard/initialize.py
View file @
197d0bf4
...
@@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
...
@@ -112,11 +112,13 @@ def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstruc
This method is used to solve the best solution for the given graph.
This method is used to solve the best solution for the given graph.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
'''
'''
graph_analyser
=
GraphAnalyser
(
gm
)
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
liveness_list
=
graph_analyser
.
liveness_analysis
()
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# graph_analyser = GraphAnalyser(gm)
# liveness_list = graph_analyser.liveness_analysis()
cost_graph
=
CostGraph
(
strategy_constructor
.
leaf_strategies
)
cost_graph
=
CostGraph
(
strategy_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
cost_graph
.
simplify_graph
()
solver
=
Solver
(
gm
.
graph
,
strategy_constructor
,
cost_graph
,
graph_analyser
,
memory_budget
=
memory_budget
)
solver
=
Solver
(
gm
.
graph
,
strategy_constructor
,
cost_graph
,
memory_budget
=
memory_budget
)
ret
=
solver
.
call_solver_serialized_args
()
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
solution
=
list
(
ret
[
0
])
...
...
colossalai/auto_parallel/tensor_shard/solver/solver.py
View file @
197d0bf4
...
@@ -32,7 +32,7 @@ class Solver:
...
@@ -32,7 +32,7 @@ class Solver:
graph
:
Graph
,
graph
:
Graph
,
strategies_constructor
:
StrategiesConstructor
,
strategies_constructor
:
StrategiesConstructor
,
cost_graph
:
CostGraph
,
cost_graph
:
CostGraph
,
graph_analyser
:
GraphAnalyser
,
graph_analyser
:
GraphAnalyser
=
None
,
memory_budget
:
float
=
-
1.0
,
memory_budget
:
float
=
-
1.0
,
solution_numbers
:
int
=
1
,
solution_numbers
:
int
=
1
,
forward_only
:
bool
=
False
,
forward_only
:
bool
=
False
,
...
@@ -63,7 +63,10 @@ class Solver:
...
@@ -63,7 +63,10 @@ class Solver:
self
.
memory_increasing_coefficient
=
memory_increasing_coefficient
self
.
memory_increasing_coefficient
=
memory_increasing_coefficient
else
:
else
:
self
.
memory_increasing_coefficient
=
1
self
.
memory_increasing_coefficient
=
1
self
.
liveness_list
=
self
.
graph_analyser
.
liveness_analysis
()
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
# self.liveness_list = self.graph_analyser.liveness_analysis()
self
.
liveness_list
=
self
.
nodes
self
.
node_index_dict
=
self
.
_generate_node_index_dict
()
self
.
node_index_dict
=
self
.
_generate_node_index_dict
()
# The last solution vector of auto sharding.
# The last solution vector of auto sharding.
self
.
last_s_val
=
None
self
.
last_s_val
=
None
...
@@ -140,7 +143,7 @@ class Solver:
...
@@ -140,7 +143,7 @@ class Solver:
liveness_set
=
self
.
liveness_list
liveness_set
=
self
.
liveness_list
# omit alias_set now
# omit alias_set now
alias_set
=
None
alias_set
=
self
.
strategies_constructor
.
alias_set
alias_convert_costs
=
None
alias_convert_costs
=
None
# prepare compute_costs, communication_costs and memory_costs
# prepare compute_costs, communication_costs and memory_costs
...
@@ -230,6 +233,7 @@ class Solver:
...
@@ -230,6 +233,7 @@ class Solver:
# 0. Unpack flatten numpy arrays
# 0. Unpack flatten numpy arrays
s_follow
=
following_nodes
s_follow
=
following_nodes
s_alias
=
alias_set
E
=
edge_pairs
.
reshape
((
-
1
,
2
))
# noqa
E
=
edge_pairs
.
reshape
((
-
1
,
2
))
# noqa
r
=
[]
r
=
[]
...
@@ -294,8 +298,11 @@ class Solver:
...
@@ -294,8 +298,11 @@ class Solver:
if
strategies_len
[
i
]
==
1
:
if
strategies_len
[
i
]
==
1
:
s
.
append
([
1
])
s
.
append
([
1
])
else
:
else
:
num_nodes
+=
1
if
i
not
in
s_alias
:
s
.
append
(
LpVariable
.
matrix
(
f
"s[
{
i
}
]"
,
(
range
(
strategies_len
[
i
]),),
cat
=
"Binary"
))
num_nodes
+=
1
s
.
append
(
LpVariable
.
matrix
(
f
"s[
{
i
}
]"
,
(
range
(
strategies_len
[
i
]),),
cat
=
"Binary"
))
else
:
s
.
append
(
s
[
s_alias
[
i
]])
else
:
else
:
if
s_follow
[
i
]
<
len
(
s
):
if
s_follow
[
i
]
<
len
(
s
):
s
.
append
(
s
[
s_follow
[
i
]])
s
.
append
(
s
[
s_follow
[
i
]])
...
@@ -311,15 +318,20 @@ class Solver:
...
@@ -311,15 +318,20 @@ class Solver:
#############################
#############################
e
=
[]
e
=
[]
num_edges
=
0
num_edges
=
0
map_edge_to_idx
=
{}
for
(
idx
,
(
i
,
j
))
in
enumerate
(
E
):
for
(
idx
,
(
i
,
j
))
in
enumerate
(
E
):
if
len
(
s
[
i
])
==
1
:
if
len
(
s
[
i
])
==
1
:
e
.
append
(
s
[
j
])
e
.
append
(
s
[
j
])
elif
len
(
s
[
j
])
==
1
:
elif
len
(
s
[
j
])
==
1
:
e
.
append
(
s
[
i
])
e
.
append
(
s
[
i
])
else
:
else
:
num_edges
+=
1
if
i
in
s_alias
and
j
in
s_alias
and
(
s_alias
[
i
],
s_alias
[
j
])
in
map_edge_to_idx
:
e
.
append
(
LpVariable
.
matrix
(
f
"e[
{
i
}
,
{
j
}
]"
,
(
range
(
len
(
s
[
i
])
*
len
(
s
[
j
])),),
cat
=
"Binary"
))
e
.
append
(
e
[
map_edge_to_idx
[(
s_alias
[
i
],
s_alias
[
j
])]])
else
:
num_edges
+=
1
e
.
append
(
LpVariable
.
matrix
(
f
"e[
{
i
}
,
{
j
}
]"
,
(
range
(
len
(
s
[
i
])
*
len
(
s
[
j
])),),
cat
=
"Binary"
))
assert
len
(
e
[
idx
])
==
len
(
r
[
idx
])
assert
len
(
e
[
idx
])
==
len
(
r
[
idx
])
map_edge_to_idx
[(
i
,
j
)]
=
idx
for
element
in
s
:
for
element
in
s
:
assert
len
(
element
)
>
0
assert
len
(
element
)
>
0
# 2. Set initial value
# 2. Set initial value
...
@@ -371,13 +383,12 @@ class Solver:
...
@@ -371,13 +383,12 @@ class Solver:
# compute memory consumption with liveness set #
# compute memory consumption with liveness set #
#################################################
#################################################
if
memory_budget
>
0
:
if
memory_budget
>
0
:
for
liveness_stage
in
liveness_set
:
mem
=
0
mem
=
0
for
node
in
liveness_set
:
for
live_variable
in
liveness_stage
.
unique_live_vars
:
if
node
not
in
self
.
node_index_dict
:
if
live_variable
.
node
not
in
self
.
node_index_dict
:
continue
continue
node_index
=
self
.
node_index_dict
[
node
]
node_index
=
self
.
node_index_dict
[
live_variable
.
node
]
mem
+=
lpSum
(
s
[
node_index
][
j
]
*
m
[
node_index
][
j
]
for
j
in
range
(
len
(
s
[
node_index
])))
mem
+=
lpSum
(
s
[
node_index
][
j
]
*
m
[
node_index
][
j
]
for
j
in
range
(
len
(
s
[
node_index
])))
prob
+=
mem
<=
memory_budget
prob
+=
mem
<=
memory_budget
# (d). specified by `cat="Binary"`
# (d). specified by `cat="Binary"`
...
...
colossalai/auto_parallel/tensor_shard/solver/strategies_constructor.py
View file @
197d0bf4
...
@@ -15,6 +15,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
...
@@ -15,6 +15,7 @@ from colossalai.auto_parallel.tensor_shard.node_handler import (
)
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
StrategiesVector
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
StrategiesVector
from
colossalai.auto_parallel.tensor_shard.utils
import
generate_resharding_costs
,
generate_sharding_spec
from
colossalai.auto_parallel.tensor_shard.utils
import
generate_resharding_costs
,
generate_sharding_spec
from
colossalai.auto_parallel.tensor_shard.utils.factory
import
find_repeat_blocks
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
..options
import
DataloaderOption
,
SolverOptions
from
..options
import
DataloaderOption
,
SolverOptions
...
@@ -42,6 +43,7 @@ class StrategiesConstructor:
...
@@ -42,6 +43,7 @@ class StrategiesConstructor:
self
.
strategy_map
=
{}
self
.
strategy_map
=
{}
self
.
solver_options
=
solver_options
self
.
solver_options
=
solver_options
self
.
no_strategy_nodes
=
[]
self
.
no_strategy_nodes
=
[]
self
.
alias_set
=
None
def
remove_duplicated_strategy
(
self
,
strategies_vector
):
def
remove_duplicated_strategy
(
self
,
strategies_vector
):
'''
'''
...
@@ -59,6 +61,22 @@ class StrategiesConstructor:
...
@@ -59,6 +61,22 @@ class StrategiesConstructor:
for
strategy
in
remove_list
:
for
strategy
in
remove_list
:
strategies_vector
.
remove
(
strategy
)
strategies_vector
.
remove
(
strategy
)
def
generate_alias_set
(
self
):
node_list
=
[
strategy_vector
.
node
for
strategy_vector
in
self
.
leaf_strategies
]
common_blocks
=
find_repeat_blocks
(
node_list
,
self
.
root_module
,
common_length_threshold
=
10
)
repeat_block_nums
=
len
(
common_blocks
)
alias_set
=
{}
if
repeat_block_nums
==
0
:
return
alias_set
for
index
,
common_node
in
enumerate
(
common_blocks
[
0
]):
for
i
in
range
(
1
,
repeat_block_nums
):
alias_set
[
node_list
.
index
(
common_blocks
[
i
][
index
])]
=
node_list
.
index
(
common_node
)
return
alias_set
def
build_strategies_and_cost
(
self
):
def
build_strategies_and_cost
(
self
):
"""
"""
This method is to build the strategy vector for each node in the computation graph.
This method is to build the strategy vector for each node in the computation graph.
...
@@ -175,3 +193,6 @@ class StrategiesConstructor:
...
@@ -175,3 +193,6 @@ class StrategiesConstructor:
self
.
leaf_strategies
.
remove
(
node
.
strategies_vector
)
self
.
leaf_strategies
.
remove
(
node
.
strategies_vector
)
if
node
in
self
.
strategy_map
:
if
node
in
self
.
strategy_map
:
self
.
strategy_map
.
pop
(
node
)
self
.
strategy_map
.
pop
(
node
)
alias_set
=
self
.
generate_alias_set
()
self
.
alias_set
=
alias_set
tests/test_auto_parallel/test_tensor_shard/test_gpt/test_solver_with_gpt_module.py
View file @
197d0bf4
...
@@ -15,13 +15,13 @@ from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2
...
@@ -15,13 +15,13 @@ from tests.test_auto_parallel.test_tensor_shard.test_gpt.gpt_modules import GPT2
BATCH_SIZE
=
1
BATCH_SIZE
=
1
SEQ_LENGTH
=
32
SEQ_LENGTH
=
32
HIDDEN_DIM
=
768
HIDDEN_DIM
=
384
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
parameterize
(
'model_cls'
,
[
GPT2Block
,
GPT2Attention
,
GPT2MLP
,
GPT2Model
])
@
parameterize
(
'model_cls'
,
[
GPT2Block
,
GPT2Attention
,
GPT2MLP
,
GPT2Model
])
def
test_self_attention_block
(
model_cls
):
def
test_self_attention_block
(
model_cls
):
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
4
,
n_head
=
16
,
n_embd
=
HIDDEN_DIM
)
config
=
transformers
.
GPT2Config
(
n_position
=
64
,
n_layer
=
12
,
n_head
=
16
,
n_embd
=
HIDDEN_DIM
)
if
model_cls
==
GPT2MLP
:
if
model_cls
==
GPT2MLP
:
model
=
model_cls
(
intermediate_size
=
4
*
config
.
hidden_size
,
config
=
config
)
model
=
model_cls
(
intermediate_size
=
4
*
config
.
hidden_size
,
config
=
config
)
else
:
else
:
...
@@ -54,15 +54,13 @@ def test_self_attention_block(model_cls):
...
@@ -54,15 +54,13 @@ def test_self_attention_block(model_cls):
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
print
(
gm
.
graph
)
print
(
gm
.
graph
)
gm
.
recompile
()
gm
.
recompile
()
graph_analyser
=
GraphAnalyser
(
gm
)
liveness_list
=
graph_analyser
.
liveness_analysis
()
solver_options
=
SolverOptions
()
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
cost_graph
.
simplify_graph
()
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
,
memory_budget
=-
1
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
memory_budget
=-
1
)
ret
=
solver
.
call_solver_serialized_args
()
ret
=
solver
.
call_solver_serialized_args
()
strategies_list
=
solver
.
last_s_val
strategies_list
=
solver
.
last_s_val
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
nodes
=
[
strategies_vector
.
node
for
strategies_vector
in
strategies_constructor
.
leaf_strategies
]
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
View file @
197d0bf4
...
@@ -9,7 +9,6 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_pre
...
@@ -9,7 +9,6 @@ from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_pre
from
colossalai.auto_parallel.tensor_shard.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.options
import
SolverOptions
from
colossalai.auto_parallel.tensor_shard.solver
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.solver
import
StrategiesConstructor
from
colossalai.auto_parallel.tensor_shard.solver.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.solver.cost_graph
import
CostGraph
from
colossalai.auto_parallel.tensor_shard.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.tensor_shard.solver.solver
import
Solver
from
colossalai.auto_parallel.tensor_shard.solver.solver
import
Solver
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.tracer
import
ColoTracer
...
@@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
...
@@ -109,8 +108,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
# solution construction
# solution construction
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
cost_graph
.
simplify_graph
()
graph_analyser
=
GraphAnalyser
(
gm
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
verbose
=
False
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
,
verbose
=
False
)
ret
=
solver
.
call_solver_serialized_args
()
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
solution
=
list
(
ret
[
0
])
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
gm
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
runtime_preparation_pass
(
...
...
tests/test_auto_parallel/test_tensor_shard/test_solver_with_resnet_v2.py
View file @
197d0bf4
...
@@ -51,15 +51,14 @@ def test_cost_graph():
...
@@ -51,15 +51,14 @@ def test_cost_graph():
# return fc
# return fc
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
gm
.
recompile
()
graph_analyser
=
GraphAnalyser
(
gm
)
liveness_list
=
graph_analyser
.
liveness_analysis
()
solver_options
=
SolverOptions
()
solver_options
=
SolverOptions
()
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
cost_graph
.
simplify_graph
()
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
)
ret
=
solver
.
call_solver_serialized_args
()
ret
=
solver
.
call_solver_serialized_args
()
print
(
ret
[
0
])
print
(
ret
[
0
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment