Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
47b11c43
"vscode:/vscode.git/clone" did not exist on "8accecd55bf1a5aaaeb4b84c06fac0d63850fd5e"
Unverified
Commit
47b11c43
authored
Sep 20, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 20, 2022
Browse files
[autoparallel]add bcast matmul strategies (#1605)
parent
edb67cb3
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
492 additions
and
31 deletions
+492
-31
colossalai/auto_parallel/solver/constants.py
colossalai/auto_parallel/solver/constants.py
+1
-1
colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py
...salai/auto_parallel/solver/op_handler/bcast_op_handler.py
+428
-27
colossalai/auto_parallel/solver/op_handler/dot_handler.py
colossalai/auto_parallel/solver/op_handler/dot_handler.py
+1
-2
colossalai/auto_parallel/solver/strategies_constructor.py
colossalai/auto_parallel/solver/strategies_constructor.py
+10
-1
tests/test_auto_parallel/test_bcast_matmul.py
tests/test_auto_parallel/test_bcast_matmul.py
+52
-0
No files found.
colossalai/auto_parallel/solver/constants.py
View file @
47b11c43
...
...
@@ -14,7 +14,7 @@ ELEMENTWISE_FUNC_OP = [
RESHAPE_FUNC_OP
=
[
torch
.
flatten
,
torch
.
Tensor
.
view
,
torch
.
reshape
]
BCAST_FUNC_OP
=
[
torch
.
add
,
torch
.
sub
,
torch
.
mul
,
torch
.
div
,
torch
.
floor_divide
,
torch
.
true_divide
,
operator
.
add
,
operator
.
sub
,
operator
.
mul
,
operator
.
floordiv
,
operator
.
truediv
operator
.
mul
,
operator
.
floordiv
,
operator
.
truediv
,
torch
.
matmul
]
CONV_MODULE_OP
=
[
torch
.
nn
.
Conv1d
,
torch
.
nn
.
Conv2d
,
torch
.
nn
.
Conv3d
,
torch
.
nn
.
ConvTranspose1d
,
torch
.
nn
.
ConvTranspose2d
,
...
...
colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py
View file @
47b11c43
This diff is collapsed.
Click to expand it.
colossalai/auto_parallel/solver/op_handler/dot_handler.py
View file @
47b11c43
...
...
@@ -11,7 +11,6 @@ from enum import Enum
from
.strategy_generator
import
StrategyGenerator
,
IntermediateStrategy
from
typing
import
List
__all__
=
[
'DotHandler'
]
...
...
@@ -465,7 +464,7 @@ class DotHandler(OperatorHandler):
# since weight of the linear layer is transposed
# the actual dim to be sharded is 1
dim_partition_dict_for_weight
=
{
1
:
[
mesh_dim_
0
]}
dim_partition_dict_for_weight
=
{
1
:
[
mesh_dim_
1
]}
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{
0
:
[
mesh_dim_0
]}
...
...
colossalai/auto_parallel/solver/strategies_constructor.py
View file @
47b11c43
...
...
@@ -50,6 +50,15 @@ class StrategiesConstructor:
for
strategy
in
remove_list
:
strategies_vector
.
remove
(
strategy
)
def
_is_bcast_matmul
(
self
,
node
):
is_bcast_matmul
=
False
if
node
.
target
is
torch
.
matmul
and
len
(
node
.
args
)
==
2
:
lhs_data
=
node
.
args
[
0
].
_meta_data
rhs_data
=
node
.
args
[
1
].
_meta_data
if
lhs_data
.
dim
()
>=
3
and
rhs_data
.
dim
()
>=
3
:
is_bcast_matmul
=
True
return
is_bcast_matmul
def
build_strategies_and_cost
(
self
):
for
node
in
self
.
nodes
:
strategies_vector
=
StrategiesVector
(
node
)
...
...
@@ -222,7 +231,7 @@ class StrategiesConstructor:
conv_handler
.
register_strategy
()
# linear function
elif
target
in
LINEAR_FUNC_OP
:
elif
target
in
LINEAR_FUNC_OP
and
not
self
.
_is_bcast_matmul
(
node
)
:
# use DotHandler to create sharding strategies for linear node
# TODO: the operator_handler does NOT support function node processing now.
linear_handler
=
DotHandler
(
node
,
self
.
device_mesh
,
strategies_vector
)
...
...
tests/test_auto_parallel/test_bcast_matmul.py
0 → 100644
View file @
47b11c43
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
colossalai.auto_parallel.solver.options
import
SolverOptions
from
colossalai.auto_parallel.solver.strategies_constructor
import
StrategiesConstructor
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.device.device_mesh
import
DeviceMesh
class
MatmulModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x1
,
x2
):
x
=
torch
.
matmul
(
x1
,
x2
)
return
x
def
test_conv_handler
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
tracer
=
ColoTracer
()
model
=
MatmulModel
()
input_sample
=
{
'x1'
:
torch
.
rand
(
4
,
4
,
8
).
to
(
'meta'
),
'x2'
:
torch
.
rand
(
4
,
1
,
8
,
4
).
to
(
'meta'
)}
# graph():
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
# %matmul : [#users=1] = call_function[target=torch.matmul](args = (%x1, %x2), kwargs = {})
# return matmul
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
# [x1, x2, matmul, output]
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
solver_options
=
SolverOptions
(
fast
=
True
)
strategies_constructor
=
StrategiesConstructor
(
graph
,
device_mesh
,
solver_options
)
strategies_constructor
.
build_strategies_and_cost
()
strategy_map
=
strategies_constructor
.
strategy_map
matmul_strategies
=
strategy_map
[
nodes
[
2
]]
assert
len
(
matmul_strategies
)
==
30
if
__name__
==
'__main__'
:
test_conv_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment