Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ede32629
Unverified
Commit
ede32629
authored
Aug 23, 2022
by
Frank Lee
Committed by
GitHub
Aug 23, 2022
Browse files
[autoparallel] integrate auto parallel with torch fx (#1479)
parent
8fb09a95
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
132 additions
and
120 deletions
+132
-120
colossalai/auto_parallel/solver/__init__.py
colossalai/auto_parallel/solver/__init__.py
+6
-0
colossalai/auto_parallel/solver/conv_handler.py
colossalai/auto_parallel/solver/conv_handler.py
+30
-30
colossalai/auto_parallel/solver/dot_handler.py
colossalai/auto_parallel/solver/dot_handler.py
+31
-29
colossalai/auto_parallel/solver/operator_handler.py
colossalai/auto_parallel/solver/operator_handler.py
+37
-19
colossalai/auto_parallel/solver/sharding_strategy.py
colossalai/auto_parallel/solver/sharding_strategy.py
+10
-12
tests/test_auto_parallel/test_conv_handler.py
tests/test_auto_parallel/test_conv_handler.py
+9
-16
tests/test_auto_parallel/test_dot_handler.py
tests/test_auto_parallel/test_dot_handler.py
+9
-14
No files found.
colossalai/auto_parallel/solver/__init__.py
View file @
ede32629
from
.operator_handler
import
OperatorHandler
from
.dot_handler
import
DotHandler
from
.conv_handler
import
ConvHandler
from
.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
__all__
=
[
'OperatorHandler'
,
'DotHandler'
,
'ConvHandler'
,
'StrategiesVector'
,
'ShardingStrategy'
]
colossalai/auto_parallel/solver/conv_handler.py
View file @
ede32629
import
operator
import
operator
from
functools
import
reduce
from
functools
import
reduce
import
torch
import
torch
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.operator_handler
import
OperatorHan
l
der
from
.operator_handler
import
OperatorHand
l
er
class
ConvHandler
(
OperatorHan
l
der
):
class
ConvHandler
(
OperatorHand
l
er
):
"""
"""
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
input_data
=
self
.
predecessor_node
[
0
].
_meta_data
self
.
weight
=
self
.
module_named_parameters
[
'weight'
]
self
.
output_data
=
self
.
node
.
_meta_data
self
.
_sanity_check
()
self
.
_sanity_check
()
def
_sanity_check
(
self
):
def
_sanity_check
(
self
):
...
@@ -42,7 +45,7 @@ class ConvHandler(OperatorHanlder):
...
@@ -42,7 +45,7 @@ class ConvHandler(OperatorHanlder):
# 1D: (L) * N * Cout * Cin * kernel
# 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
output_size
=
self
.
output
.
shape
[
2
:]
output_size
=
self
.
output
_data
.
shape
[
2
:]
output_size_product
=
reduce
(
operator
.
mul
,
output_size
,
1
)
output_size_product
=
reduce
(
operator
.
mul
,
output_size
,
1
)
kernel_size
=
self
.
weight
.
shape
[
2
:]
kernel_size
=
self
.
weight
.
shape
[
2
:]
kernel_size_product
=
reduce
(
operator
.
mul
,
kernel_size
,
1
)
kernel_size_product
=
reduce
(
operator
.
mul
,
kernel_size
,
1
)
...
@@ -59,11 +62,10 @@ class ConvHandler(OperatorHanlder):
...
@@ -59,11 +62,10 @@ class ConvHandler(OperatorHanlder):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]}
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_
in
put
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
_data
,
dim_partition_dict_for_
out
put
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
{}
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
...
@@ -73,7 +75,7 @@ class ConvHandler(OperatorHanlder):
...
@@ -73,7 +75,7 @@ class ConvHandler(OperatorHanlder):
# compute the memory cost of this strategy
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
numel
=
self
.
output
_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
...
@@ -87,7 +89,7 @@ class ConvHandler(OperatorHanlder):
...
@@ -87,7 +89,7 @@ class ConvHandler(OperatorHanlder):
memory_cost
=
memory_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
split_input_both_dim_weight_in_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_input_both_dim_weight_in_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x S
{
mesh_dim_1
}
R'
name
=
f
'S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x S
{
mesh_dim_1
}
R'
...
@@ -99,11 +101,10 @@ class ConvHandler(OperatorHanlder):
...
@@ -99,11 +101,10 @@ class ConvHandler(OperatorHanlder):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
]}
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
{}
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
...
@@ -113,7 +114,7 @@ class ConvHandler(OperatorHanlder):
...
@@ -113,7 +114,7 @@ class ConvHandler(OperatorHanlder):
# compute the memory cost of this strategy
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
numel
=
self
.
output
_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
...
@@ -127,7 +128,7 @@ class ConvHandler(OperatorHanlder):
...
@@ -127,7 +128,7 @@ class ConvHandler(OperatorHanlder):
memory_cost
=
memory_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
split_input_in_channel_weight_both_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_input_in_channel_weight_both_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'RS
{
mesh_dim_1
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
'
name
=
f
'RS
{
mesh_dim_1
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
'
...
@@ -139,11 +140,10 @@ class ConvHandler(OperatorHanlder):
...
@@ -139,11 +140,10 @@ class ConvHandler(OperatorHanlder):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim_1
]}
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim_1
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
{}
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
bs
=
self
.
input_data
.
shape
[
0
]
...
@@ -153,7 +153,7 @@ class ConvHandler(OperatorHanlder):
...
@@ -153,7 +153,7 @@ class ConvHandler(OperatorHanlder):
# compute the memory cost of this strategy
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
numel
=
self
.
output
_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
...
@@ -167,7 +167,7 @@ class ConvHandler(OperatorHanlder):
...
@@ -167,7 +167,7 @@ class ConvHandler(OperatorHanlder):
memory_cost
=
memory_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
split_weight_out_channel
(
self
,
mesh_dim_0
):
def
split_weight_out_channel
(
self
,
mesh_dim_0
):
name
=
f
'RS
{
mesh_dim_0
}
= RR x RS
{
mesh_dim_0
}
'
name
=
f
'RS
{
mesh_dim_0
}
= RR x RS
{
mesh_dim_0
}
'
...
@@ -179,11 +179,10 @@ class ConvHandler(OperatorHanlder):
...
@@ -179,11 +179,10 @@ class ConvHandler(OperatorHanlder):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim_0
]}
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim_0
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
{}
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
bs
=
self
.
input_data
.
shape
[
0
]
...
@@ -193,7 +192,7 @@ class ConvHandler(OperatorHanlder):
...
@@ -193,7 +192,7 @@ class ConvHandler(OperatorHanlder):
# compute the memory cost of this strategy
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
numel
=
self
.
output
_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
...
@@ -208,7 +207,7 @@ class ConvHandler(OperatorHanlder):
...
@@ -208,7 +207,7 @@ class ConvHandler(OperatorHanlder):
memory_cost
=
memory_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
non_split
(
self
):
def
non_split
(
self
):
name
=
f
'RR = RR x RR'
name
=
f
'RR = RR x RR'
...
@@ -220,11 +219,10 @@ class ConvHandler(OperatorHanlder):
...
@@ -220,11 +219,10 @@ class ConvHandler(OperatorHanlder):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{}
dim_partition_dict_for_output
=
{}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
{}
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
bs
=
self
.
input_data
.
shape
[
0
]
...
@@ -234,7 +232,7 @@ class ConvHandler(OperatorHanlder):
...
@@ -234,7 +232,7 @@ class ConvHandler(OperatorHanlder):
# compute the memory cost of this strategy
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
numel
=
self
.
output
_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
memory_cost
=
numel
*
size_per_elem_bytes
memory_cost
=
numel
*
size_per_elem_bytes
...
@@ -248,9 +246,9 @@ class ConvHandler(OperatorHanlder):
...
@@ -248,9 +246,9 @@ class ConvHandler(OperatorHanlder):
memory_cost
=
memory_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
register_strategy
_into_s
trategies
_v
ector
(
self
)
:
def
register_strategy
(
self
)
->
S
trategies
V
ector
:
'''
'''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
...
@@ -315,3 +313,5 @@ class ConvHandler(OperatorHanlder):
...
@@ -315,3 +313,5 @@ class ConvHandler(OperatorHanlder):
# RR= RR x RR
# RR= RR x RR
self
.
non_split
()
self
.
non_split
()
return
self
.
strategies_vector
colossalai/auto_parallel/solver/dot_handler.py
View file @
ede32629
import
operator
import
operator
import
torch
import
torch
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.operator_handler
import
OperatorHan
l
der
from
.operator_handler
import
OperatorHand
l
er
from
functools
import
reduce
from
functools
import
reduce
class
DotHandler
(
OperatorHan
l
der
):
class
DotHandler
(
OperatorHand
l
er
):
"""
"""
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
input_data
=
self
.
predecessor_node
[
0
].
_meta_data
self
.
weight
=
self
.
module_named_parameters
[
'weight'
]
self
.
output_data
=
self
.
node
.
_meta_data
def
_generate_compute_cost
(
self
,
input_shape
,
weight_shape
):
def
_generate_compute_cost
(
self
,
input_shape
,
weight_shape
):
# TODO: consider bias addition
# TODO: consider bias addition
compute_cost
=
reduce
(
operator
.
mul
,
input_shape
)
*
weight_shape
[
0
]
*
2
compute_cost
=
reduce
(
operator
.
mul
,
input_shape
)
*
weight_shape
[
0
]
*
2
...
@@ -27,18 +33,17 @@ class DotHandler(OperatorHanlder):
...
@@ -27,18 +33,17 @@ class DotHandler(OperatorHanlder):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]}
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
{}
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute computation cost
# compute computation cost
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
numel
=
self
.
output
_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
...
@@ -55,7 +60,7 @@ class DotHandler(OperatorHanlder):
...
@@ -55,7 +60,7 @@ class DotHandler(OperatorHanlder):
memory_cost
=
memory_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
split_lhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_lhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# handle the case SR = SS x SR
# handle the case SR = SS x SR
...
@@ -70,18 +75,17 @@ class DotHandler(OperatorHanlder):
...
@@ -70,18 +75,17 @@ class DotHandler(OperatorHanlder):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
]}
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_output
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
{}
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
numel
=
self
.
output
_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
...
@@ -95,7 +99,7 @@ class DotHandler(OperatorHanlder):
...
@@ -95,7 +99,7 @@ class DotHandler(OperatorHanlder):
memory_cost
=
memory_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
split_rhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_rhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'RS
{
mesh_dim_1
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
'
name
=
f
'RS
{
mesh_dim_1
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
'
...
@@ -107,18 +111,17 @@ class DotHandler(OperatorHanlder):
...
@@ -107,18 +111,17 @@ class DotHandler(OperatorHanlder):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim_1
]}
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim_1
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
{}
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
numel
=
self
.
output
_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim_0
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
...
@@ -132,7 +135,7 @@ class DotHandler(OperatorHanlder):
...
@@ -132,7 +135,7 @@ class DotHandler(OperatorHanlder):
memory_cost
=
memory_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
recompute_split_both_contract
(
self
,
mesh_dim
):
def
recompute_split_both_contract
(
self
,
mesh_dim
):
name
=
f
'RR = RS
{
mesh_dim
}
x S
{
mesh_dim
}
R'
name
=
f
'RR = RS
{
mesh_dim
}
x S
{
mesh_dim
}
R'
...
@@ -144,18 +147,17 @@ class DotHandler(OperatorHanlder):
...
@@ -144,18 +147,17 @@ class DotHandler(OperatorHanlder):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{}
dim_partition_dict_for_output
=
{}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_output
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
{}
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
numel
=
self
.
output
_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
memory_cost
=
numel
*
size_per_elem_bytes
memory_cost
=
numel
*
size_per_elem_bytes
...
@@ -168,7 +170,7 @@ class DotHandler(OperatorHanlder):
...
@@ -168,7 +170,7 @@ class DotHandler(OperatorHanlder):
memory_cost
=
memory_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
split_rhs_space_only
(
self
,
mesh_dim
):
def
split_rhs_space_only
(
self
,
mesh_dim
):
name
=
f
'RS
{
mesh_dim
}
= RR x RS
{
mesh_dim
}
'
name
=
f
'RS
{
mesh_dim
}
= RR x RS
{
mesh_dim
}
'
...
@@ -180,18 +182,17 @@ class DotHandler(OperatorHanlder):
...
@@ -180,18 +182,17 @@ class DotHandler(OperatorHanlder):
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim
]}
dim_partition_dict_for_output
=
{
1
:
[
mesh_dim
]}
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
,
dim_partition_dict_for_output
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output
_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
{}
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
self
.
_generate_resharding_costs
(
resharding_costs
,
sharding_spec_for_input
)
# compute the computation cost of this strategy
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
compute_cost
=
self
.
_generate_compute_cost
(
self
.
input_data
.
shape
,
self
.
weight
.
shape
)
# compute the memory cost of this strategy
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
dtype
=
self
.
input_data
.
dtype
numel
=
self
.
output
.
numel
()
numel
=
self
.
output
_data
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim
]
sharding_size
=
self
.
device_mesh
.
shape
[
mesh_dim
]
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
memory_cost
=
numel
*
size_per_elem_bytes
/
sharding_size
...
@@ -205,9 +206,9 @@ class DotHandler(OperatorHanlder):
...
@@ -205,9 +206,9 @@ class DotHandler(OperatorHanlder):
memory_cost
=
memory_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
strategies
.
append
(
sharding_strategies
)
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
register_strategy
_into_s
trategies
_v
ector
(
self
)
:
def
register_strategy
(
self
)
->
S
trategies
V
ector
:
'''
'''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
...
@@ -233,3 +234,4 @@ class DotHandler(OperatorHanlder):
...
@@ -233,3 +234,4 @@ class DotHandler(OperatorHanlder):
# RS = RR x RS
# RS = RR x RS
self
.
split_rhs_space_only
(
0
)
self
.
split_rhs_space_only
(
0
)
self
.
split_rhs_space_only
(
1
)
self
.
split_rhs_space_only
(
1
)
return
self
.
strategies_vector
colossalai/auto_parallel/solver/operator_handler.py
View file @
ede32629
import
torch
import
torch.nn
as
nn
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
import
torch.nn
as
nn
from
typing
import
Dict
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
.sharding_strategy
import
StrategiesVector
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.sharding_strategy
import
StrategiesVector
class
OperatorHan
l
der
(
ABC
):
class
OperatorHand
l
er
(
ABC
):
'''
'''
The OperatorHan
l
der is an abstract class used to generate every possible strategies for a operator node.
The OperatorHand
l
er is an abstract class used to generate every possible strategies for a operator node.
Argument:
Argument:
input_node(Node): the input node in node argument list.
input_node(Node): the input node in node argument list.
...
@@ -21,30 +24,43 @@ class OperatorHanlder(ABC):
...
@@ -21,30 +24,43 @@ class OperatorHanlder(ABC):
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
shape_consistency_manager(ShapeConsistencyManager): ShapeConsistencyManager will give the resharding costs of the different sharding specs.
'''
'''
def
__init__
(
self
,
input_node
:
Node
,
input_index
:
int
,
weight
:
nn
.
Parameter
,
output_node
:
Node
,
def
__init__
(
self
,
node
:
Node
,
device_mesh
:
DeviceMesh
,
strategies_vector
:
StrategiesVector
,
device_mesh
:
DeviceMesh
,
strategies_vector
:
StrategiesVector
,
shape_consistency_manager
:
ShapeConsistencyManager
):
shape_consistency_manager
:
ShapeConsistencyManager
):
self
.
input_node
=
input_node
self
.
node
=
node
self
.
input_data
=
self
.
input_node
.
_meta_data
self
.
predecessor_node
=
list
(
node
.
_input_nodes
.
keys
())
self
.
weight
=
weight
self
.
successor_node
=
list
(
node
.
users
.
keys
())
self
.
input_index
=
input_index
self
.
output_node
=
output_node
self
.
output
=
self
.
output_node
.
_meta_data
self
.
device_mesh
=
device_mesh
self
.
device_mesh
=
device_mesh
self
.
strategies_vector
=
strategies_vector
self
.
strategies_vector
=
strategies_vector
self
.
shape_consistency_manager
=
shape_consistency_manager
self
.
shape_consistency_manager
=
shape_consistency_manager
# find the module and its parameters associated with this node
# this can be used to compute the compute/communication/sharding cost
if
self
.
node
.
op
==
'call_module'
:
module
=
node
.
graph
.
owning_module
.
get_submodule
(
node
.
target
)
named_parameters
=
list
(
module
.
named_parameters
(
recurse
=
False
))
# convert named parameters from list to dict
named_parameters
=
{
k
:
v
for
k
,
v
in
named_parameters
}
else
:
module
=
None
named_parameters
=
None
self
.
module
=
module
self
.
module_named_parameters
=
named_parameters
@
abstractmethod
@
abstractmethod
def
register_strategy
_into_s
trategies
_v
ector
(
self
)
:
def
register_strategy
(
self
)
->
S
trategies
V
ector
:
pass
pass
def
_generate_sharding_spec
(
self
,
tensor
,
dim_partition_dict
):
def
_generate_sharding_spec
(
self
,
tensor
:
torch
.
Tensor
,
dim_partition_dict
:
Dict
[
int
,
int
])
->
ShardingSpec
:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
tensor
.
shape
,
entire_shape
=
tensor
.
shape
,
dim_partition_dict
=
dim_partition_dict
)
dim_partition_dict
=
dim_partition_dict
)
return
sharding_spec
return
sharding_spec
def
_generate_resharding_costs
(
self
,
resharding_costs
,
sharding_spec_for_input
):
def
_generate_resharding_costs
(
self
,
sharding_spec_for_input
):
'''
'''
Compute the resharding costs with this specific strategy.
Compute the resharding costs with this specific strategy.
...
@@ -58,8 +74,10 @@ class OperatorHanlder(ABC):
...
@@ -58,8 +74,10 @@ class OperatorHanlder(ABC):
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
sharding_spec_for_input(ShardingSpec): ShardingSpec of the input node.
'''
'''
# The resharding_cost of weight is counted due to sharing weight cases.
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs
[
self
.
input_index
]
=
[]
resharding_costs
=
{}
for
stategy
in
self
.
input_node
.
strategies_vector
.
strategies
:
for
input_node
,
input_spec
in
zip
(
self
.
predecessor_node
,
sharding_spec_for_input
):
_
,
_
,
resharding_cost
=
self
.
shape_consistency_manager
.
shape_consistency
(
stategy
,
sharding_spec_for_input
)
resharding_costs
[
input_node
]
=
[]
resharding_costs
[
self
.
input_index
].
append
(
resharding_cost
)
for
strategy
in
input_node
.
strategies_vector
:
_
,
_
,
resharding_cost
=
self
.
shape_consistency_manager
.
shape_consistency
(
strategy
,
input_spec
)
resharding_costs
[
input_node
].
append
(
resharding_cost
)
return
resharding_cost
return
resharding_cost
colossalai/auto_parallel/solver/sharding_strategy.py
View file @
ede32629
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
from
torch.fx.node
import
Node
__all__
=
[
'ShardingStrategy'
,
'StrategiesVector'
]
@
dataclass
@
dataclass
...
@@ -30,26 +33,21 @@ class ShardingStrategy:
...
@@ -30,26 +33,21 @@ class ShardingStrategy:
input_shardings
:
ShardingSpec
=
None
input_shardings
:
ShardingSpec
=
None
class
StrategiesVector
:
class
StrategiesVector
(
list
)
:
'''
'''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
strategies of the node.
Argument:
Argument:
node(Node): node to build corresponding strategies_vector.
node (Node): node for which the list of sharding strategies are generated.
in_nodes(List[Node]): input nodes in the argument list of the node.
following_nodes(List[Node]): the nodes take the target node as their argument.
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
'''
'''
def
__init__
(
self
,
node
,
in_nodes
,
following_nodes
=
None
,
strategies
=
None
):
def
__init__
(
self
,
node
:
Node
):
super
().
__init__
()
self
.
node
=
node
self
.
node
=
node
self
.
in_nodes
=
in_nodes
# fetch its input and output nodes
self
.
following_nodes
=
following_nodes
self
.
predecessor_nodes
=
list
(
node
.
_input_nodes
.
keys
())
self
.
successor_ndoes
=
list
(
node
.
users
.
keys
())
if
strategies
is
None
:
strategies
=
[]
self
.
strategies
=
strategies
def
check_merge
(
self
):
def
check_merge
(
self
):
pass
pass
tests/test_auto_parallel/test_conv_handler.py
View file @
ede32629
...
@@ -47,7 +47,9 @@ def test_conv_handler():
...
@@ -47,7 +47,9 @@ def test_conv_handler():
# [x, mul, conv, output]
# [x, mul, conv, output]
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
strategies_for_input
=
[]
# find the sharding strategies for the input node of the conv node
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input
=
StrategiesVector
(
nodes
[
1
])
sharding_option
=
(
None
,
0
,
1
)
sharding_option
=
(
None
,
0
,
1
)
for
first_sharding_index
in
sharding_option
:
for
first_sharding_index
in
sharding_option
:
for
second_sharding_index
in
sharding_option
:
for
second_sharding_index
in
sharding_option
:
...
@@ -68,28 +70,19 @@ def test_conv_handler():
...
@@ -68,28 +70,19 @@ def test_conv_handler():
sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
entire_shape
,
entire_shape
=
entire_shape
,
sharding_sequence
=
sharding_sequence
)
sharding_sequence
=
sharding_sequence
)
strategies_for_input
.
append
(
sharding_spec
)
strategies_vector_for_input
.
append
(
sharding_spec
)
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input
=
StrategiesVector
(
node
=
nodes
[
0
],
in_nodes
=
[
nodes
[
1
],
2
],
strategies
=
strategies_for_input
)
setattr
(
nodes
[
1
],
'strategies_vector'
,
strategies_vector_for_input
)
setattr
(
nodes
[
1
],
'strategies_vector'
,
strategies_vector_for_input
)
strategies_vector
=
StrategiesVector
(
node
=
nodes
[
2
],
in_nodes
=
[
# generate conv strategy
nodes
[
1
],
strategies_vector
=
StrategiesVector
(
node
=
nodes
[
2
])
])
conv_handler
=
ConvHandler
(
node
=
nodes
[
2
],
conv_handler
=
ConvHandler
(
input_node
=
nodes
[
1
],
input_index
=
0
,
weight
=
dict
(
gm
.
named_modules
())[
nodes
[
2
].
name
].
weight
,
output_node
=
nodes
[
2
],
device_mesh
=
device_mesh
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
,
strategies_vector
=
strategies_vector
,
shape_consistency_manager
=
shape_consistency_manager
)
shape_consistency_manager
=
shape_consistency_manager
)
conv_handler
.
register_strategy
_into_strategies_vector
()
conv_handler
.
register_strategy
()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
strategy_name_list
=
[
strategy
.
name
for
strategy
in
conv_handler
.
strategies_vector
.
strategies
]
strategy_name_list
=
[
strategy
.
name
for
strategy
in
conv_handler
.
strategies_vector
]
# SS = SR x RS
# SS = SR x RS
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
...
...
tests/test_auto_parallel/test_dot_handler.py
View file @
ede32629
...
@@ -47,7 +47,9 @@ def test_dot_handler():
...
@@ -47,7 +47,9 @@ def test_dot_handler():
# [x, mul, linear, output]
# [x, mul, linear, output]
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
strategies_for_input
=
[]
# find the sharding strategies for the input node of the conv node
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input
=
StrategiesVector
(
node
=
nodes
[
1
])
sharding_option
=
(
None
,
0
,
1
)
sharding_option
=
(
None
,
0
,
1
)
for
first_sharding_index
in
sharding_option
:
for
first_sharding_index
in
sharding_option
:
for
second_sharding_index
in
sharding_option
:
for
second_sharding_index
in
sharding_option
:
...
@@ -67,26 +69,19 @@ def test_dot_handler():
...
@@ -67,26 +69,19 @@ def test_dot_handler():
sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
sharding_spec
=
ShardingSpec
(
device_mesh
=
device_mesh
,
entire_shape
=
entire_shape
,
entire_shape
=
entire_shape
,
sharding_sequence
=
sharding_sequence
)
sharding_sequence
=
sharding_sequence
)
strategies_for_input
.
append
(
sharding_spec
)
strategies_vector_for_input
.
append
(
sharding_spec
)
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input
=
StrategiesVector
(
node
=
nodes
[
1
],
in_nodes
=
nodes
[
0
],
strategies
=
strategies_for_input
)
setattr
(
nodes
[
1
],
'strategies_vector'
,
strategies_vector_for_input
)
setattr
(
nodes
[
1
],
'strategies_vector'
,
strategies_vector_for_input
)
strategies_vector
=
StrategiesVector
(
node
=
nodes
[
2
],
in_nodes
=
[
# generate dot strategy
nodes
[
1
],
strategies_vector
=
StrategiesVector
(
node
=
nodes
[
2
])
])
dot_handler
=
DotHandler
(
node
=
nodes
[
2
],
dot_handler
=
DotHandler
(
input_node
=
nodes
[
1
],
input_index
=
0
,
weight
=
dict
(
gm
.
named_modules
())[
nodes
[
2
].
name
].
weight
,
output_node
=
nodes
[
2
],
device_mesh
=
device_mesh
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
,
strategies_vector
=
strategies_vector
,
shape_consistency_manager
=
shape_consistency_manager
)
shape_consistency_manager
=
shape_consistency_manager
)
dot_handler
.
register_strategy
_into_strategies_vector
()
strategies_vector
=
dot_handler
.
register_strategy
()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
strategy_name_list
=
[
strategy
.
name
for
strategy
in
dot_handler
.
strategies_vector
.
strategies
]
strategy_name_list
=
[
strategy
.
name
for
strategy
in
strategies_vector
]
# SS = SR x RS
# SS = SR x RS
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
assert
'S0S1 = S0R x RS1'
in
strategy_name_list
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment