Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
451cd72d
Unverified
Commit
451cd72d
authored
Oct 14, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 14, 2022
Browse files
[autoparallel] adapt runtime passes (#1703)
* [autoparallel] adapt runtime passes v2 * polish code
parent
21962e15
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
204 additions
and
6 deletions
+204
-6
colossalai/auto_parallel/solver/cost_graph.py
colossalai/auto_parallel/solver/cost_graph.py
+0
-3
colossalai/auto_parallel/solver/node_handler/node_handler.py
colossalai/auto_parallel/solver/node_handler/node_handler.py
+2
-2
colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py
...auto_parallel/solver/strategy/normal_pooling_generator.py
+1
-1
colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py
...x/passes/experimental/adding_shape_consistency_pass_v2.py
+115
-0
tests/test_auto_parallel/test_shape_consistency_pass.py
tests/test_auto_parallel/test_shape_consistency_pass.py
+86
-0
No files found.
colossalai/auto_parallel/solver/cost_graph.py
View file @
451cd72d
...
...
@@ -58,9 +58,6 @@ class CostGraph:
edge_cost
=
{}
for
i
in
range
(
len
(
strategies_vector
)):
for
j
in
range
(
len
(
src_node
.
strategies_vector
)):
if
strategies_vector
[
i
].
resharding_costs
is
None
:
print
(
strategies_vector
.
node
.
name
)
assert
False
resharding_cost_item
=
strategies_vector
[
i
].
resharding_costs
[
src_node
][
j
]
if
self
.
forward_only
:
edge_cost
[(
j
,
i
)]
=
resharding_cost_item
.
fwd
...
...
colossalai/auto_parallel/solver/node_handler/node_handler.py
View file @
451cd72d
...
...
@@ -90,8 +90,8 @@ class NodeHandler(ABC):
# compute the resharding costs based on the previous node
# strategies if specified
if
compute_resharding_cost
:
updated_strategies
=
map
(
self
.
update_resharding_cost
,
strategies
)
strategies
=
list
(
updated_strategies
)
updated_strategies
=
map
(
self
.
update_resharding_cost
,
post_processed_
strategies
)
post_processed_
strategies
=
list
(
updated_strategies
)
self
.
strategies_vector
.
extend
(
post_processed_strategies
)
...
...
colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py
View file @
451cd72d
...
...
@@ -52,7 +52,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
total_compute_cost
=
forward_compute_cost
+
backward_compute_cost
compute_cost
=
TrainCycleItem
(
fwd
=
forward_compute_cost
,
bwd
=
backward_compute_cost
,
total
=
total_compute_cost
)
return
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
forward_size_mapping
=
{
...
...
colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py
0 → 100644
View file @
451cd72d
import
torch
from
typing
import
List
from
torch.fx
import
symbolic_trace
from
torch.fx.node
import
Node
from
colossalai.fx.passes.split_module
import
split_module
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
_DimSpec
import
builtins
import
operator
from
copy
import
deepcopy
def
apply
(
*
args
,
**
kwargs
):
shape_consistency_manager
=
ShapeConsistencyManager
()
return
shape_consistency_manager
.
apply
(
*
args
,
**
kwargs
)
def
solution_annotatation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
):
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
# the dict to get origin sharding spec of node
origin_node_sharding_spec_dict
=
{}
for
node_index
,
(
node
,
strategy_index
)
in
enumerate
(
zip
(
nodes
,
solution
)):
strategies_vector
=
node
.
strategies_vector
setattr
(
node
,
'best_strategy'
,
strategies_vector
[
strategy_index
])
setattr
(
node
,
'sharding_spec'
,
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
)))
origin_node_sharding_spec_dict
[
node_index
]
=
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
))
# apply the sharding spec of parameters
for
node
in
nodes
:
if
node
.
op
==
'call_module'
:
target_module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
for
name
,
param
in
target_module
.
named_parameters
():
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
param
.
shape
,
{})
setattr
(
param
,
'sharding_spec'
,
origin_sharding_spec
)
target_weight_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
apply
(
param
,
target_weight_sharding_spec
)
# the dict to get input sharding specs of user node
sharding_spec_convert_dict
=
{}
for
index
,
node
in
enumerate
(
nodes
):
target_sharding_specs
=
[]
if
node
.
name
==
'bn1'
:
print
(
node
.
strategies_vector
.
successor_nodes
)
assert
False
for
user_node
in
node
.
strategies_vector
.
successor_nodes
:
# node_index = user_node.strategies_vector.predecessor_nodes.index(node)
# target_sharding_spec = user_node.best_strategy.input_shardings[node_index]
target_sharding_spec
=
user_node
.
best_strategy
.
get_sharding_spec_by_name
(
str
(
node
.
name
))
target_sharding_specs
.
append
(
target_sharding_spec
)
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
# add above dicts into graph
for
node
in
nodes
:
if
node
.
op
!=
'placeholder'
:
with
mod_graph
.
inserting_before
(
node
):
input_specs_node
=
mod_graph
.
create_node
(
'placeholder'
,
target
=
'sharding_spec_convert_dict'
)
origin_specs_node
=
mod_graph
.
create_node
(
'placeholder'
,
target
=
'origin_node_sharding_spec_dict'
)
break
return
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
def
shape_consistency_pass
(
gm
:
torch
.
fx
.
GraphModule
):
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
input_dict_node
=
None
origin_dict_node
=
None
# mapping the node into the origin graph index
node_to_index_dict
=
{}
index
=
0
for
node
in
nodes
:
if
node
.
target
==
'sharding_spec_convert_dict'
:
input_dict_node
=
node
continue
if
node
.
target
==
'origin_node_sharding_spec_dict'
:
origin_dict_node
=
node
continue
if
not
hasattr
(
node
,
'best_strategy'
):
continue
node_to_index_dict
[
node
]
=
index
index
+=
1
assert
input_dict_node
is
not
None
# add shape consistency apply function into graph
for
node
in
nodes
:
if
not
hasattr
(
node
,
'best_strategy'
):
continue
with
mod_graph
.
inserting_after
(
node
):
origin_spec_node
=
mod_graph
.
create_node
(
'call_function'
,
operator
.
getitem
,
args
=
(
origin_dict_node
,
node_to_index_dict
[
node
]))
with
mod_graph
.
inserting_after
(
origin_spec_node
):
set_sharding_spec_node
=
mod_graph
.
create_node
(
'call_function'
,
builtins
.
setattr
,
args
=
(
node
,
'sharding_spec'
,
origin_spec_node
))
for
user_node
in
node
.
strategies_vector
.
successor_nodes
:
node_index
=
user_node
.
strategies_vector
.
predecessor_nodes
.
index
(
node
)
with
mod_graph
.
inserting_before
(
user_node
):
input_specs_node
=
mod_graph
.
create_node
(
'call_function'
,
operator
.
getitem
,
args
=
(
input_dict_node
,
node_to_index_dict
[
node
]))
with
mod_graph
.
inserting_before
(
user_node
):
sharding_spec_node
=
mod_graph
.
create_node
(
'call_function'
,
operator
.
getitem
,
args
=
(
input_specs_node
,
node_index
))
with
mod_graph
.
inserting_before
(
user_node
):
shape_consistency_node
=
mod_graph
.
create_node
(
'call_function'
,
apply
,
args
=
(
node
,
sharding_spec_node
))
return
gm
tests/test_auto_parallel/test_shape_consistency_pass.py
0 → 100644
View file @
451cd72d
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
colossalai.initialize
import
launch
from
colossalai.utils
import
free_port
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.auto_parallel.solver.cost_graph
import
CostGraph
from
colossalai.auto_parallel.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2
import
shape_consistency_pass
,
solution_annotatation_pass
from
colossalai.auto_parallel.solver.solver
import
Solver_V2
from
colossalai.auto_parallel.solver.options
import
SolverOptions
class
ConvModel
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_out
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
c_in
,
c_out
,
kernel_size
=
3
,
padding
=
1
,
bias
=
False
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
return
x
def
check_apply
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
input
=
torch
.
rand
(
4
,
4
,
4
,
4
).
cuda
()
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
False
)
entire_shape
=
torch
.
Size
((
4
,
4
,
8
,
8
))
tracer
=
ColoTracer
()
model
=
ConvModel
(
4
,
4
).
cuda
()
origin_output
=
model
(
input
)
input_sample
=
{
'x'
:
torch
.
rand
(
4
,
4
,
4
,
4
).
to
(
'meta'
)}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
solver_options
=
SolverOptions
(
fast
=
True
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
.
simplify_graph
()
graph_analyser
=
GraphAnalyser
(
gm
)
solver
=
Solver_V2
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
device_mesh
.
process_groups_dict
=
device_mesh
.
create_process_groups_for_logical_mesh
()
sharding_spec_dict
,
origin_spec_dict
=
solution_annotatation_pass
(
gm
,
solution
,
device_mesh
)
shape_consistency_pass
(
gm
)
gm
.
recompile
()
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
# TODO: wrap the gm to avoid the influence of the user training code
output
=
gm
(
input
,
sharding_spec_dict
,
origin_spec_dict
)
assert
output
.
equal
(
origin_output
)
@
pytest
.
mark
.
skip
(
"for higher testing speed"
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_apply
():
world_size
=
4
run_func
=
partial
(
check_apply
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_apply
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment