Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
3abf98a6
Unverified
Commit
3abf98a6
authored
Sep 16, 2022
by
Frank Lee
Committed by
GitHub
Sep 16, 2022
Browse files
[autoparallel] added all non-bcast matmul strategies (#1603)
parent
db98b695
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
251 additions
and
3 deletions
+251
-3
colossalai/auto_parallel/solver/op_handler/dot_handler.py
colossalai/auto_parallel/solver/op_handler/dot_handler.py
+240
-2
colossalai/auto_parallel/solver/op_handler/strategy_generator.py
...lai/auto_parallel/solver/op_handler/strategy_generator.py
+11
-1
No files found.
colossalai/auto_parallel/solver/op_handler/dot_handler.py
View file @
3abf98a6
...
...
@@ -13,9 +13,238 @@ from typing import List
__all__
=
[
'DotHandler'
]
class
DotProductStrategyGenerator
(
StrategyGenerator
):
"""
DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation.
This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we
do not consider bias here.
"""
def
validate
(
self
,
input
,
other
):
assert
input
.
dim
()
==
1
and
other
.
dim
()
==
1
def
no_split
(
self
):
name
=
f
'R = R dot R'
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{}}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
def
split_one_dim
(
self
,
mesh_dim
):
name
=
f
'S
{
mesh_dim
}
= S
{
mesh_dim
}
dot S
{
mesh_dim
}
'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{
0
:
[
mesh_dim
]},
"output"
:
{}}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
,
all_reduce_axis
=
[
mesh_dim
])
def
generate
(
self
)
->
List
[
IntermediateStrategy
]:
strategy_list
=
[]
# do not split dimensions for dot product
# R = R dot R
strategy_list
.
append
(
self
.
no_split
())
# split two tensors in the same dimensions
# S = S dot S
strategy_list
.
append
(
self
.
split_one_dim
(
0
))
strategy_list
.
append
(
self
.
split_one_dim
(
1
))
return
strategy_list
class
MatVecStrategyGenerator
(
StrategyGenerator
):
def
validate
(
self
,
input
,
other
)
->
bool
:
assert
input
.
dim
()
>
1
and
other
.
dim
()
==
1
def
no_split
(
self
):
name
=
"R = R x R"
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{}}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
def
split_input_batch
(
self
,
mesh_dim
):
name
=
f
'S
{
mesh_dim
}
R = S
{
mesh_dim
}
R x R'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{},
"output"
:
{
0
:
[
mesh_dim
]}}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
def
generate
(
self
)
->
List
[
IntermediateStrategy
]:
strategy_list
=
[]
# no split
strategy_list
.
append
(
self
.
no_split
())
# split the batch dim for the first tensor only
strategy_list
.
append
(
self
.
split_input_batch
(
0
))
strategy_list
.
append
(
self
.
split_input_batch
(
1
))
return
strategy_list
class
MatMulStrategyGenerator
(
StrategyGenerator
):
# TODO: to be implmented
pass
"""
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
A matmul can be formulated as [n, p] x [p, q] = [n, q]
Args:
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
This will incur extra transformation of the dim partitioning as the weight is transposed.
"""
def
__init__
(
self
,
is_linear
:
bool
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
is_linear
=
is_linear
# as the weight for the linear module is transposed, we can compute
# the correponding dimension indexfor convenience
if
is_linear
:
self
.
dim_q
=
0
self
.
dim_p
=
1
else
:
self
.
dim_q
=
1
self
.
dim_p
=
0
def
validate
(
self
,
input
,
other
,
bias
)
->
bool
:
# make sure the second tensor is a 2D tensor
assert
input
.
dim
()
>
0
and
other
.
dim
()
==
2
# make sure bias is of the same dimension
if
self
.
is_linear
:
assert
bias
is
None
or
bias
.
shape
[
-
1
]
==
other
.
shape
[
0
]
else
:
assert
bias
is
None
or
bias
.
shape
[
-
1
]
==
other
.
shape
[
1
]
def
split_lhs_space_rhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# handle case SS = SR x RS
name
=
f
'S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
R x RS
{
mesh_dim_1
}
'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim_0
]
},
"other"
:
{
self
.
dim_q
:
[
mesh_dim_1
]
},
"bias"
:
{
-
1
:
[
mesh_dim_1
]
},
"output"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]
},
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
def
split_lhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# handle the case SR = SS x SR
name
=
f
'S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x S
{
mesh_dim_1
}
R'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]
},
"other"
:
{
self
.
dim_p
:
[
mesh_dim_1
]
},
"bias"
:
{},
"output"
:
{
0
:
[
mesh_dim_0
]
},
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
,
all_reduce_axis
=
[
mesh_dim_1
])
def
split_rhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'RS
{
mesh_dim_1
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
'
dim_partition_dict
=
{
"input"
:
{
-
1
:
[
mesh_dim_0
]
},
"other"
:
{
self
.
dim_p
:
[
mesh_dim_0
],
self
.
dim_q
:
[
mesh_dim_1
]
},
"bias"
:
{
-
1
:
[
mesh_dim_1
]
},
"output"
:
{
-
1
:
[
mesh_dim_1
]
},
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
def
recompute_split_both_contract
(
self
,
mesh_dim
):
name
=
f
'RR = RS
{
mesh_dim
}
x S
{
mesh_dim
}
R'
dim_partition_dict
=
{
"input"
:
{
-
1
:
[
mesh_dim
]
},
"other"
:
{
self
.
dim_p
:
[
mesh_dim
]
},
"bias"
:
{},
"output"
:
{},
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
,
all_reduce_axis
=
[
mesh_dim
])
def
split_rhs_space_only
(
self
,
mesh_dim
):
name
=
f
'RS
{
mesh_dim
}
= RR x RS
{
mesh_dim
}
'
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{
self
.
dim_q
:
[
mesh_dim
]
},
"bias"
:
{
-
1
:
[
mesh_dim
]
},
"output"
:
{
-
1
:
[
mesh_dim
]
},
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
,
all_reduce_axis
=
[
mesh_dim
])
def
split_lhs_1st_dim_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'S
{
mesh_dim_0
}{
mesh_dim_1
}
R = S
{
mesh_dim_0
}{
mesh_dim_1
}
R x RR'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"other"
:
{},
"bias"
:
{},
"output"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
def
split_lhs_2nd_dim_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'RR = RS
{
mesh_dim_0
}{
mesh_dim_1
}
x S
{
mesh_dim_0
}{
mesh_dim_1
}
R'
dim_partition_dict
=
{
"input"
:
{
-
1
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"other"
:
{
self
.
dim_p
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"bias"
:
{},
"output"
:
{},
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
,
all_reduce_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
def
split_rhs_2nd_dim_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'RS
{
mesh_dim_0
}{
mesh_dim_1
}
= RR x RS
{
mesh_dim_0
}{
mesh_dim_1
}
'
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{
self
.
dim_q
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"bias"
:
{
-
1
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"output"
:
{
-
1
:
[
mesh_dim_0
,
mesh_dim_1
]
},
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
class
BatchedMatMulStrategyGenerator
(
StrategyGenerator
):
...
...
@@ -30,6 +259,15 @@ class BatchedMatMulStrategyGenerator(StrategyGenerator):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
is_torch_bmm
=
is_torch_bmm
def
validate
(
self
,
input
,
other
,
bias
)
->
bool
:
if
self
.
is_torch_bmm
:
assert
input
.
shape
==
other
.
shape
assert
input
.
dim
()
>
2
assert
other
.
shape
[
-
1
]
==
bias
.
shape
[
0
]
else
:
# TODO: validate these inputs are broadcastable
pass
def
split_one_batch_dim
(
self
):
if
1
in
self
.
device_mesh
.
mesh_shape
:
mesh_dim
=
self
.
device_mesh
.
mesh_shape
.
index
(
1
)
...
...
colossalai/auto_parallel/solver/op_handler/strategy_generator.py
View file @
3abf98a6
...
...
@@ -32,4 +32,14 @@ class StrategyGenerator(ABC):
@
abstractmethod
def
generate
(
self
)
->
List
[
IntermediateStrategy
]:
pass
\ No newline at end of file
"""
"""
pass
@
abstractmethod
def
validate
(
self
,
*
args
,
**
kwargs
)
->
bool
:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment