Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
628c7e3f
Unverified
Commit
628c7e3f
authored
Aug 22, 2022
by
Frank Lee
Committed by
GitHub
Aug 22, 2022
Browse files
[autoparallel] added dot handler (#1475)
parent
d08566fb
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
364 additions
and
26 deletions
+364
-26
colossalai/auto_parallel/solver/conv_handler.py
colossalai/auto_parallel/solver/conv_handler.py
+1
-22
colossalai/auto_parallel/solver/dot_handler.py
colossalai/auto_parallel/solver/dot_handler.py
+226
-3
colossalai/auto_parallel/solver/operator_handler.py
colossalai/auto_parallel/solver/operator_handler.py
+20
-0
colossalai/auto_parallel/solver/sharding_strategy.py
colossalai/auto_parallel/solver/sharding_strategy.py
+4
-1
tests/test_auto_parallel/test_dot_handler.py
tests/test_auto_parallel/test_dot_handler.py
+113
-0
No files found.
colossalai/auto_parallel/solver/conv_handler.py
View file @
628c7e3f
from
lib2to3.pytree
import
Base
import
operator
from
functools
import
reduce
import
torch
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
from
.operator_handler
import
OperatorHanlder
...
...
@@ -26,25 +24,6 @@ class ConvHandler(OperatorHanlder):
assert
self
.
input_data
.
dim
()
in
(
3
,
4
,
5
),
f
'We suppose the dim of input fed into conv op should in range of [3, 5].'
def
_generate_resharding_costs
(
self
,
resharding_costs
,
sharding_spec_for_input
):
'''
Compute the resharding costs with this specific strategy.
Note: The resharding_cost of weight is NOT counted.
Argument:
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs
[
self
.
input_index
]
=
[]
for
stategy
in
self
.
input_node
.
strategies_vector
.
strategies
:
_
,
_
,
resharding_cost
=
self
.
shape_consistency_manager
.
shape_consistency
(
stategy
,
sharding_spec_for_input
)
resharding_costs
[
self
.
input_index
].
append
(
resharding_cost
)
def
_generate_compute_cost
(
self
,
bs
,
channel_in
,
channel_out
):
'''
Compute the computation cost per device with this specific strategy.
...
...
colossalai/auto_parallel/solver/dot_handler.py
View file @
628c7e3f
import
operator
import
torch
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
from
.operator_handler
import
OperatorHanlder
from
functools
import
reduce
class
DotHandler
(
OperatorHanlder
):
...
...
@@ -6,7 +10,226 @@ class DotHandler(OperatorHanlder):
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
_generate_compute_cost
(
self
,
input_shape
,
weight_shape
):
# TODO: consider bias addition
compute_cost
=
reduce
(
operator
.
mul
,
input_shape
)
*
weight_shape
[
0
]
*
2
return
compute_cost
# TODO: refactor the dot handler in my local branch to align with the latest main branch
def
split_lhs_space_rhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# handle case SS = SR x RS
name
=
f
'S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
R x RS
{
mesh_dim_1
}
'
dim_partition_dict_for_input
=
{
0
:
[
mesh_dim_0
]}
sharding_spec_for_input
=
self
.
_generate_sharding_spec
(
self
.
input_data
,
dim_partition_dict_for_input
)
# linear layer weight is transposed during init
dim_partition_dict_for_weight
=
{
0
:
[
mesh_dim_1
]}
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
resharding_costs
=
{}
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute computation cost
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
# compute the communication cost
# no all-reduce required for this case
communication_cost
=
0
# create and register strategy
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
def
split_lhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# handle the case SR = SS x SR
name
=
f
'S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x S
{
mesh_dim_1
}
R'
dim_partition_dict_for_input
=
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]}
sharding_spec_for_input
=
self
.
_generate_sharding_spec
(
self
.
input_data
,
dim_partition_dict_for_input
)
# since weight of the linear layer is transposed
# the actual dim to be sharded is 1
dim_partition_dict_for_weight
=
{
1
:
[
mesh_dim_0
]}
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
{}
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost
,
mesh_dim_1
)
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
def
split_rhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'RS
{
mesh_dim_1
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
'
dim_partition_dict_for_input
=
{
1
:
[
mesh_dim_0
]}
sharding_spec_for_input
=
self
.
_generate_sharding_spec
(
self
.
input_data
,
dim_partition_dict_for_input
)
dim_partition_dict_for_weight
=
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]}
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim_1
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
resharding_costs
=
{}
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost
,
mesh_dim_1
)
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
def
recompute_split_both_contract
(
self
,
mesh_dim
):
name
=
f
'RR = RS
{
mesh_dim
}
x S
{
mesh_dim
}
R'
dim_partition_dict_for_input
=
{
1
:
[
mesh_dim
]}
sharding_spec_for_input
=
self
.
_generate_sharding_spec
(
self
.
input_data
,
dim_partition_dict_for_input
)
dim_partition_dict_for_weight
=
{
1
:
[
mesh_dim
]}
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
{}
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
memory_cost
=
numel
*
size_per_elem_bytes
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost
,
mesh_dim
)
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
def
split_rhs_space_only
(
self
,
mesh_dim
):
name
=
f
'RS
{
mesh_dim
}
= RR x RS
{
mesh_dim
}
'
dim_partition_dict_for_input
=
{}
sharding_spec_for_input
=
self
.
_generate_sharding_spec
(
self
.
input_data
,
dim_partition_dict_for_input
)
dim_partition_dict_for_weight
=
{
0
:
[
mesh_dim
]}
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
{}
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
# compute the communication cost of this strategy
communication_cost
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost
,
mesh_dim
)
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_ouput
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
def
register_strategy_into_strategies_vector
(
self
):
'''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
Output:
'''
# SS = SR x RS
self
.
split_lhs_space_rhs_space
(
0
,
1
)
self
.
split_lhs_space_rhs_space
(
1
,
0
)
# SR = SS x SR
self
.
split_lhs_space_both_contract
(
0
,
1
)
self
.
split_lhs_space_both_contract
(
1
,
0
)
# RS = RS x SS
self
.
split_rhs_space_both_contract
(
0
,
1
)
self
.
split_rhs_space_both_contract
(
1
,
0
)
# RR= RS x SR
self
.
recompute_split_both_contract
(
0
)
self
.
recompute_split_both_contract
(
1
)
# RS = RR x RS
self
.
split_rhs_space_only
(
0
)
self
.
split_rhs_space_only
(
1
)
colossalai/auto_parallel/solver/operator_handler.py
View file @
628c7e3f
...
...
@@ -43,3 +43,23 @@ class OperatorHanlder(ABC):
entire_shape
=
tensor
.
shape
,
dim_partition_dict
=
dim_partition_dict
)
return
sharding_spec
def
_generate_resharding_costs
(
self
,
resharding_costs
,
sharding_spec_for_input
):
'''
Compute the resharding costs with this specific strategy.
Note: The resharding_cost of weight is NOT counted.
Argument:
resharding_costs(Dict[int, List[float]]): The resharding cost generated in this method will be appended into this dictionary.
Resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs
[
self
.
input_index
]
=
[]
for
stategy
in
self
.
input_node
.
strategies_vector
.
strategies
:
_
,
_
,
resharding_cost
=
self
.
shape_consistency_manager
.
shape_consistency
(
stategy
,
sharding_spec_for_input
)
resharding_costs
[
self
.
input_index
].
append
(
resharding_cost
)
return
resharding_cost
colossalai/auto_parallel/solver/sharding_strategy.py
View file @
628c7e3f
...
...
@@ -42,10 +42,13 @@ class StrategiesVector:
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
'''
def
__init__
(
self
,
node
,
in_nodes
,
following_nodes
=
None
,
strategies
=
[]
):
def
__init__
(
self
,
node
,
in_nodes
,
following_nodes
=
None
,
strategies
=
None
):
self
.
node
=
node
self
.
in_nodes
=
in_nodes
self
.
following_nodes
=
following_nodes
if
strategies
is
None
:
strategies
=
[]
self
.
strategies
=
strategies
def
check_merge
(
self
):
...
...
tests/test_auto_parallel/test_dot_handler.py
0 → 100644
View file @
628c7e3f
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
colossalai.fx.proxy
import
ColoProxy
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
_DimSpec
from
colossalai.auto_parallel.solver.dot_handler
import
DotHandler
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.device.device_mesh
import
DeviceMesh
class
LinearModel
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
out_features
):
super
().
__init__
()
self
.
linear
=
nn
.
Linear
(
in_features
,
out_features
)
def
forward
(
self
,
x
):
x
=
x
*
2
x
=
self
.
linear
(
x
)
return
x
def
test_dot_handler
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
entire_shape
=
torch
.
Size
((
4
,
8
))
shape_consistency_manager
=
ShapeConsistencyManager
()
tracer
=
ColoTracer
()
model
=
LinearModel
(
8
,
16
)
input_sample
=
{
'x'
:
torch
.
rand
(
4
,
8
).
to
(
'meta'
)}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
# [x, mul, linear, output]
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
strategies_for_input
=
[]
sharding_option
=
(
None
,
0
,
1
)
for
first_sharding_index
in
sharding_option
:
for
second_sharding_index
in
sharding_option
:
if
first_sharding_index
is
not
None
and
second_sharding_index
==
first_sharding_index
:
continue
if
first_sharding_index
is
None
:
first_dim_spec
=
_DimSpec
([])
else
:
first_dim_spec
=
_DimSpec
([
first_sharding_index
])
if
second_sharding_index
is
None
:
second_dim_spec
=
_DimSpec
([])
else
:
second_dim_spec
=
_DimSpec
([
second_sharding_index
])
sharding_sequence
=
[
first_dim_spec
,
second_dim_spec
]
sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
entire_shape
,
sharding_sequence
=
sharding_sequence
)
strategies_for_input
.
append
(
sharding_spec
)
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input
=
StrategiesVector
(
node
=
nodes
[
1
],
in_nodes
=
nodes
[
0
],
strategies
=
strategies_for_input
)
setattr
(
nodes
[
1
],
'strategies_vector'
,
strategies_vector_for_input
)
strategies_vector
=
StrategiesVector
(
node
=
nodes
[
2
],
in_nodes
=
[
nodes
[
1
],
])
dot_handler
=
DotHandler
(
input_node
=
nodes
[
1
],
input_index
=
0
,
weight
=
dict
(
gm
.
named_modules
())[
nodes
[
2
].
name
].
weight
,
output_node
=
nodes
[
2
],
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
,
shape_consistency_manager
=
shape_consistency_manager
)
dot_handler
.
register_strategy_into_strategies_vector
()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
strategy_name_list
=
[
strategy
.
name
for
strategy
in
dot_handler
.
strategies_vector
.
strategies
]
# SS = SR x RS
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
assert
'S1S0 = S1R x RS0'
in
strategy_name_list
# SR = SS x SR
assert
'S0R = S0S1 x S1R'
in
strategy_name_list
assert
'S1R = S1S0 x S0R'
in
strategy_name_list
# RS = RS x SS
assert
'RS0 = RS1 x S1S0'
in
strategy_name_list
assert
'RS1 = RS0 x S0S1'
in
strategy_name_list
# RR = RS x SR
assert
'RR = RS0 x S0R'
in
strategy_name_list
assert
'RR = RS1 x S1R'
in
strategy_name_list
# RS= RR x RS
assert
'RS0 = RR x RS0'
in
strategy_name_list
assert
'RS1 = RR x RS1'
in
strategy_name_list
if
__name__
==
'__main__'
:
test_dot_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment