Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
db98b695
Unverified
Commit
db98b695
authored
Sep 15, 2022
by
Frank Lee
Committed by
GitHub
Sep 15, 2022
Browse files
[autoparallel] added strategy generator and bmm strategies (#1602)
parent
a19eb809
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
190 additions
and
3 deletions
+190
-3
colossalai/auto_parallel/solver/op_handler/dot_handler.py
colossalai/auto_parallel/solver/op_handler/dot_handler.py
+155
-2
colossalai/auto_parallel/solver/op_handler/operator_handler.py
...salai/auto_parallel/solver/op_handler/operator_handler.py
+0
-1
colossalai/auto_parallel/solver/op_handler/strategy_generator.py
...lai/auto_parallel/solver/op_handler/strategy_generator.py
+35
-0
No files found.
colossalai/auto_parallel/solver/op_handler/dot_handler.py
View file @
db98b695
import
operator
import
operator
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.operator_handler
import
OperatorHandler
from
.operator_handler
import
OperatorHandler
from
..constants
import
LINEAR_FUNC_OP
,
LINEAR_MODULE_OP
from
functools
import
reduce
from
functools
import
reduce
from
enum
import
Enum
from
.strategy_generator
import
StrategyGenerator
,
IntermediateStrategy
from
typing
import
List
__all__
=
[
'DotHandler'
]
__all__
=
[
'DotHandler'
]
class
MatMulStrategyGenerator
(
StrategyGenerator
):
# TODO: to be implmented
pass
class
BatchedMatMulStrategyGenerator
(
StrategyGenerator
):
"""
Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j]
"""
def
__init__
(
self
,
is_torch_bmm
:
bool
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
is_torch_bmm
=
is_torch_bmm
def
split_one_batch_dim
(
self
):
if
1
in
self
.
device_mesh
.
mesh_shape
:
mesh_dim
=
self
.
device_mesh
.
mesh_shape
.
index
(
1
)
name
=
f
'Sb
{
mesh_dim
}
= Sb
{
mesh_dim
}
x Sb
{
mesh_dim
}
'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]
},
"other"
:
{
0
:
[
mesh_dim
]
},
"bias"
:
{},
"output"
:
{
0
:
[
mesh_dim
]
}
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
else
:
return
None
def
split_two_batch_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'Sb
{
mesh_dim_0
}{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}{
mesh_dim_1
}
'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"other"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"bias"
:
{},
"output"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
}
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
def
split_one_batch_dim
(
self
,
mesh_dim
):
name
=
f
'Sb
{
mesh_dim
}
= Sb
{
mesh_dim
}
x Sb
{
mesh_dim
}
'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{
0
:
[
mesh_dim
]},
"bias"
:
{},
"output"
:
{
0
:
[
mesh_dim
]}}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
def
split_batch_dim_lhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'Sb
{
mesh_dim_0
}
Si
{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}
Si
{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}
'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim_0
],
-
2
:
[
mesh_dim_1
]
},
"other"
:
{
0
:
[
mesh_dim_0
]
},
"bias"
:
{},
"output"
:
{
0
:
mesh_dim_0
,
-
2
:
[
mesh_dim_1
]
}
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
def
split_batch_dim_rhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'Sb
{
mesh_dim_0
}
Sj
{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}
R x Sb
{
mesh_dim_0
}
Sj
{
mesh_dim_1
}
'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim_0
]
},
"other"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]
},
"bias"
:
{
-
1
:
[
mesh_dim_1
]
},
"output"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]
}
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
)
def
split_batch_dim_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'Sb
{
mesh_dim_0
}
R = Sb
{
mesh_dim_0
}
Sk
{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}
Sk
{
mesh_dim_1
}
'
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]
},
"other"
:
{
0
:
[
mesh_dim_0
],
-
2
:
[
mesh_dim_1
]
},
"bias"
:
{},
"output"
:
{
0
:
[
mesh_dim_0
],
-
2
:
[
mesh_dim_1
]
}
}
return
IntermediateStrategy
(
name
=
name
,
dim_partition_dict
=
dim_partition_dict
,
all_reduce_axis
=
[
mesh_dim_1
])
def
generate
(
self
)
->
List
[
IntermediateStrategy
]:
strategy_list
=
[]
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
strategy
=
self
.
split_one_batch_dim
()
if
strategy
:
strategy_list
.
append
(
strategy
)
# split batch dim of two inputs and the i dim of the first tensor
# SbSi = SbSi x Sb
strategy_list
.
append
(
self
.
split_batch_dim_lhs_space
(
0
,
1
))
strategy_list
.
append
(
self
.
split_batch_dim_lhs_space
(
1
,
0
))
# split batch dim of two inputs and the j of the second tensor
# SbSj = Sb x SbSj
strategy_list
.
append
(
self
.
split_batch_dim_rhs_space
(
0
,
1
))
strategy_list
.
append
(
self
.
split_batch_dim_rhs_space
(
1
,
0
))
# split batch dim of two inputs and the k dim of two inputs
# Sb = SbSk x SbSk, need to all-reduce by k dim
strategy_list
.
append
(
self
.
split_batch_dim_both_contract
(
0
,
1
))
strategy_list
.
append
(
self
.
split_batch_dim_both_contract
(
1
,
0
))
# split two batch dim
strategy_list
.
append
(
self
.
split_two_batch_dim
(
0
,
1
))
strategy_list
.
append
(
self
.
split_two_batch_dim
(
1
,
0
))
return
strategy_list
class
DotHandler
(
OperatorHandler
):
class
DotHandler
(
OperatorHandler
):
"""
"""
A OperatorHandler which deals with the sharding strategies
of linear matrix multiplication
.
A OperatorHandler which deals with the sharding strategies
for nn.Linear and F.linear
.
"""
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
...
@@ -297,7 +450,7 @@ class DotHandler(OperatorHandler):
...
@@ -297,7 +450,7 @@ class DotHandler(OperatorHandler):
def
register_strategy
(
self
)
->
StrategiesVector
:
def
register_strategy
(
self
)
->
StrategiesVector
:
'''
'''
Generate every possible strategies for a
Conv
node, and record all strategies into the strategies_vector.
Generate every possible strategies for a
linear
node, and record all strategies into the strategies_vector.
Output:
Output:
...
...
colossalai/auto_parallel/solver/op_handler/operator_handler.py
View file @
db98b695
...
@@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
...
@@ -5,7 +5,6 @@ from abc import ABC, abstractmethod
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.._utils
import
generate_resharding_costs
,
generate_sharding_spec
from
.._utils
import
generate_resharding_costs
,
generate_sharding_spec
from
colossalai.auto_parallel.solver.constants
import
*
from
colossalai.auto_parallel.solver.constants
import
*
...
...
colossalai/auto_parallel/solver/op_handler/strategy_generator.py
0 → 100644
View file @
db98b695
from
dataclasses
import
dataclass
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Dict
from
colossalai.device.device_mesh
import
DeviceMesh
__all__
=
[
'IntermediateStrategy'
,
'StrategyGenerator'
]
@
dataclass
class
IntermediateStrategy
:
"""
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
Args:
name (str): name of the sharding strategy.
dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
"""
name
:
str
dim_partition_dict
:
Dict
[
str
,
Dict
[
int
,
List
[
int
]]]
all_reduce_axis
:
List
[
int
]
=
None
class
StrategyGenerator
(
ABC
):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
"""
def
__init__
(
self
,
device_mesh
:
DeviceMesh
):
self
.
device_mesh
=
device_mesh
@
abstractmethod
def
generate
(
self
)
->
List
[
IntermediateStrategy
]:
pass
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment