Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1a359941
Unverified
Commit
1a359941
authored
Sep 07, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 07, 2022
Browse files
[autoparallel] support fucntion in operator handler (#1529)
parent
44c866a3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
12 deletions
+22
-12
colossalai/auto_parallel/solver/conv_handler.py
colossalai/auto_parallel/solver/conv_handler.py
+10
-10
colossalai/auto_parallel/solver/operator_handler.py
colossalai/auto_parallel/solver/operator_handler.py
+5
-1
colossalai/auto_parallel/solver/strategies_constructor.py
colossalai/auto_parallel/solver/strategies_constructor.py
+7
-1
No files found.
colossalai/auto_parallel/solver/conv_handler.py
View file @
1a359941
...
...
@@ -9,7 +9,7 @@ __all__ = ['ConvHandler']
class
ConvHandler
(
OperatorHandler
):
"""
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
A
n
OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
...
@@ -67,7 +67,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
...
...
@@ -106,7 +106,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
...
...
@@ -145,7 +145,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
...
...
@@ -184,7 +184,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
...
...
@@ -223,7 +223,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
...
...
@@ -261,7 +261,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
...
...
@@ -301,7 +301,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
...
...
@@ -340,7 +340,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
//
(
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
])
...
...
@@ -380,7 +380,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
...
...
colossalai/auto_parallel/solver/operator_handler.py
View file @
1a359941
...
...
@@ -15,7 +15,7 @@ __all__ = ['OperatorHandler']
class
OperatorHandler
(
ABC
):
'''
The OperatorHandler is an abstract class used to generate every possible strategies for a operator node.
The OperatorHandler is an abstract class used to generate every possible strategies for a
n
operator node.
Argument:
input_node(Node): the input node in node argument list.
...
...
@@ -43,6 +43,10 @@ class OperatorHandler(ABC):
named_parameters
=
list
(
module
.
named_parameters
(
recurse
=
False
))
# convert named parameters from list to dict
named_parameters
=
{
k
:
v
for
k
,
v
in
named_parameters
}
elif
self
.
node
.
op
==
'call_function'
:
module
=
None
parameters
=
list
(
self
.
node
.
args
)[
1
]
named_parameters
=
{
'weight'
:
parameters
.
_meta_data
}
else
:
module
=
None
named_parameters
=
None
...
...
colossalai/auto_parallel/solver/strategies_constructor.py
View file @
1a359941
...
...
@@ -27,7 +27,13 @@ class StrategiesConstructor:
Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
meta_tensor
=
node
.
_meta_data
if
hasattr
(
node
,
'_meta_data'
):
meta_tensor
=
node
.
_meta_data
elif
isinstance
(
node
,
torch
.
Tensor
):
meta_tensor
=
node
else
:
raise
RuntimeError
(
f
'We cannot generate sharding spec for
{
type
(
node
)
}
type.'
)
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
meta_tensor
.
shape
,
dim_partition_dict
=
dim_partition_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment