Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
faa23b9d
Unverified
Commit
faa23b9d
authored
Sep 14, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 14, 2022
Browse files
[autoparallel] add reshape handler (#1594)
* [autoparallel] add reshape handler * polish code
parent
c938dda0
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
148 additions
and
14 deletions
+148
-14
colossalai/auto_parallel/solver/constants.py
colossalai/auto_parallel/solver/constants.py
+5
-3
colossalai/auto_parallel/solver/op_handler/__init__.py
colossalai/auto_parallel/solver/op_handler/__init__.py
+2
-1
colossalai/auto_parallel/solver/op_handler/operator_handler.py
...salai/auto_parallel/solver/op_handler/operator_handler.py
+2
-1
colossalai/auto_parallel/solver/op_handler/reshape_handler.py
...ssalai/auto_parallel/solver/op_handler/reshape_handler.py
+66
-0
colossalai/auto_parallel/solver/sharding_strategy.py
colossalai/auto_parallel/solver/sharding_strategy.py
+6
-1
colossalai/auto_parallel/solver/strategies_constructor.py
colossalai/auto_parallel/solver/strategies_constructor.py
+9
-6
tests/test_auto_parallel/test_reshape_handler.py
tests/test_auto_parallel/test_reshape_handler.py
+55
-0
tests/test_auto_parallel/test_solver_with_resnet.py
tests/test_auto_parallel/test_solver_with_resnet.py
+3
-2
No files found.
colossalai/auto_parallel/solver/constants.py
View file @
faa23b9d
...
@@ -2,16 +2,17 @@ import torch
...
@@ -2,16 +2,17 @@ import torch
import
operator
import
operator
__all__
=
[
__all__
=
[
'ELEMENTWISE_MODULE_OP'
,
'ELEMENTWISE_FUNC_OP'
,
'CONV_MODULE_OP'
,
'CONV_FUNC_OP'
,
'LINEAR_MODULE_OP'
,
'ELEMENTWISE_MODULE_OP'
,
'ELEMENTWISE_FUNC_OP'
,
'RESHAPE_FUNC_OP'
,
'CONV_MODULE_OP'
,
'CONV_FUNC_OP'
,
'LINEAR_FUNC_OP'
,
'BATCHNORM_MODULE_OP'
,
'POOL_MODULE_OP'
'LINEAR_MODULE_OP'
,
'LINEAR_FUNC_OP'
,
'BATCHNORM_MODULE_OP'
,
'POOL_MODULE_OP'
,
'NON_PARAM_FUNC_OP'
]
]
ELEMENTWISE_MODULE_OP
=
[
torch
.
nn
.
Dropout
,
torch
.
nn
.
ReLU
]
ELEMENTWISE_MODULE_OP
=
[
torch
.
nn
.
Dropout
,
torch
.
nn
.
ReLU
]
# TODO: flatten should not be added into this group
# TODO: flatten should not be added into this group
ELEMENTWISE_FUNC_OP
=
[
ELEMENTWISE_FUNC_OP
=
[
torch
.
add
,
operator
.
add
,
torch
.
abs
,
torch
.
cos
,
torch
.
exp
,
torch
.
mul
,
operator
.
mul
,
operator
.
floordiv
,
torch
.
add
,
operator
.
add
,
torch
.
abs
,
torch
.
cos
,
torch
.
exp
,
torch
.
mul
,
operator
.
mul
,
operator
.
floordiv
,
operator
.
truediv
,
operator
.
neg
,
torch
.
multiply
,
torch
.
nn
.
functional
.
relu
,
torch
.
nn
.
functional
.
dropout
,
torch
.
flatten
operator
.
truediv
,
operator
.
neg
,
torch
.
multiply
,
torch
.
nn
.
functional
.
relu
,
torch
.
nn
.
functional
.
dropout
]
]
RESHAPE_FUNC_OP
=
[
torch
.
flatten
,
torch
.
Tensor
.
view
,
torch
.
reshape
]
CONV_MODULE_OP
=
[
CONV_MODULE_OP
=
[
torch
.
nn
.
Conv1d
,
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Conv3d
,
torch
.
nn
.
ConvTranspose1d
,
torch
.
nn
.
ConvTranspose2d
,
torch
.
nn
.
Conv1d
,
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Conv3d
,
torch
.
nn
.
ConvTranspose1d
,
torch
.
nn
.
ConvTranspose2d
,
torch
.
nn
.
ConvTranspose3d
torch
.
nn
.
ConvTranspose3d
...
@@ -23,5 +24,6 @@ LINEAR_MODULE_OP = [torch.nn.Linear]
...
@@ -23,5 +24,6 @@ LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP
=
[
torch
.
nn
.
functional
.
linear
,
torch
.
matmul
,
torch
.
bmm
]
LINEAR_FUNC_OP
=
[
torch
.
nn
.
functional
.
linear
,
torch
.
matmul
,
torch
.
bmm
]
BATCHNORM_MODULE_OP
=
[
torch
.
nn
.
BatchNorm1d
,
torch
.
nn
.
BatchNorm2d
,
torch
.
nn
.
BatchNorm3d
,
torch
.
nn
.
SyncBatchNorm
]
BATCHNORM_MODULE_OP
=
[
torch
.
nn
.
BatchNorm1d
,
torch
.
nn
.
BatchNorm2d
,
torch
.
nn
.
BatchNorm3d
,
torch
.
nn
.
SyncBatchNorm
]
POOL_MODULE_OP
=
[
torch
.
nn
.
MaxPool1d
,
torch
.
nn
.
MaxPool2d
,
torch
.
nn
.
MaxPool3d
,
torch
.
nn
.
AdaptiveAvgPool2d
]
POOL_MODULE_OP
=
[
torch
.
nn
.
MaxPool1d
,
torch
.
nn
.
MaxPool2d
,
torch
.
nn
.
MaxPool3d
,
torch
.
nn
.
AdaptiveAvgPool2d
]
NON_PARAM_FUNC_OP
=
RESHAPE_FUNC_OP
+
ELEMENTWISE_FUNC_OP
INFINITY_COST
=
1e13
INFINITY_COST
=
1e13
colossalai/auto_parallel/solver/op_handler/__init__.py
View file @
faa23b9d
...
@@ -2,5 +2,6 @@ from .operator_handler import OperatorHandler
...
@@ -2,5 +2,6 @@ from .operator_handler import OperatorHandler
from
.dot_handler
import
DotHandler
from
.dot_handler
import
DotHandler
from
.conv_handler
import
ConvHandler
from
.conv_handler
import
ConvHandler
from
.batch_norm_handler
import
BatchNormHandler
from
.batch_norm_handler
import
BatchNormHandler
from
.reshape_handler
import
ReshapeHandler
__all__
=
[
'OperatorHandler'
,
'DotHandler'
,
'ConvHandler'
,
'BatchNormHandler'
]
__all__
=
[
'OperatorHandler'
,
'DotHandler'
,
'ConvHandler'
,
'BatchNormHandler'
,
'ReshapeHandler'
]
\ No newline at end of file
\ No newline at end of file
colossalai/auto_parallel/solver/op_handler/operator_handler.py
View file @
faa23b9d
...
@@ -8,6 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
...
@@ -8,6 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.._utils
import
generate_resharding_costs
,
generate_sharding_spec
from
.._utils
import
generate_resharding_costs
,
generate_sharding_spec
from
colossalai.auto_parallel.solver.constants
import
*
from
..sharding_strategy
import
StrategiesVector
from
..sharding_strategy
import
StrategiesVector
...
@@ -44,7 +45,7 @@ class OperatorHandler(ABC):
...
@@ -44,7 +45,7 @@ class OperatorHandler(ABC):
named_parameters
=
list
(
module
.
named_parameters
(
recurse
=
False
))
named_parameters
=
list
(
module
.
named_parameters
(
recurse
=
False
))
# convert named parameters from list to dict
# convert named parameters from list to dict
named_parameters
=
{
k
:
v
for
k
,
v
in
named_parameters
}
named_parameters
=
{
k
:
v
for
k
,
v
in
named_parameters
}
elif
self
.
node
.
op
==
'call_function'
:
elif
self
.
node
.
op
==
'call_function'
and
self
.
node
.
target
not
in
NON_PARAM_FUNC_OP
:
module
=
None
module
=
None
parameters
=
list
(
self
.
node
.
args
)[
1
]
parameters
=
list
(
self
.
node
.
args
)[
1
]
named_parameters
=
{
'weight'
:
parameters
.
_meta_data
}
named_parameters
=
{
'weight'
:
parameters
.
_meta_data
}
...
...
colossalai/auto_parallel/solver/op_handler/reshape_handler.py
0 → 100644
View file @
faa23b9d
from
.operator_handler
import
OperatorHandler
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
copy
import
deepcopy
import
math
class
ReshapeHandler
(
OperatorHandler
):
"""
An OperatorHandler which deals with the sharding strategies of Reshape Operator, such as torch.reshape, torch.flatten, etc.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
input_data
=
self
.
predecessor_node
[
0
].
_meta_data
self
.
output_data
=
self
.
node
.
_meta_data
def
_generate_compute_cost
(
self
,
*
args
,
**
kwargs
):
return
super
().
_generate_compute_cost
(
*
args
,
**
kwargs
)
def
register_strategy
(
self
):
input_node
=
self
.
strategies_vector
.
predecessor_nodes
[
0
]
# For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in
# replica status and let the successor node choose the way to resharding the
# output node. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for reshape function.
sharding_spec_checklist
=
[]
for
strategy
in
input_node
.
strategies_vector
:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec
=
strategy
.
output_sharding_spec
assert
isinstance
(
input_sharding_spec
,
ShardingSpec
),
f
'The input node should NOT be a tuple of tensor.'
if
input_sharding_spec
in
sharding_spec_checklist
:
continue
sharding_spec_checklist
.
append
(
input_sharding_spec
)
dim_partition_dict_for_output
=
{}
output_sharding_spec
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
name
=
f
'
{
input_sharding_spec
.
sharding_sequence
}
-> FULLY REPLICATED'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost
=
0
memory_cost
=
self
.
node
.
_meta_data
.
numel
()
# compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
dim_partition_dict_for_replicate_input
=
{}
replicate_input_sharding_spec
=
self
.
_generate_sharding_spec
(
self
.
input_data
,
dim_partition_dict_for_replicate_input
)
# shape consistency manager is a singleton class
shape_consistency_manager
=
ShapeConsistencyManager
()
_
,
_
,
communication_cost
=
shape_consistency_manager
.
shape_consistency
(
input_sharding_spec
,
replicate_input_sharding_spec
)
# generate resharding cost
resharding_costs
=
self
.
_generate_resharding_costs
([
input_sharding_spec
])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs
[
input_node
]
=
[
0
if
cost
==
0
else
math
.
inf
for
cost
in
resharding_costs
[
input_node
]]
sharding_strategy
=
ShardingStrategy
(
name
,
output_sharding_spec
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
[
input_sharding_spec
])
self
.
strategies_vector
.
append
(
sharding_strategy
)
colossalai/auto_parallel/solver/sharding_strategy.py
View file @
faa23b9d
...
@@ -61,12 +61,17 @@ class StrategiesVector(list):
...
@@ -61,12 +61,17 @@ class StrategiesVector(list):
root_module
=
self
.
node
.
graph
.
owning_module
root_module
=
self
.
node
.
graph
.
owning_module
submod
=
root_module
.
get_submodule
(
target
)
submod
=
root_module
.
get_submodule
(
target
)
submod_type
=
type
(
submod
)
submod_type
=
type
(
submod
)
# merge elementwise module node into following nodes
# merge elementwise module node into source nodes
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if
submod_type
in
ELEMENTWISE_MODULE_OP
:
if
submod_type
in
ELEMENTWISE_MODULE_OP
:
merge_label
=
True
merge_label
=
True
if
self
.
node
.
op
==
'call_function'
:
if
self
.
node
.
op
==
'call_function'
:
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if
self
.
node
.
target
in
ELEMENTWISE_FUNC_OP
:
if
self
.
node
.
target
in
ELEMENTWISE_FUNC_OP
:
merge_label
=
True
merge_label
=
True
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
if
self
.
node
.
target
in
RESHAPE_FUNC_OP
:
merge_label
=
True
return
merge_label
return
merge_label
colossalai/auto_parallel/solver/strategies_constructor.py
View file @
faa23b9d
...
@@ -157,8 +157,7 @@ class StrategiesConstructor:
...
@@ -157,8 +157,7 @@ class StrategiesConstructor:
# print(node, node.op, node.target, node.args)
# print(node, node.op, node.target, node.args)
# create sharding strategy for element-wise module
# create sharding strategy for element-wise module
# input_node = strategies_vector.predecessor_nodes[0]
# input_node = strategies_vector.predecessor_nodes[0]
norm_handler
=
BatchNormHandler
(
node
,
self
.
device_mesh
,
strategies_vector
,
norm_handler
=
BatchNormHandler
(
node
,
self
.
device_mesh
,
strategies_vector
)
self
.
shape_consistency_manager
)
norm_handler
.
register_strategy
()
norm_handler
.
register_strategy
()
# for strategy in norm_handler.strategies_vector:
# for strategy in norm_handler.strategies_vector:
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
...
@@ -214,18 +213,22 @@ class StrategiesConstructor:
...
@@ -214,18 +213,22 @@ class StrategiesConstructor:
if
target
in
CONV_FUNC_OP
:
if
target
in
CONV_FUNC_OP
:
# use ConvHandler to create sharding strategies for conv node
# use ConvHandler to create sharding strategies for conv node
# TODO: the operator_handler does NOT support function node processing now.
# TODO: the operator_handler does NOT support function node processing now.
conv_handler
=
ConvHandler
(
node
,
self
.
device_mesh
,
strategies_vector
,
conv_handler
=
ConvHandler
(
node
,
self
.
device_mesh
,
strategies_vector
)
self
.
shape_consistency_manager
)
conv_handler
.
register_strategy
()
conv_handler
.
register_strategy
()
# linear function
# linear function
elif
target
in
LINEAR_FUNC_OP
:
elif
target
in
LINEAR_FUNC_OP
:
# use DotHandler to create sharding strategies for linear node
# use DotHandler to create sharding strategies for linear node
# TODO: the operator_handler does NOT support function node processing now.
# TODO: the operator_handler does NOT support function node processing now.
linear_handler
=
DotHandler
(
node
,
self
.
device_mesh
,
strategies_vector
,
linear_handler
=
DotHandler
(
node
,
self
.
device_mesh
,
strategies_vector
)
self
.
shape_consistency_manager
)
linear_handler
.
register_strategy
()
linear_handler
.
register_strategy
()
# reshape function
elif
target
in
RESHAPE_FUNC_OP
:
# use ReshapeHandler to create sharding strategies for rehsape node
reshape_handler
=
ReshapeHandler
(
node
,
self
.
device_mesh
,
strategies_vector
)
reshape_handler
.
register_strategy
()
# element-wise function
# element-wise function
elif
target
in
ELEMENTWISE_FUNC_OP
:
elif
target
in
ELEMENTWISE_FUNC_OP
:
# TODO: integrate element-wise func and module together
# TODO: integrate element-wise func and module together
...
...
tests/test_auto_parallel/test_reshape_handler.py
0 → 100644
View file @
faa23b9d
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
colossalai.auto_parallel.solver.options
import
SolverOptions
from
colossalai.auto_parallel.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.device.device_mesh
import
DeviceMesh
class
ConvModel
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_out
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
c_in
,
c_out
,
kernel_size
=
3
)
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
torch
.
flatten
(
x
)
return
x
def
test_conv_handler
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
tracer
=
ColoTracer
()
model
=
ConvModel
(
16
,
32
)
input_sample
=
{
'x'
:
torch
.
rand
(
4
,
16
,
64
,
64
).
to
(
'meta'
)}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return flatten
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
# [x, conv, flatten, output]
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
solver_options
=
SolverOptions
(
fast
=
True
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
strategy_map
=
strategies_constructor
.
strategy_map
conv_strategies
=
strategy_map
[
nodes
[
1
]]
flatten_strategies
=
strategy_map
[
nodes
[
2
]]
flatten_strategies_cover_list
=
[
strategy
.
input_shardings
[
0
].
sharding_sequence
for
strategy
in
flatten_strategies
]
for
strategy
in
conv_strategies
:
assert
strategy
.
output_sharding_spec
.
sharding_sequence
in
flatten_strategies_cover_list
if
__name__
==
'__main__'
:
test_conv_handler
()
tests/test_auto_parallel/test_solver_with_resnet.py
View file @
faa23b9d
...
@@ -14,6 +14,7 @@ from colossalai.auto_parallel.solver import Solver
...
@@ -14,6 +14,7 @@ from colossalai.auto_parallel.solver import Solver
from
torchvision.models
import
resnet34
,
resnet50
from
torchvision.models
import
resnet34
,
resnet50
from
colossalai.auto_parallel.solver.constants
import
*
from
colossalai.auto_parallel.solver.constants
import
*
from
colossalai.auto_parallel.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.solver.graph_analysis
import
GraphAnalyser
from
colossalai.auto_parallel.solver.options
import
SolverOptions
class
ConvModel
(
nn
.
Module
):
class
ConvModel
(
nn
.
Module
):
...
@@ -81,8 +82,8 @@ def test_cost_graph():
...
@@ -81,8 +82,8 @@ def test_cost_graph():
liveness_list
=
graph_analyser
.
liveness_analysis
()
liveness_list
=
graph_analyser
.
liveness_analysis
()
# print(len(liveness_dict[0].unique_live_vars))
# print(len(liveness_dict[0].unique_live_vars))
# assert False
# assert False
solver_options
=
{
'fast_mode'
:
True
}
solver_options
=
SolverOptions
(
fast
=
True
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
shape_consistency_manager
,
solver_options
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
strategies_constructor
.
build_strategies_and_cost
()
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
cost_graph
=
CostGraph
(
strategies_constructor
.
leaf_strategies
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment