Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
1a359941
Unverified
Commit
1a359941
authored
Sep 07, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 07, 2022
Browse files
[autoparallel] support fucntion in operator handler (#1529)
parent
44c866a3
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
22 additions
and
12 deletions
+22
-12
colossalai/auto_parallel/solver/conv_handler.py
colossalai/auto_parallel/solver/conv_handler.py
+10
-10
colossalai/auto_parallel/solver/operator_handler.py
colossalai/auto_parallel/solver/operator_handler.py
+5
-1
colossalai/auto_parallel/solver/strategies_constructor.py
colossalai/auto_parallel/solver/strategies_constructor.py
+7
-1
No files found.
colossalai/auto_parallel/solver/conv_handler.py
View file @
1a359941
...
@@ -9,7 +9,7 @@ __all__ = ['ConvHandler']
...
@@ -9,7 +9,7 @@ __all__ = ['ConvHandler']
class
ConvHandler
(
OperatorHandler
):
class
ConvHandler
(
OperatorHandler
):
"""
"""
A OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
A
n
OperatorHandler which deals with the sharding strategies of linear matrix multiplication.
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
@@ -67,7 +67,7 @@ class ConvHandler(OperatorHandler):
...
@@ -67,7 +67,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
...
@@ -106,7 +106,7 @@ class ConvHandler(OperatorHandler):
...
@@ -106,7 +106,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
...
@@ -145,7 +145,7 @@ class ConvHandler(OperatorHandler):
...
@@ -145,7 +145,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
bs
=
self
.
input_data
.
shape
[
0
]
//
self
.
device_mesh
.
shape
[
mesh_dim_0
]
...
@@ -184,7 +184,7 @@ class ConvHandler(OperatorHandler):
...
@@ -184,7 +184,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
bs
=
self
.
input_data
.
shape
[
0
]
...
@@ -223,7 +223,7 @@ class ConvHandler(OperatorHandler):
...
@@ -223,7 +223,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
bs
=
self
.
input_data
.
shape
[
0
]
...
@@ -261,7 +261,7 @@ class ConvHandler(OperatorHandler):
...
@@ -261,7 +261,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
bs
=
self
.
input_data
.
shape
[
0
]
...
@@ -301,7 +301,7 @@ class ConvHandler(OperatorHandler):
...
@@ -301,7 +301,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
bs
=
self
.
input_data
.
shape
[
0
]
...
@@ -340,7 +340,7 @@ class ConvHandler(OperatorHandler):
...
@@ -340,7 +340,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
//
(
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
])
bs
=
self
.
input_data
.
shape
[
0
]
//
(
self
.
device_mesh
.
shape
[
mesh_dim_0
]
*
self
.
device_mesh
.
shape
[
mesh_dim_1
])
...
@@ -380,7 +380,7 @@ class ConvHandler(OperatorHandler):
...
@@ -380,7 +380,7 @@ class ConvHandler(OperatorHandler):
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
sharding_spec_for_ouput
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_input
)
# generate resharding cost for this strategy
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
,
sharding_spec_for_weight
])
# compute the computation cost of this strategy
# compute the computation cost of this strategy
bs
=
self
.
input_data
.
shape
[
0
]
bs
=
self
.
input_data
.
shape
[
0
]
...
...
colossalai/auto_parallel/solver/operator_handler.py
View file @
1a359941
...
@@ -15,7 +15,7 @@ __all__ = ['OperatorHandler']
...
@@ -15,7 +15,7 @@ __all__ = ['OperatorHandler']
class
OperatorHandler
(
ABC
):
class
OperatorHandler
(
ABC
):
'''
'''
The OperatorHandler is an abstract class used to generate every possible strategies for a operator node.
The OperatorHandler is an abstract class used to generate every possible strategies for a
n
operator node.
Argument:
Argument:
input_node(Node): the input node in node argument list.
input_node(Node): the input node in node argument list.
...
@@ -43,6 +43,10 @@ class OperatorHandler(ABC):
...
@@ -43,6 +43,10 @@ class OperatorHandler(ABC):
named_parameters
=
list
(
module
.
named_parameters
(
recurse
=
False
))
named_parameters
=
list
(
module
.
named_parameters
(
recurse
=
False
))
# convert named parameters from list to dict
# convert named parameters from list to dict
named_parameters
=
{
k
:
v
for
k
,
v
in
named_parameters
}
named_parameters
=
{
k
:
v
for
k
,
v
in
named_parameters
}
elif
self
.
node
.
op
==
'call_function'
:
module
=
None
parameters
=
list
(
self
.
node
.
args
)[
1
]
named_parameters
=
{
'weight'
:
parameters
.
_meta_data
}
else
:
else
:
module
=
None
module
=
None
named_parameters
=
None
named_parameters
=
None
...
...
colossalai/auto_parallel/solver/strategies_constructor.py
View file @
1a359941
...
@@ -27,7 +27,13 @@ class StrategiesConstructor:
...
@@ -27,7 +27,13 @@ class StrategiesConstructor:
Generate the sharding spec of the tensor based on the given dim_partition_dict
Generate the sharding spec of the tensor based on the given dim_partition_dict
where the key is the tensor dimension and the value is the mesh dimension for sharding.
where the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
"""
meta_tensor
=
node
.
_meta_data
if
hasattr
(
node
,
'_meta_data'
):
meta_tensor
=
node
.
_meta_data
elif
isinstance
(
node
,
torch
.
Tensor
):
meta_tensor
=
node
else
:
raise
RuntimeError
(
f
'We cannot generate sharding spec for
{
type
(
node
)
}
type.'
)
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
meta_tensor
.
shape
,
entire_shape
=
meta_tensor
.
shape
,
dim_partition_dict
=
dim_partition_dict
)
dim_partition_dict
=
dim_partition_dict
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment