Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3a4d6f63
Unverified
Commit
3a4d6f63
authored
Sep 28, 2022
by
Frank Lee
Committed by
GitHub
Sep 28, 2022
Browse files
[autoparallel] added node handler for bmm (#1655)
parent
09585447
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
210 additions
and
27 deletions
+210
-27
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
+29
-2
colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py
...uto_parallel/solver/strategy/matmul_strategy_generator.py
+25
-20
colossalai/auto_parallel/solver/strategy/strategy_generator.py
...salai/auto_parallel/solver/strategy/strategy_generator.py
+6
-5
tests/test_auto_parallel/test_node_handler/test_bmm_handler.py
.../test_auto_parallel/test_node_handler/test_bmm_handler.py
+150
-0
No files found.
colossalai/auto_parallel/solver/op_handler/dot_handler_v2.py
View file @
3a4d6f63
...
...
@@ -2,11 +2,11 @@ import torch
import
torch.nn.functional
as
F
from
.node_handler
import
ModuleHandler
,
NodeHandler
from
..sharding_strategy
import
ShardingStrategy_V2
,
OperationDataType
,
OperationData
from
..strategy
import
LinearProjectionStrategyGenerator
,
StrategyGenerator_V2
from
..strategy
import
LinearProjectionStrategyGenerator
,
StrategyGenerator_V2
,
BatchedMatMulStrategyGenerator
from
typing
import
List
,
Dict
from
.registry
import
operator_registry
__all__
=
[
'LinearModuleHandler'
,
'LinearFunctionHandler'
]
__all__
=
[
'LinearModuleHandler'
,
'LinearFunctionHandler'
,
'BMMFunctionHandler'
]
@
operator_registry
.
register
(
torch
.
nn
.
Linear
)
...
...
@@ -133,3 +133,30 @@ class LinearFunctionHandler(NodeHandler):
# re-init the sharding spec
sharding_spec
.
__init__
(
sharding_spec
.
device_mesh
,
sharding_spec
.
entire_shape
,
dim_partition_dict
)
return
strategy
@
operator_registry
.
register
(
torch
.
bmm
)
@
operator_registry
.
register
(
torch
.
Tensor
.
bmm
)
class
BMMFunctionHandler
(
NodeHandler
):
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
physical_other_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
1
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
1
].
_meta_data
)
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
return
mapping
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator_V2
]:
generators
=
[]
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
BatchedMatMulStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
return
generators
colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py
View file @
3a4d6f63
...
...
@@ -483,6 +483,9 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
other_op_data
=
self
.
op_data
[
'other'
]
assert
input_op_data
.
data
.
dim
()
>
2
or
other_op_data
.
data
.
dim
()
>
2
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
ShardingStrategy_V2
:
return
self
.
op_data
[
'input'
].
data
.
shape
[
-
1
]
*
reduce
(
operator
.
mul
,
self
.
op_data
[
'output'
].
data
.
shape
)
def
split_one_batch_dim
(
self
):
device_mesh_is_1d
=
True
if
len
(
self
.
device_mesh
.
mesh_shape
)
==
1
:
...
...
@@ -552,7 +555,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
},
"bias"
:
{},
"output"
:
{
0
:
mesh_dim_0
,
0
:
[
mesh_dim_0
]
,
-
2
:
[
mesh_dim_1
]
}
}
...
...
@@ -635,8 +638,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# can be None as it is only for 1D device mesh
strategy
=
self
.
split_one_batch_dim
()
if
strategy
:
# only for 1D device mesh
strategy_list
.
append
(
strategy
)
else
:
# for 2D device mesh
# split batch dim of two inputs and the i dim of the first tensor
# SbSi = SbSi x Sb
strategy_list
.
append
(
self
.
split_batch_dim_lhs_space
(
0
,
1
))
...
...
colossalai/auto_parallel/solver/strategy/strategy_generator.py
View file @
3a4d6f63
...
...
@@ -49,6 +49,7 @@ class StrategyGenerator_V2(ABC):
"""
results
=
{}
for
op_data_name
,
dim_partition_dict
in
mapping
.
items
():
if
op_data_name
in
self
.
op_data
:
op_data
=
self
.
op_data
[
op_data_name
]
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
op_data
.
logical_shape
,
...
...
tests/test_auto_parallel/test_node_handler/test_bmm_handler.py
0 → 100644
View file @
3a4d6f63
import
pytest
import
torch
import
torch.nn
as
nn
from
colossalai.fx
import
ColoTracer
,
ColoGraphModule
from
colossalai.auto_parallel.solver.op_handler.dot_handler_v2
import
BMMFunctionHandler
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
class
BMMTensorMethodModule
(
nn
.
Module
):
def
forward
(
self
,
x1
,
x2
):
return
x1
.
bmm
(
x2
)
class
BMMTorchFunctionModule
(
nn
.
Module
):
def
forward
(
self
,
x1
,
x2
):
return
torch
.
bmm
(
x1
,
x2
)
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_2d_device_mesh
(
module
):
model
=
module
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"x1"
:
torch
.
rand
(
4
,
8
,
16
).
to
(
'meta'
),
'x2'
:
torch
.
rand
(
4
,
16
,
8
).
to
(
'meta'
)
})
print
(
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_mod_node
=
list
(
graph
.
nodes
)[
2
]
strategies_vector
=
StrategiesVector
(
linear_mod_node
)
# build handler
handler
=
BMMFunctionHandler
(
node
=
linear_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
logical_shape
is
not
None
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"x1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
8
,
16
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
8
,
16
])
assert
mapping
[
'other'
].
name
==
"x2"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
4
,
16
,
8
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
4
,
16
,
8
])
assert
mapping
[
'output'
].
name
==
"bmm"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
8
,
8
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
()
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# one batch dim
assert
'Sb0 = Sb0 x Sb0'
not
in
strategy_name_list
# two batch dim
assert
'Sb01 = Sb01 x Sb01'
in
strategy_name_list
# SbSi = SbSi x Sb
assert
'Sb0Si1 = Sb0Si1 x Sb0'
in
strategy_name_list
assert
'Sb1Si0 = Sb1Si0 x Sb1'
in
strategy_name_list
# SbSj = SbR x SbSj
assert
'Sb0Sj1 = Sb0R x Sb0Sj1'
in
strategy_name_list
assert
'Sb1Sj0 = Sb1R x Sb1Sj0'
in
strategy_name_list
# SbR = SbSk x SbSk
assert
'Sb0R = Sb0Sk1 x Sb0Sk1'
in
strategy_name_list
assert
'Sb1R = Sb1Sk0 x Sb1Sk0'
in
strategy_name_list
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_1d_device_mesh
(
module
):
model
=
module
()
tracer
=
ColoTracer
()
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"x1"
:
torch
.
rand
(
4
,
8
,
16
).
to
(
'meta'
),
'x2'
:
torch
.
rand
(
4
,
16
,
8
).
to
(
'meta'
)
})
print
(
graph
)
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
1
,
4
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_mod_node
=
list
(
graph
.
nodes
)[
2
]
strategies_vector
=
StrategiesVector
(
linear_mod_node
)
# build handler
handler
=
BMMFunctionHandler
(
node
=
linear_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
logical_shape
is
not
None
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"x1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
8
,
16
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
8
,
16
])
assert
mapping
[
'other'
].
name
==
"x2"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
4
,
16
,
8
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
4
,
16
,
8
])
assert
mapping
[
'output'
].
name
==
"bmm"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
8
,
8
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
()
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
assert
len
(
strategy_name_list
)
==
1
# one batch dim
assert
'Sb0 = Sb0 x Sb0'
in
strategy_name_list
if
__name__
==
'__main__'
:
test_1d_device_mesh
(
BMMTensorMethodModule
)
test_1d_device_mesh
(
BMMTorchFunctionModule
)
test_2d_device_mesh
(
BMMTensorMethodModule
)
test_2d_device_mesh
(
BMMTorchFunctionModule
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment