Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
9e768b59
Commit
9e768b59
authored
Oct 10, 2023
by
zhuwenwen
Browse files
Merge branch 'main' of
https://github.com/hpcaitech/ColossalAI
parents
7bc5a8e3
8aed02b9
Changes
436
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1004 additions
and
928 deletions
+1004
-928
colossalai/auto_parallel/tensor_shard/node_handler/registry.py
...salai/auto_parallel/tensor_shard/node_handler/registry.py
+2
-5
colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
...uto_parallel/tensor_shard/node_handler/softmax_handler.py
+4
-4
colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
.../auto_parallel/tensor_shard/node_handler/split_handler.py
+4
-4
colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
...o_parallel/tensor_shard/node_handler/strategy/__init__.py
+27
-7
colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
...ensor_shard/node_handler/strategy/batch_norm_generator.py
+90
-96
colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py
...ard/node_handler/strategy/binary_elementwise_generator.py
+22
-20
colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
...or_shard/node_handler/strategy/conv_strategy_generator.py
+134
-101
colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py
...tensor_shard/node_handler/strategy/embedding_generator.py
+65
-51
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
...l/tensor_shard/node_handler/strategy/getattr_generator.py
+7
-6
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
...l/tensor_shard/node_handler/strategy/getitem_generator.py
+27
-27
colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
...ensor_shard/node_handler/strategy/layer_norm_generator.py
+34
-27
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
..._shard/node_handler/strategy/matmul_strategy_generator.py
+338
-354
colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
...r_shard/node_handler/strategy/normal_pooling_generator.py
+24
-17
colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
...el/tensor_shard/node_handler/strategy/output_generator.py
+28
-23
colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
...nsor_shard/node_handler/strategy/placeholder_generator.py
+24
-19
colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
...l/tensor_shard/node_handler/strategy/reshape_generator.py
+56
-41
colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
...l/tensor_shard/node_handler/strategy/softmax_generator.py
+21
-30
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
.../tensor_shard/node_handler/strategy/strategy_generator.py
+64
-46
colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py
...allel/tensor_shard/node_handler/strategy/sum_generator.py
+22
-32
colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
...ard/node_handler/strategy/tensor_constructor_generator.py
+11
-18
No files found.
Too many changes to show.
To preserve performance only
436 of 436+
files are displayed.
Plain diff
Email patch
colossalai/auto_parallel/tensor_shard/node_handler/registry.py
View file @
9e768b59
class
Registry
:
class
Registry
:
# TODO: refactor the registry classes used in colossalai.registry, colossalai.fx and here
def
__init__
(
self
,
name
):
def
__init__
(
self
,
name
):
self
.
name
=
name
self
.
name
=
name
self
.
store
=
{}
self
.
store
=
{}
def
register
(
self
,
source
):
def
register
(
self
,
source
):
def
wrapper
(
func
):
def
wrapper
(
func
):
if
isinstance
(
source
,
(
list
,
tuple
)):
if
isinstance
(
source
,
(
list
,
tuple
)):
# support register a list of items for this func
# support register a list of items for this func
...
@@ -19,7 +16,7 @@ class Registry:
...
@@ -19,7 +16,7 @@ class Registry:
return
wrapper
return
wrapper
def
get
(
self
,
source
):
def
get
(
self
,
source
):
assert
source
in
self
.
store
,
f
'
{
source
}
not found in the
{
self
.
name
}
registry
'
assert
source
in
self
.
store
,
f
"
{
source
}
not found in the
{
self
.
name
}
registry
"
target
=
self
.
store
[
source
]
target
=
self
.
store
[
source
]
return
target
return
target
...
@@ -27,4 +24,4 @@ class Registry:
...
@@ -27,4 +24,4 @@ class Registry:
return
source
in
self
.
store
return
source
in
self
.
store
operator_registry
=
Registry
(
'
operator
'
)
operator_registry
=
Registry
(
"
operator
"
)
colossalai/auto_parallel/tensor_shard/node_handler/softmax_handler.py
View file @
9e768b59
...
@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
...
@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from
.registry
import
operator_registry
from
.registry
import
operator_registry
from
.strategy
import
SoftmaxGenerator
,
StrategyGenerator
from
.strategy
import
SoftmaxGenerator
,
StrategyGenerator
__all__
=
[
'
SoftmaxHandler
'
]
__all__
=
[
"
SoftmaxHandler
"
]
@
operator_registry
.
register
(
torch
.
nn
.
Softmax
)
@
operator_registry
.
register
(
torch
.
nn
.
Softmax
)
...
@@ -34,14 +34,14 @@ class SoftmaxHandler(NodeHandler):
...
@@ -34,14 +34,14 @@ class SoftmaxHandler(NodeHandler):
input_data
=
self
.
node
.
args
[
0
].
_meta_data
input_data
=
self
.
node
.
args
[
0
].
_meta_data
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
data_type
,
data
=
input_data
)
softmax_dim
=
self
.
node
.
kwargs
[
'
dim
'
]
softmax_dim
=
self
.
node
.
kwargs
[
"
dim
"
]
num_dims
=
self
.
node
.
args
[
0
].
_meta_data
.
dim
()
num_dims
=
self
.
node
.
args
[
0
].
_meta_data
.
dim
()
# recover negative value to positive
# recover negative value to positive
if
softmax_dim
<
0
:
if
softmax_dim
<
0
:
softmax_dim
+=
num_dims
softmax_dim
+=
num_dims
physical_dim_operand
=
OperationData
(
name
=
'
softmax_dim
'
,
type
=
OperationDataType
.
ARG
,
data
=
softmax_dim
)
physical_dim_operand
=
OperationData
(
name
=
"
softmax_dim
"
,
type
=
OperationDataType
.
ARG
,
data
=
softmax_dim
)
output_data
=
self
.
node
.
_meta_data
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
...
@@ -49,7 +49,7 @@ class SoftmaxHandler(NodeHandler):
...
@@ -49,7 +49,7 @@ class SoftmaxHandler(NodeHandler):
mapping
=
{
mapping
=
{
"input"
:
physical_input_operand
,
"input"
:
physical_input_operand
,
"softmax_dim"
:
physical_dim_operand
,
"softmax_dim"
:
physical_dim_operand
,
"output"
:
physical_output_operand
"output"
:
physical_output_operand
,
}
}
return
mapping
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/split_handler.py
View file @
9e768b59
...
@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
...
@@ -7,7 +7,7 @@ from .node_handler import NodeHandler
from
.registry
import
operator_registry
from
.registry
import
operator_registry
from
.strategy
import
SplitGenerator
,
StrategyGenerator
from
.strategy
import
SplitGenerator
,
StrategyGenerator
__all__
=
[
'
SplitHandler
'
]
__all__
=
[
"
SplitHandler
"
]
@
operator_registry
.
register
(
torch
.
Tensor
.
split
)
@
operator_registry
.
register
(
torch
.
Tensor
.
split
)
...
@@ -38,7 +38,7 @@ class SplitHandler(NodeHandler):
...
@@ -38,7 +38,7 @@ class SplitHandler(NodeHandler):
split_dim
=
self
.
node
.
args
[
2
]
split_dim
=
self
.
node
.
args
[
2
]
else
:
else
:
if
self
.
node
.
kwargs
:
if
self
.
node
.
kwargs
:
split_dim
=
self
.
node
.
kwargs
[
'
dim
'
]
split_dim
=
self
.
node
.
kwargs
[
"
dim
"
]
else
:
else
:
split_dim
=
0
split_dim
=
0
...
@@ -48,7 +48,7 @@ class SplitHandler(NodeHandler):
...
@@ -48,7 +48,7 @@ class SplitHandler(NodeHandler):
split_dim
+=
num_dims
split_dim
+=
num_dims
split_info
=
(
split_size
,
split_dim
)
split_info
=
(
split_size
,
split_dim
)
physical_shape_operand
=
OperationData
(
name
=
'
split_info
'
,
type
=
OperationDataType
.
ARG
,
data
=
split_info
)
physical_shape_operand
=
OperationData
(
name
=
"
split_info
"
,
type
=
OperationDataType
.
ARG
,
data
=
split_info
)
output_data
=
self
.
node
.
_meta_data
output_data
=
self
.
node
.
_meta_data
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
physical_output_operand
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
output_data
)
...
@@ -56,7 +56,7 @@ class SplitHandler(NodeHandler):
...
@@ -56,7 +56,7 @@ class SplitHandler(NodeHandler):
mapping
=
{
mapping
=
{
"input"
:
physical_input_operand
,
"input"
:
physical_input_operand
,
"split_info"
:
physical_shape_operand
,
"split_info"
:
physical_shape_operand
,
"output"
:
physical_output_operand
"output"
:
physical_output_operand
,
}
}
return
mapping
return
mapping
colossalai/auto_parallel/tensor_shard/node_handler/strategy/__init__.py
View file @
9e768b59
...
@@ -29,11 +29,31 @@ from .unary_elementwise_generator import UnaryElementwiseGenerator
...
@@ -29,11 +29,31 @@ from .unary_elementwise_generator import UnaryElementwiseGenerator
from
.where_generator
import
WhereGenerator
from
.where_generator
import
WhereGenerator
__all__
=
[
__all__
=
[
'StrategyGenerator'
,
'DotProductStrategyGenerator'
,
'MatVecStrategyGenerator'
,
'LinearProjectionStrategyGenerator'
,
"StrategyGenerator"
,
'BatchedMatMulStrategyGenerator'
,
'ConvStrategyGenerator'
,
'UnaryElementwiseGenerator'
,
"DotProductStrategyGenerator"
,
'BatchNormStrategyGenerator'
,
'GetItemStrategyGenerator'
,
'TensorStrategyGenerator'
,
'TensorTupleStrategyGenerator'
,
"MatVecStrategyGenerator"
,
'LayerNormGenerator'
,
'PlaceholderGenerator'
,
'OutputGenerator'
,
'WhereGenerator'
,
'NormalPoolStrategyGenerator'
,
"LinearProjectionStrategyGenerator"
,
'BinaryElementwiseStrategyGenerator'
,
'GetattrGenerator'
,
'TensorConstructorGenerator'
,
"BatchedMatMulStrategyGenerator"
,
'EmbeddingStrategyGenerator'
,
'SumGenerator'
,
'SoftmaxGenerator'
,
'ViewGenerator'
,
'PermuteGenerator'
,
"ConvStrategyGenerator"
,
'TransposeGenerator'
,
'SplitGenerator'
,
'DefaultReshapeGenerator'
"UnaryElementwiseGenerator"
,
"BatchNormStrategyGenerator"
,
"GetItemStrategyGenerator"
,
"TensorStrategyGenerator"
,
"TensorTupleStrategyGenerator"
,
"LayerNormGenerator"
,
"PlaceholderGenerator"
,
"OutputGenerator"
,
"WhereGenerator"
,
"NormalPoolStrategyGenerator"
,
"BinaryElementwiseStrategyGenerator"
,
"GetattrGenerator"
,
"TensorConstructorGenerator"
,
"EmbeddingStrategyGenerator"
,
"SumGenerator"
,
"SoftmaxGenerator"
,
"ViewGenerator"
,
"PermuteGenerator"
,
"TransposeGenerator"
,
"SplitGenerator"
,
"DefaultReshapeGenerator"
,
]
]
colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
View file @
9e768b59
...
@@ -14,7 +14,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
...
@@ -14,7 +14,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
__all__
=
[
'
BatchNormStrategyGenerator
'
]
__all__
=
[
"
BatchNormStrategyGenerator
"
]
class
BatchNormStrategyGenerator
(
StrategyGenerator
):
class
BatchNormStrategyGenerator
(
StrategyGenerator
):
...
@@ -24,34 +24,37 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -24,34 +24,37 @@ class BatchNormStrategyGenerator(StrategyGenerator):
To keep the math consistency, there are two way to do BatchNorm if the input
To keep the math consistency, there are two way to do BatchNorm if the input
shards on batch dimension:
shards on batch dimension:
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
2. We do the SyncBatchNorm on the each input partition sep
e
rately, the SyncBN op will help
2. We do the SyncBatchNorm on the each input partition sep
a
rately, the SyncBN op will help
us to keep the computing correctness.
us to keep the computing correctness.
In this generator, both methods will be considered.
In this generator, both methods will be considered.
"""
"""
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
'''
"""
In sanity check, we need make sure the input data having correct dimension size.
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
'''
"""
input_op_data
=
self
.
op_data
[
'
input
'
]
input_op_data
=
self
.
op_data
[
"
input
"
]
assert
input_op_data
.
data
.
dim
()
in
(
assert
input_op_data
.
data
.
dim
()
in
(
3
,
4
,
5
),
f
'We suppose the dim of input fed into conv op should in range of [3, 5].'
3
,
4
,
5
,
),
f
"We suppose the dim of input fed into conv op should in range of [3, 5]."
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the computation cost per device with this specific strategy.
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be d
e
vided by TFLOPS, now it just shows the computation size.
Note: compute_cost need to be d
i
vided by TFLOPS, now it just shows the computation size.
'''
"""
# TODO: a constant coefficient need to be added.
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
# 3D: (H * W * D) * N * Cin
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
output
'
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
output
"
]].
get_sharded_shape_per_device
()
if
self
.
has_bias
:
if
self
.
has_bias
:
# bias add is an element wise operation, so the cost is equal to product of output shape.
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost
=
reduce
(
operator
.
mul
,
sharded_output_shape
)
bias_compute_cost
=
reduce
(
operator
.
mul
,
sharded_output_shape
)
...
@@ -69,23 +72,24 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -69,23 +72,24 @@ class BatchNormStrategyGenerator(StrategyGenerator):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
forward_size_mapping
=
{
forward_size_mapping
=
{
'
input
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
"
input
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'
other
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
"
other
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
),
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
),
'
running_mean
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"running_mean"
),
"
running_mean
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"running_mean"
),
'
running_var
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"running_var"
),
"
running_var
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"running_var"
),
}
}
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_size
=
self
.
_compute_size_in_bytes
(
strategy
,
"bias"
)
bias_size
=
self
.
_compute_size_in_bytes
(
strategy
,
"bias"
)
forward_size_mapping
[
'
bias
'
]
=
bias_size
forward_size_mapping
[
"
bias
"
]
=
bias_size
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
.
pop
(
"output"
)
backward_size_mapping
.
pop
(
"output"
)
# compute fwd cost incurred
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
# fwd_cost = input + other + bias + output
fwd_activation_cost
=
sum
(
fwd_activation_cost
=
sum
(
[
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)
and
not
self
.
is_buffer
(
k
)])
[
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)
and
not
self
.
is_buffer
(
k
)]
)
fwd_parameter_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
fwd_parameter_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
fwd_buffer_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
self
.
is_buffer
(
k
)])
fwd_buffer_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
self
.
is_buffer
(
k
)])
fwd_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
,
parameter
=
fwd_parameter_cost
,
buffer
=
fwd_buffer_cost
)
fwd_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
,
parameter
=
fwd_parameter_cost
,
buffer
=
fwd_buffer_cost
)
...
@@ -93,36 +97,29 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -93,36 +97,29 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# compute bwd cost incurred
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost
=
sum
(
bwd_activation_cost
=
sum
(
[
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)
and
not
self
.
is_buffer
(
k
)])
[
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)
and
not
self
.
is_buffer
(
k
)]
)
bwd_parameter_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
bwd_parameter_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
total_mem_cost
=
MemoryCost
(
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
,
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
buffer
=
fwd_buffer_cost
)
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
,
buffer
=
fwd_buffer_cost
,
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_channel
(
self
,
mesh_dim_0
):
def
split_input_channel
(
self
,
mesh_dim_0
):
name
=
f
'
RS
{
mesh_dim_0
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
'
name
=
f
"
RS
{
mesh_dim_0
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
1
:
[
mesh_dim_0
]},
1
:
[
mesh_dim_0
]
"other"
:
{
0
:
[
mesh_dim_0
]},
},
"output"
:
{
1
:
[
mesh_dim_0
]},
"other"
:
{
"running_mean"
:
{
0
:
[
mesh_dim_0
]},
0
:
[
mesh_dim_0
]
"running_var"
:
{
0
:
[
mesh_dim_0
]},
},
"output"
:
{
1
:
[
mesh_dim_0
]
},
"running_mean"
:
{
0
:
[
mesh_dim_0
]
},
"running_var"
:
{
0
:
[
mesh_dim_0
]
},
"num_batches_tracked"
:
{},
"num_batches_tracked"
:
{},
}
}
if
self
.
has_bias
:
if
self
.
has_bias
:
...
@@ -132,29 +129,21 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -132,29 +129,21 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_channel_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_input_channel_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
RS
{
mesh_dim_0
}{
mesh_dim_1
}
= RS
{
mesh_dim_0
}{
mesh_dim_1
}
x S
{
mesh_dim_0
}{
mesh_dim_1
}
'
name
=
f
"
RS
{
mesh_dim_0
}{
mesh_dim_1
}
= RS
{
mesh_dim_0
}{
mesh_dim_1
}
x S
{
mesh_dim_0
}{
mesh_dim_1
}
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
1
:
[
mesh_dim_0
,
mesh_dim_1
]},
1
:
[
mesh_dim_0
,
mesh_dim_1
]
"other"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
},
"output"
:
{
1
:
[
mesh_dim_0
,
mesh_dim_1
]},
"other"
:
{
"running_mean"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
0
:
[
mesh_dim_0
,
mesh_dim_1
]
"running_var"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
},
"output"
:
{
1
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"running_mean"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"running_var"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"num_batches_tracked"
:
{},
"num_batches_tracked"
:
{},
}
}
if
self
.
has_bias
:
if
self
.
has_bias
:
...
@@ -164,13 +153,15 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -164,13 +153,15 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
non_split
(
self
):
def
non_split
(
self
):
name
=
f
'
RR = RR x R
'
name
=
f
"
RR = RR x R
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{},
"input"
:
{},
"other"
:
{},
"other"
:
{},
...
@@ -186,21 +177,19 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -186,21 +177,19 @@ class BatchNormStrategyGenerator(StrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_batch
(
self
,
mesh_dim_0
):
def
split_input_batch
(
self
,
mesh_dim_0
):
name
=
f
'
S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
R x R WITH SYNC_BN
'
name
=
f
"
S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
R x R WITH SYNC_BN
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
]},
0
:
[
mesh_dim_0
]
},
"other"
:
{},
"other"
:
{},
"output"
:
{
"output"
:
{
0
:
[
mesh_dim_0
]},
0
:
[
mesh_dim_0
]
},
"running_mean"
:
{},
"running_mean"
:
{},
"running_var"
:
{},
"running_var"
:
{},
"num_batches_tracked"
:
{},
"num_batches_tracked"
:
{},
...
@@ -212,33 +201,32 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -212,33 +201,32 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action
# set communication action
# For SyncBN case, we don't need to do communication for weight and bias.
# For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens inter
n
ally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
# to SyncBN operation instead of inserting a communication node.
output_comm_action
=
self
.
get_communication_action
(
output_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
"output"
],
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
IMPLICIT
)
comm_type
=
CommType
.
IMPLICIT
,
)
# TODO: Temporary solution has no communication cost,
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping
=
{}
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_batch_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_input_batch_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
S
{
mesh_dim_0
}{
mesh_dim_1
}
R = S
{
mesh_dim_0
}{
mesh_dim_1
}
R x R WITH SYNC_BN
'
name
=
f
"
S
{
mesh_dim_0
}{
mesh_dim_1
}
R = S
{
mesh_dim_0
}{
mesh_dim_1
}
R x R WITH SYNC_BN
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"other"
:
{},
"other"
:
{},
"output"
:
{
"output"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"running_mean"
:
{},
"running_mean"
:
{},
"running_var"
:
{},
"running_var"
:
{},
"num_batches_tracked"
:
{},
"num_batches_tracked"
:
{},
...
@@ -250,25 +238,28 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -250,25 +238,28 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens inter
n
ally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
# to SyncBN operation instead of inserting a communication node.
output_comm_action
=
self
.
get_communication_action
(
output_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
"output"
],
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
IMPLICIT
)
comm_type
=
CommType
.
IMPLICIT
,
)
# TODO: Temporary solution has no communication cost,
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping
=
{}
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_both_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_input_both_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x S
{
mesh_dim_1
}
WITH SYNC_BN
'
name
=
f
"
S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x S
{
mesh_dim_1
}
WITH SYNC_BN
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
],
0
:
[
mesh_dim_0
],
...
@@ -298,26 +289,29 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -298,26 +289,29 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# set communication action
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens inter
n
ally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
# to SyncBN operation instead of inserting a communication node.
output_comm_action
=
self
.
get_communication_action
(
output_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
"output"
],
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
],
logical_process_axis
=
[
mesh_dim_0
],
comm_type
=
CommType
.
IMPLICIT
)
comm_type
=
CommType
.
IMPLICIT
,
)
# TODO: Temporary solution has no communication cost,
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping
=
{}
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
'''
"""
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
'''
"""
strategy_list
=
[]
strategy_list
=
[]
# RS = RS x S
# RS = RS x S
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/binary_elementwise_generator.py
View file @
9e768b59
...
@@ -14,7 +14,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
...
@@ -14,7 +14,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
__all__
=
[
'
BinaryElementwiseStrategyGenerator
'
]
__all__
=
[
"
BinaryElementwiseStrategyGenerator
"
]
class
BinaryElementwiseStrategyGenerator
(
StrategyGenerator
):
class
BinaryElementwiseStrategyGenerator
(
StrategyGenerator
):
...
@@ -26,36 +26,37 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
...
@@ -26,36 +26,37 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
"""
"""
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
assert
len
(
self
.
op_data
)
==
3
,
\
assert
(
f
'BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got
{
len
(
self
.
op_data
)
}
'
len
(
self
.
op_data
)
==
3
),
f
"BinaryElementwiseStrategyGenerator only accepts three operation data (input, other and output), but got
{
len
(
self
.
op_data
)
}
"
for
name
,
op_data
in
self
.
op_data
.
items
():
for
name
,
op_data
in
self
.
op_data
.
items
():
if
not
isinstance
(
op_data
.
data
,
(
torch
.
Tensor
,
int
,
float
)):
if
not
isinstance
(
op_data
.
data
,
(
torch
.
Tensor
,
int
,
float
)):
raise
TypeError
(
f
'
The operation data
{
name
}
is not a torch.Tensor/int/float.
'
)
raise
TypeError
(
f
"
The operation data
{
name
}
is not a torch.Tensor/int/float.
"
)
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
# since elementwise ops are not compute-intensive,
# since elementwise ops are not compute-intensive,
# we approximate the backward compute cost
# we approximate the backward compute cost
# to be twice the fwd compute cost
# to be twice the fwd compute cost
fwd_compute_cost
=
reduce
(
operator
.
mul
,
shape
)
fwd_compute_cost
=
reduce
(
operator
.
mul
,
shape
)
bwd_compute_cost
=
fwd_compute_cost
*
2
bwd_compute_cost
=
fwd_compute_cost
*
2
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
compute_cost
=
TrainCycleItem
(
bwd
=
bwd_compute_cost
,
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
total
=
fwd_compute_cost
+
bwd_compute_cost
)
)
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
# all input, output and outputs have the same shape
# all input, output and outputs have the same shape
shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
# compute fwd memory cost in bytes
# compute fwd memory cost in bytes
# as the elementwise ops are not memory-intensive
# as the elementwise ops are not memory-intensive
# we approximate the fwd mem
r
oy cost to be the output
# we approximate the fwd memo
r
y cost to be the output
# and the backward memory cost to be grad of input and other
# and the backward memory cost to be grad of input and other
input_bytes
=
self
.
_compute_size_in_bytes
(
strategy
,
'
input
'
)
input_bytes
=
self
.
_compute_size_in_bytes
(
strategy
,
"
input
"
)
other_bytes
=
self
.
_compute_size_in_bytes
(
strategy
,
'
other
'
)
other_bytes
=
self
.
_compute_size_in_bytes
(
strategy
,
"
other
"
)
output_bytes
=
self
.
_compute_size_in_bytes
(
strategy
,
'
output
'
)
output_bytes
=
self
.
_compute_size_in_bytes
(
strategy
,
"
output
"
)
fwd_memory_cost
=
MemoryCost
(
activation
=
output_bytes
)
fwd_memory_cost
=
MemoryCost
(
activation
=
output_bytes
)
bwd_memory_cost
=
MemoryCost
(
activation
=
input_bytes
+
other_bytes
)
bwd_memory_cost
=
MemoryCost
(
activation
=
input_bytes
+
other_bytes
)
total_memory_cost
=
MemoryCost
(
activation
=
input_bytes
+
other_bytes
+
output_bytes
)
total_memory_cost
=
MemoryCost
(
activation
=
input_bytes
+
other_bytes
+
output_bytes
)
...
@@ -66,7 +67,7 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
...
@@ -66,7 +67,7 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
def
enumerate_all_possible_output
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
enumerate_all_possible_output
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# we check for the output logical shape to get the number of dimensions
# we check for the output logical shape to get the number of dimensions
dim_partition_list
=
[]
dim_partition_list
=
[]
dim_size
=
len
(
self
.
op_data
[
'
output
'
].
logical_shape
)
dim_size
=
len
(
self
.
op_data
[
"
output
"
].
logical_shape
)
# enumerate all the 2D sharding cases
# enumerate all the 2D sharding cases
sharding_list_2d
=
enumerate_all_possible_2d_sharding
(
mesh_dim_0
,
mesh_dim_1
,
dim_size
)
sharding_list_2d
=
enumerate_all_possible_2d_sharding
(
mesh_dim_0
,
mesh_dim_1
,
dim_size
)
...
@@ -86,21 +87,22 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
...
@@ -86,21 +87,22 @@ class BinaryElementwiseStrategyGenerator(StrategyGenerator):
# convert these dim partition dict to sharding strategy
# convert these dim partition dict to sharding strategy
for
dim_partition_dict
in
dim_partition_list
:
for
dim_partition_dict
in
dim_partition_list
:
dim_partition_dict_mapping
=
dict
(
input
=
dim_partition_dict
,
dim_partition_dict_mapping
=
dict
(
other
=
dim_partition_dict
,
input
=
dim_partition_dict
,
other
=
dim_partition_dict
,
output
=
dim_partition_dict
output
=
dim_partition_dict
)
)
try
:
try
:
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
communication_action_mapping
=
{}
communication_action_mapping
=
{}
# get name
# get name
sharding_seq
=
sharding_spec_mapping
[
'
input
'
].
sharding_sequence
sharding_seq
=
sharding_spec_mapping
[
"
input
"
].
sharding_sequence
name
=
f
'
{
sharding_seq
}
=
{
sharding_seq
}
<binary-elementwise-op>
{
sharding_seq
}
'
name
=
f
"
{
sharding_seq
}
=
{
sharding_seq
}
<binary-elementwise-op>
{
sharding_seq
}
"
sharding_strategy
=
self
.
get_sharding_strategy
(
sharding_strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
,
)
strategy_list
.
append
(
sharding_strategy
)
strategy_list
.
append
(
sharding_strategy
)
except
ShardingSpecException
:
except
ShardingSpecException
:
continue
continue
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
View file @
9e768b59
import
copy
import
copy
import
operator
import
operator
import
warnings
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
CommType
,
MemoryCost
,
MemoryCost
,
ShardingStrategy
,
ShardingStrategy
,
...
@@ -24,29 +22,32 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -24,29 +22,32 @@ class ConvStrategyGenerator(StrategyGenerator):
"""
"""
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
'''
"""
In sanity check, we need make sure the input data having correct dimension size.
In sanity check, we need make sure the input data having correct dimension size.
For Conv1d, the dim of input data should be 3([N, C, L]).
For Conv1d, the dim of input data should be 3([N, C, L]).
For Conv2d, the dim of input data should be 4([N, C, H, W]).
For Conv2d, the dim of input data should be 4([N, C, H, W]).
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
'''
"""
input_op_data
=
self
.
op_data
[
'
input
'
]
input_op_data
=
self
.
op_data
[
"
input
"
]
assert
input_op_data
.
data
.
dim
()
in
(
assert
input_op_data
.
data
.
dim
()
in
(
3
,
4
,
5
),
f
'We suppose the dim of input fed into conv op should in range of [3, 5].'
3
,
4
,
5
,
),
f
"We suppose the dim of input fed into conv op should in range of [3, 5]."
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the computation cost per device with this specific strategy.
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be d
e
vided by TFLOPS, now it just shows the computation size.
Note: compute_cost need to be d
i
vided by TFLOPS, now it just shows the computation size.
'''
"""
# TODO: compute_cost need to be d
e
vided by TFLOPS, now it just shows the computation size.
# TODO: compute_cost need to be d
i
vided by TFLOPS, now it just shows the computation size.
# 1D: (L) * N * Cout * Cin * kernel
# 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
sharded_other_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
other
'
]].
get_sharded_shape_per_device
()
sharded_other_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
other
"
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
output
'
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
output
"
]].
get_sharded_shape_per_device
()
if
self
.
has_bias
:
if
self
.
has_bias
:
# bias add is an element wise operation, so the cost is equal to product of output shape.
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost
=
reduce
(
operator
.
mul
,
sharded_output_shape
)
bias_compute_cost
=
reduce
(
operator
.
mul
,
sharded_output_shape
)
...
@@ -76,14 +77,14 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -76,14 +77,14 @@ class ConvStrategyGenerator(StrategyGenerator):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
forward_size_mapping
=
{
forward_size_mapping
=
{
'
input
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
"
input
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'
other
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
"
other
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
,
}
}
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_size
=
self
.
_compute_size_in_bytes
(
strategy
,
"bias"
)
bias_size
=
self
.
_compute_size_in_bytes
(
strategy
,
"bias"
)
forward_size_mapping
[
'
bias
'
]
=
bias_size
forward_size_mapping
[
"
bias
"
]
=
bias_size
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
.
pop
(
"output"
)
backward_size_mapping
.
pop
(
"output"
)
...
@@ -100,26 +101,20 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -100,26 +101,20 @@ class ConvStrategyGenerator(StrategyGenerator):
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
total_mem_cost
=
MemoryCost
(
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_batch_weight_out_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_input_batch_weight_out_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
R x RS
{
mesh_dim_1
}
'
name
=
f
"
S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
R x RS
{
mesh_dim_1
}
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
]},
0
:
[
mesh_dim_0
]
"other"
:
{
1
:
[
mesh_dim_1
]},
},
"output"
:
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]},
"other"
:
{
1
:
[
mesh_dim_1
]
},
"output"
:
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]
},
}
}
if
self
.
has_bias
:
if
self
.
has_bias
:
dim_partition_dict_mapping
[
"bias"
]
=
{
0
:
[
mesh_dim_1
]}
dim_partition_dict_mapping
[
"bias"
]
=
{
0
:
[
mesh_dim_1
]}
...
@@ -132,7 +127,8 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -132,7 +127,8 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
,
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
communication_action_mapping
=
{
"input"
:
input_comm_action
}
communication_action_mapping
=
{
"input"
:
input_comm_action
}
if
self
.
is_param
(
"other"
):
if
self
.
is_param
(
"other"
):
...
@@ -140,7 +136,8 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -140,7 +136,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
...
@@ -148,38 +145,41 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -148,38 +145,41 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
)
communication_action_mapping
[
"other"
]
=
other_comm_action
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
if
self
.
is_param
(
'
bias
'
):
if
self
.
is_param
(
"
bias
"
):
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
key_for_kwarg
=
"bias"
,
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_batch
(
self
,
mesh_dim_0
):
def
split_input_batch
(
self
,
mesh_dim_0
):
name
=
f
'
S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
R x RR
'
name
=
f
"
S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
R x RR
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
]},
0
:
[
mesh_dim_0
]
},
"other"
:
{},
"other"
:
{},
"output"
:
{
"output"
:
{
0
:
[
mesh_dim_0
],
0
:
[
mesh_dim_0
],
...
@@ -196,7 +196,8 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -196,7 +196,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
...
@@ -204,42 +205,45 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -204,42 +205,45 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
)
communication_action_mapping
[
"other"
]
=
other_comm_action
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
if
self
.
is_param
(
'
bias
'
):
if
self
.
is_param
(
"
bias
"
):
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
key_for_kwarg
=
"bias"
,
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_both_dim_weight_in_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_input_both_dim_weight_in_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x S
{
mesh_dim_1
}
R
'
name
=
f
"
S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x S
{
mesh_dim_1
}
R
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
],
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
],
1
:
[
mesh_dim_1
],
},
},
"other"
:
{
"other"
:
{
0
:
[
mesh_dim_1
]},
0
:
[
mesh_dim_1
]
},
"output"
:
{
"output"
:
{
0
:
[
mesh_dim_0
],
0
:
[
mesh_dim_0
],
},
},
...
@@ -254,7 +258,8 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -254,7 +258,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"output"
],
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_1
,
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
AFTER
,
)
communication_action_mapping
=
{
"output"
:
output_comm_action
}
communication_action_mapping
=
{
"output"
:
output_comm_action
}
...
@@ -263,7 +268,8 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -263,7 +268,8 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
...
@@ -271,7 +277,8 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -271,7 +277,8 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
)
communication_action_mapping
[
"other"
]
=
other_comm_action
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
if
self
.
is_param
(
"bias"
):
if
self
.
is_param
(
"bias"
):
...
@@ -279,23 +286,27 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -279,23 +286,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
key_for_kwarg
=
"bias"
,
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_in_channel_weight_both_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_input_in_channel_weight_both_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
RS
{
mesh_dim_1
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
'
name
=
f
"
RS
{
mesh_dim_1
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
...
@@ -322,23 +333,27 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -322,23 +333,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"output"
],
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
AFTER
,
)
input_comm_action
=
self
.
get_communication_action
(
input_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"input"
],
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
,
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
communication_action_mapping
=
{
"output"
:
output_comm_action
,
"input"
:
input_comm_action
}
communication_action_mapping
=
{
"output"
:
output_comm_action
,
"input"
:
input_comm_action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_in_channel_weight_in_channel
(
self
,
mesh_dim_0
):
def
split_input_in_channel_weight_in_channel
(
self
,
mesh_dim_0
):
name
=
f
'
RR = RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
R
'
name
=
f
"
RR = RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
R
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
...
@@ -360,17 +375,20 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -360,17 +375,20 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"output"
],
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
AFTER
,
)
communication_action_mapping
=
{
"output"
:
output_comm_action
}
communication_action_mapping
=
{
"output"
:
output_comm_action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_weight_out_channel
(
self
,
mesh_dim_0
):
def
split_weight_out_channel
(
self
,
mesh_dim_0
):
name
=
f
'
RS
{
mesh_dim_0
}
= RR x RS
{
mesh_dim_0
}
'
name
=
f
"
RS
{
mesh_dim_0
}
= RR x RS
{
mesh_dim_0
}
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{},
"input"
:
{},
...
@@ -395,17 +413,20 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -395,17 +413,20 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
communication_action_mapping
=
{
"input"
:
input_comm_action
}
communication_action_mapping
=
{
"input"
:
input_comm_action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
non_split
(
self
):
def
non_split
(
self
):
name
=
f
'
RR = RR x RR
'
name
=
f
"
RR = RR x RR
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{},
"input"
:
{},
...
@@ -418,13 +439,13 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -418,13 +439,13 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
{}
communication_action_mapping
=
{}
)
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_1d_parallel_on_input_batch
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_1d_parallel_on_input_batch
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
S
{
mesh_dim_0
}{
mesh_dim_1
}
R = S
{
mesh_dim_0
}{
mesh_dim_1
}
R x RR
'
name
=
f
"
S
{
mesh_dim_0
}{
mesh_dim_1
}
R = S
{
mesh_dim_0
}{
mesh_dim_1
}
R x RR
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
...
@@ -447,14 +468,16 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -447,14 +468,16 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
)
communication_action_mapping
[
"other"
]
=
other_comm_action
communication_action_mapping
[
"other"
]
=
other_comm_action
...
@@ -464,23 +487,27 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -464,23 +487,27 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
key_for_kwarg
=
"bias"
,
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_1d_parallel_on_in_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_1d_parallel_on_in_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
RR = RS
{
mesh_dim_0
}{
mesh_dim_1
}
x S
{
mesh_dim_0
}{
mesh_dim_1
}
R
'
name
=
f
"
RR = RS
{
mesh_dim_0
}{
mesh_dim_1
}
x S
{
mesh_dim_0
}{
mesh_dim_1
}
R
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
1
:
[
mesh_dim_0
,
mesh_dim_1
],
1
:
[
mesh_dim_0
,
mesh_dim_1
],
...
@@ -501,17 +528,20 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -501,17 +528,20 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"output"
],
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
AFTER
,
)
communication_action_mapping
=
{
"output"
:
output_comm_action
}
communication_action_mapping
=
{
"output"
:
output_comm_action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_1d_parallel_on_out_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_1d_parallel_on_out_channel
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
RS
{
mesh_dim_0
}{
mesh_dim_1
}
= RR x RS
{
mesh_dim_0
}{
mesh_dim_1
}
'
name
=
f
"
RS
{
mesh_dim_0
}{
mesh_dim_1
}
= RR x RS
{
mesh_dim_0
}{
mesh_dim_1
}
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{},
"input"
:
{},
"other"
:
{
"other"
:
{
...
@@ -535,13 +565,16 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -535,13 +565,16 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
communication_action_mapping
=
{
"input"
:
input_comm_action
}
communication_action_mapping
=
{
"input"
:
input_comm_action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategies
=
[]
strategies
=
[]
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/embedding_generator.py
View file @
9e768b59
import
copy
import
copy
import
operator
import
operator
import
warnings
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
CommType
,
MemoryCost
,
MemoryCost
,
ShardingStrategy
,
ShardingStrategy
,
...
@@ -27,16 +25,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -27,16 +25,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
return
super
().
validate
()
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the computation cost per device with this specific strategy.
Compute the computation cost per device with this specific strategy.
Note: The computation cost for the embedding handler is estimated as dense computing now.
Note: The computation cost for the embedding handler is estimated as dense computing now.
It may not be accurate.
It may not be accurate.
'''
"""
# TODO: estimate the embedding computation cost as sparse operation
# TODO: estimate the embedding computation cost as sparse operation
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
sharded_other_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
other
'
]].
get_sharded_shape_per_device
()
sharded_other_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
other
"
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
output
'
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
output
"
]].
get_sharded_shape_per_device
()
input_size_product
=
reduce
(
operator
.
mul
,
sharded_input_shape
)
input_size_product
=
reduce
(
operator
.
mul
,
sharded_input_shape
)
other_size_product
=
reduce
(
operator
.
mul
,
sharded_other_shape
)
other_size_product
=
reduce
(
operator
.
mul
,
sharded_other_shape
)
...
@@ -55,9 +53,9 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -55,9 +53,9 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
forward_size_mapping
=
{
forward_size_mapping
=
{
'
input
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
"
input
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'
other
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
"
other
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
,
}
}
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
...
@@ -75,14 +73,15 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -75,14 +73,15 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
total_mem_cost
=
MemoryCost
(
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
@
ignore_sharding_exception
@
ignore_sharding_exception
def
non_split
(
self
):
def
non_split
(
self
):
name
=
f
'
RR = R x RR
'
name
=
f
"
RR = R x RR
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{},
"input"
:
{},
...
@@ -92,18 +91,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -92,18 +91,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
{}
communication_action_mapping
=
{}
)
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input
(
self
,
mesh_dim_0
):
def
split_input
(
self
,
mesh_dim_0
):
name
=
f
'
S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
x RR
'
name
=
f
"
S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
x RR
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
]},
0
:
[
mesh_dim_0
]
},
"other"
:
{},
"other"
:
{},
"output"
:
{
"output"
:
{
0
:
[
mesh_dim_0
],
0
:
[
mesh_dim_0
],
...
@@ -118,7 +115,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -118,7 +115,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
...
@@ -126,17 +124,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -126,17 +124,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
)
communication_action_mapping
[
"other"
]
=
other_comm_action
communication_action_mapping
[
"other"
]
=
other_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_and_embedding_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_input_and_embedding_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
x RS
{
mesh_dim_1
}
'
name
=
f
"
S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
x RS
{
mesh_dim_1
}
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
...
@@ -159,7 +160,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -159,7 +160,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
,
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
communication_action_mapping
=
{
"input"
:
input_comm_action
}
communication_action_mapping
=
{
"input"
:
input_comm_action
}
if
self
.
is_param
(
"other"
):
if
self
.
is_param
(
"other"
):
...
@@ -167,7 +169,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -167,7 +169,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
...
@@ -175,22 +178,23 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -175,22 +178,23 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
)
communication_action_mapping
[
"other"
]
=
other_comm_action
communication_action_mapping
[
"other"
]
=
other_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_1d_parallel_on_input
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_1d_parallel_on_input
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
S
{
mesh_dim_0
}{
mesh_dim_1
}
R = S
{
mesh_dim_0
}{
mesh_dim_1
}
x RR
'
name
=
f
"
S
{
mesh_dim_0
}{
mesh_dim_1
}
R = S
{
mesh_dim_0
}{
mesh_dim_1
}
x RR
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"other"
:
{},
"other"
:
{},
"output"
:
{
"output"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
],
0
:
[
mesh_dim_0
,
mesh_dim_1
],
...
@@ -207,7 +211,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -207,7 +211,8 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
...
@@ -215,17 +220,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -215,17 +220,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
)
communication_action_mapping
[
"other"
]
=
other_comm_action
communication_action_mapping
[
"other"
]
=
other_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_embedding_dim
(
self
,
mesh_dim_0
):
def
split_embedding_dim
(
self
,
mesh_dim_0
):
name
=
f
'
RS
{
mesh_dim_0
}
= R x RS
{
mesh_dim_0
}
'
name
=
f
"
RS
{
mesh_dim_0
}
= R x RS
{
mesh_dim_0
}
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{},
"input"
:
{},
...
@@ -245,17 +253,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -245,17 +253,20 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
communication_action_mapping
=
{
"input"
:
input_comm_action
}
communication_action_mapping
=
{
"input"
:
input_comm_action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_1d_parallel_on_embedding_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_1d_parallel_on_embedding_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
RS
{
mesh_dim_0
}{
mesh_dim_1
}
= R x RS
{
mesh_dim_0
}{
mesh_dim_1
}
'
name
=
f
"
RS
{
mesh_dim_0
}{
mesh_dim_1
}
= R x RS
{
mesh_dim_0
}{
mesh_dim_1
}
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{},
"input"
:
{},
...
@@ -275,13 +286,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
...
@@ -275,13 +286,16 @@ class EmbeddingStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
communication_action_mapping
=
{
"input"
:
input_comm_action
}
communication_action_mapping
=
{
"input"
:
input_comm_action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategies
=
[]
strategies
=
[]
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getattr_generator.py
View file @
9e768b59
...
@@ -10,7 +10,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
...
@@ -10,7 +10,7 @@ from colossalai.tensor.sharding_spec import ShardingSpecException
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
__all__
=
[
'
GetattrGenerator
'
]
__all__
=
[
"
GetattrGenerator
"
]
class
GetattrGenerator
(
StrategyGenerator
):
class
GetattrGenerator
(
StrategyGenerator
):
...
@@ -26,10 +26,10 @@ class GetattrGenerator(StrategyGenerator):
...
@@ -26,10 +26,10 @@ class GetattrGenerator(StrategyGenerator):
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the memory cost per device with this specific strategy.
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping
=
{
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)}
forward_size_mapping
=
{
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)}
# compute fwd cost incurred
# compute fwd cost incurred
# fwd_cost = output
# fwd_cost = output
...
@@ -47,7 +47,7 @@ class GetattrGenerator(StrategyGenerator):
...
@@ -47,7 +47,7 @@ class GetattrGenerator(StrategyGenerator):
def
enumerate_all_possible_output
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
enumerate_all_possible_output
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# we check for the output logical shape to get the number of dimensions
# we check for the output logical shape to get the number of dimensions
dim_partition_list
=
[]
dim_partition_list
=
[]
dim_size
=
len
(
self
.
op_data
[
'
output
'
].
logical_shape
)
dim_size
=
len
(
self
.
op_data
[
"
output
"
].
logical_shape
)
# enumerate all the 2D sharding cases
# enumerate all the 2D sharding cases
sharding_list_2d
=
enumerate_all_possible_2d_sharding
(
mesh_dim_0
,
mesh_dim_1
,
dim_size
)
sharding_list_2d
=
enumerate_all_possible_2d_sharding
(
mesh_dim_0
,
mesh_dim_1
,
dim_size
)
...
@@ -78,7 +78,8 @@ class GetattrGenerator(StrategyGenerator):
...
@@ -78,7 +78,8 @@ class GetattrGenerator(StrategyGenerator):
sharding_strategy
=
self
.
get_sharding_strategy
(
sharding_strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
,
)
strategy_list
.
append
(
sharding_strategy
)
strategy_list
.
append
(
sharding_strategy
)
except
ShardingSpecException
:
except
ShardingSpecException
:
continue
continue
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
View file @
9e768b59
import
copy
import
copy
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
CommType
,
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.logging
import
get_dist_logger
from
colossalai.logging
import
get_dist_logger
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.sharding_spec
import
ShardingSpecException
from
colossalai.tensor.sharding_spec
import
ShardingSpecException
from
.strategy_generator
import
FollowingStrategyGenerator
from
.strategy_generator
import
FollowingStrategyGenerator
__all__
=
[
'
GetItemStrategyGenerator
'
,
'
TensorStrategyGenerator
'
,
'
TensorTupleStrategyGenerator
'
]
__all__
=
[
"
GetItemStrategyGenerator
"
,
"
TensorStrategyGenerator
"
,
"
TensorTupleStrategyGenerator
"
]
class
GetItemStrategyGenerator
(
FollowingStrategyGenerator
):
class
GetItemStrategyGenerator
(
FollowingStrategyGenerator
):
...
@@ -35,12 +29,12 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
...
@@ -35,12 +29,12 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the memory cost per device with this specific strategy.
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping
=
{
forward_size_mapping
=
{
'
input
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
"
input
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
,
}
}
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
...
@@ -58,27 +52,29 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
...
@@ -58,27 +52,29 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
total_mem_cost
=
MemoryCost
(
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
class
TensorStrategyGenerator
(
GetItemStrategyGenerator
):
class
TensorStrategyGenerator
(
GetItemStrategyGenerator
):
'''
"""
Deal with case 1 and 2.
Deal with case 1 and 2.
'''
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
getitem_index
=
self
.
op_data
[
'
index
'
].
data
getitem_index
=
self
.
op_data
[
"
index
"
].
data
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
try
:
try
:
logger
=
get_dist_logger
()
logger
=
get_dist_logger
()
dim_partition_dict_mapping
=
{}
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
communication_action_mapping
=
{}
dim_partition_dict_for_input
=
copy
.
deepcopy
(
dim_partition_dict_for_input
=
copy
.
deepcopy
(
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]].
dim_partition_dict
)
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]].
dim_partition_dict
)
int_index
=
False
int_index
=
False
if
isinstance
(
getitem_index
,
int
):
if
isinstance
(
getitem_index
,
int
):
...
@@ -120,9 +116,11 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
...
@@ -120,9 +116,11 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
_
{
index
}
'
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
except
ShardingSpecException
as
e
:
except
ShardingSpecException
as
e
:
logger
.
debug
(
e
)
logger
.
debug
(
e
)
continue
continue
...
@@ -137,9 +135,9 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
...
@@ -137,9 +135,9 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
class
TensorTupleStrategyGenerator
(
GetItemStrategyGenerator
):
class
TensorTupleStrategyGenerator
(
GetItemStrategyGenerator
):
'''
"""
Deal with case 3.
Deal with case 3.
'''
"""
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
...
@@ -158,13 +156,15 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
...
@@ -158,13 +156,15 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
sharding_spec_mapping
[
"input"
]
=
sharding_spec_for_input
sharding_spec_mapping
[
"input"
]
=
sharding_spec_for_input
input_sharding_info
=
f
"get the
{
index
}
element from ("
input_sharding_info
=
f
"get the
{
index
}
element from ("
for
sharding_spec
in
sharding_spec_for_input
:
for
sharding_spec
in
sharding_spec_for_input
:
input_sharding_info
+=
f
'
{
sharding_spec
.
sharding_sequence
}
,
'
input_sharding_info
+=
f
"
{
sharding_spec
.
sharding_sequence
}
,
"
input_sharding_info
+=
")"
input_sharding_info
+=
")"
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
input_sharding_info
}
_
{
strategy_index
}
'
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
input_sharding_info
}
_
{
strategy_index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
View file @
9e768b59
...
@@ -18,7 +18,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
...
@@ -18,7 +18,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
__all__
=
[
'
LayerNormGenerator
'
]
__all__
=
[
"
LayerNormGenerator
"
]
class
LayerNormGenerator
(
StrategyGenerator
):
class
LayerNormGenerator
(
StrategyGenerator
):
...
@@ -31,21 +31,21 @@ class LayerNormGenerator(StrategyGenerator):
...
@@ -31,21 +31,21 @@ class LayerNormGenerator(StrategyGenerator):
return
super
().
validate
()
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the computation cost per device with this specific strategy.
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be d
e
vided by TFLOPS, now it just shows the computation size.
Note: compute_cost need to be d
i
vided by TFLOPS, now it just shows the computation size.
'''
"""
# TODO: compute_cost need to be d
e
vided by TFLOPS, now it just shows the computation size.
# TODO: compute_cost need to be d
i
vided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
# TODO: a constant coefficient need to be added.
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
sharded_weight_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
other
'
]].
get_sharded_shape_per_device
()
sharded_weight_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
other
"
]].
get_sharded_shape_per_device
()
if
self
.
has_bias
:
if
self
.
has_bias
:
# bias add is an element wise operation, so the cost is equal to product of output shape.
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost
=
reduce
(
operator
.
mul
,
sharded_weight_shape
)
bias_compute_cost
=
reduce
(
operator
.
mul
,
sharded_weight_shape
)
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
input_batch_shape
=
sharded_input_shape
[:
-
len
(
sharded_weight_shape
)]
input_batch_shape
=
sharded_input_shape
[:
-
len
(
sharded_weight_shape
)]
input_batch_product
=
reduce
(
operator
.
mul
,
input_batch_shape
,
1
)
input_batch_product
=
reduce
(
operator
.
mul
,
input_batch_shape
,
1
)
norm_kernel_product
=
reduce
(
operator
.
mul
,
sharded_weight_shape
,
1
)
norm_kernel_product
=
reduce
(
operator
.
mul
,
sharded_weight_shape
,
1
)
forward_compute_cost
=
input_batch_product
*
norm_kernel_product
forward_compute_cost
=
input_batch_product
*
norm_kernel_product
...
@@ -62,18 +62,18 @@ class LayerNormGenerator(StrategyGenerator):
...
@@ -62,18 +62,18 @@ class LayerNormGenerator(StrategyGenerator):
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the memory cost per device with this specific strategy.
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping
=
{
forward_size_mapping
=
{
'
input
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
"
input
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'
other
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
"
other
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
,
}
}
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_size
=
self
.
_compute_size_in_bytes
(
strategy
,
"bias"
)
bias_size
=
self
.
_compute_size_in_bytes
(
strategy
,
"bias"
)
forward_size_mapping
[
'
bias
'
]
=
bias_size
forward_size_mapping
[
"
bias
"
]
=
bias_size
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
.
pop
(
"output"
)
backward_size_mapping
.
pop
(
"output"
)
...
@@ -90,8 +90,9 @@ class LayerNormGenerator(StrategyGenerator):
...
@@ -90,8 +90,9 @@ class LayerNormGenerator(StrategyGenerator):
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
total_mem_cost
=
MemoryCost
(
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
...
@@ -120,7 +121,8 @@ class LayerNormGenerator(StrategyGenerator):
...
@@ -120,7 +121,8 @@ class LayerNormGenerator(StrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
"other"
],
sharding_spec
=
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
total_mesh_dim_list
,
logical_process_axis
=
total_mesh_dim_list
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
communication_action_mapping
[
"other"
]
=
other_comm_action
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
...
@@ -128,12 +130,15 @@ class LayerNormGenerator(StrategyGenerator):
...
@@ -128,12 +130,15 @@ class LayerNormGenerator(StrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
"bias"
],
sharding_spec
=
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
total_mesh_dim_list
,
logical_process_axis
=
total_mesh_dim_list
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
communication_action_mapping
[
"bias"
]
=
bias_comm_action
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
return
strategy
return
strategy
...
@@ -155,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator):
...
@@ -155,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator):
@
ignore_sharding_exception
@
ignore_sharding_exception
def
non_split
(
self
):
def
non_split
(
self
):
name
=
f
'
RR = RR x R
'
name
=
f
"
RR = RR x R
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{},
"input"
:
{},
"other"
:
{},
"other"
:
{},
...
@@ -168,14 +173,16 @@ class LayerNormGenerator(StrategyGenerator):
...
@@ -168,14 +173,16 @@ class LayerNormGenerator(StrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
'''
"""
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
'''
"""
strategy_list
=
[]
strategy_list
=
[]
input_data_dim
=
len
(
self
.
op_data
[
"input"
].
logical_shape
)
input_data_dim
=
len
(
self
.
op_data
[
"input"
].
logical_shape
)
weight_data_dim
=
len
(
self
.
op_data
[
"other"
].
logical_shape
)
weight_data_dim
=
len
(
self
.
op_data
[
"other"
].
logical_shape
)
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
View file @
9e768b59
import
operator
import
operator
from
ast
import
arg
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
List
from
typing
import
List
...
@@ -24,14 +23,14 @@ class MatMulStrategyGenerator(StrategyGenerator):
...
@@ -24,14 +23,14 @@ class MatMulStrategyGenerator(StrategyGenerator):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
size_mapping
=
{
size_mapping
=
{
'
input
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
"
input
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'
other
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
"
other
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
,
}
}
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_size
=
self
.
_compute_size_in_bytes
(
strategy
,
"bias"
)
bias_size
=
self
.
_compute_size_in_bytes
(
strategy
,
"bias"
)
size_mapping
[
'
bias
'
]
=
bias_size
size_mapping
[
"
bias
"
]
=
bias_size
# compute fwd cost incurred
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
# fwd_cost = input + other + bias + output
...
@@ -41,45 +40,47 @@ class MatMulStrategyGenerator(StrategyGenerator):
...
@@ -41,45 +40,47 @@ class MatMulStrategyGenerator(StrategyGenerator):
# compute bwd cost incurred
# compute bwd cost incurred
# bwd_cost = input_grad + bias_grad
# bwd_cost = input_grad + bias_grad
bwd_activation_cost
=
sum
([
v
for
k
,
v
in
size_mapping
.
items
()
if
k
in
[
'
input
'
,
'
other
'
,
'
bias
'
]])
bwd_activation_cost
=
sum
([
v
for
k
,
v
in
size_mapping
.
items
()
if
k
in
[
"
input
"
,
"
other
"
,
"
bias
"
]])
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
0
)
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
0
)
# compute total cost
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
total_mem_cost
=
MemoryCost
(
parameter
=
fwd_parameter_cost
+
0
)
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
0
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
class
DotProductStrategyGenerator
(
MatMulStrategyGenerator
):
class
DotProductStrategyGenerator
(
MatMulStrategyGenerator
):
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
input_op_data
=
self
.
op_data
[
'
input
'
]
input_op_data
=
self
.
op_data
[
"
input
"
]
other_op_data
=
self
.
op_data
[
'
other
'
]
other_op_data
=
self
.
op_data
[
"
other
"
]
assert
input_op_data
.
data
.
dim
()
==
1
and
other_op_data
.
data
.
dim
()
==
1
assert
input_op_data
.
data
.
dim
()
==
1
and
other_op_data
.
data
.
dim
()
==
1
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
fwd_compute_cost
=
sharded_input_shape
[
0
]
fwd_compute_cost
=
sharded_input_shape
[
0
]
bwd_compute_cost
=
fwd_compute_cost
*
2
bwd_compute_cost
=
fwd_compute_cost
*
2
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
compute_cost
=
TrainCycleItem
(
bwd
=
bwd_compute_cost
,
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
total
=
fwd_compute_cost
+
bwd_compute_cost
)
)
return
compute_cost
return
compute_cost
@
ignore_sharding_exception
@
ignore_sharding_exception
def
no_split
(
self
):
def
no_split
(
self
):
name
=
f
'
R = R dot R
'
name
=
f
"
R = R dot R
"
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{},
'
bias
'
:
{}}
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{},
"
bias
"
:
{}}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
communication_action_mapping
=
{}
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_one_dim
(
self
,
mesh_dim
):
def
split_one_dim
(
self
,
mesh_dim
):
name
=
f
'
R = S
{
mesh_dim
}
dot S
{
mesh_dim
}
'
name
=
f
"
R = S
{
mesh_dim
}
dot S
{
mesh_dim
}
"
# get sharding spec
# get sharding spec
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{
0
:
[
mesh_dim
]},
"output"
:
{},
"bias"
:
{
0
:
[
mesh_dim
]}}
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{
0
:
[
mesh_dim
]},
"output"
:
{},
"bias"
:
{
0
:
[
mesh_dim
]}}
...
@@ -87,14 +88,17 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
...
@@ -87,14 +88,17 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
# get communication action
# get communication action
output_comm_action
=
self
.
get_communication_action
(
output_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
output
'
],
sharding_spec
=
sharding_spec_mapping
[
"
output
"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
AFTER
,
)
communication_action_mapping
=
{
"output"
:
output_comm_action
}
communication_action_mapping
=
{
"output"
:
output_comm_action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
...
@@ -112,19 +116,18 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
...
@@ -112,19 +116,18 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
class
MatVecStrategyGenerator
(
MatMulStrategyGenerator
):
class
MatVecStrategyGenerator
(
MatMulStrategyGenerator
):
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
input_op_data
=
self
.
op_data
[
'
input
'
]
input_op_data
=
self
.
op_data
[
"
input
"
]
other_op_data
=
self
.
op_data
[
'
other
'
]
other_op_data
=
self
.
op_data
[
"
other
"
]
assert
input_op_data
.
data
.
dim
()
==
2
and
other_op_data
.
data
.
dim
()
==
1
assert
input_op_data
.
data
.
dim
()
==
2
and
other_op_data
.
data
.
dim
()
==
1
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
fwd_compute_cost
=
sharded_input_shape
[
0
]
fwd_compute_cost
=
sharded_input_shape
[
0
]
bwd_compute_cost
=
fwd_compute_cost
*
2
bwd_compute_cost
=
fwd_compute_cost
*
2
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
compute_cost
=
TrainCycleItem
(
bwd
=
bwd_compute_cost
,
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
total
=
fwd_compute_cost
+
bwd_compute_cost
)
)
return
compute_cost
return
compute_cost
@
ignore_sharding_exception
@
ignore_sharding_exception
...
@@ -133,67 +136,69 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
...
@@ -133,67 +136,69 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{}}
dim_partition_dict
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{}}
if
self
.
has_bias
:
if
self
.
has_bias
:
dim_partition_dict
[
'
bias
'
]
=
{}
dim_partition_dict
[
"
bias
"
]
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
{}
communication_action_mapping
=
{}
)
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_input_batch
(
self
,
mesh_dim
):
def
split_input_batch
(
self
,
mesh_dim
):
name
=
f
'
S
{
mesh_dim
}
R = S
{
mesh_dim
}
R x R
'
name
=
f
"
S
{
mesh_dim
}
R = S
{
mesh_dim
}
R x R
"
# get sharding spec
# get sharding spec
dim_partition_dict
=
{
dim_partition_dict
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim
]},
0
:
[
mesh_dim
]
},
"other"
:
{},
"other"
:
{},
"output"
:
{
"output"
:
{
0
:
[
mesh_dim
]},
0
:
[
mesh_dim
]
},
}
}
if
self
.
has_bias
:
if
self
.
has_bias
:
dim_partition_dict
[
'
bias
'
]
=
{}
dim_partition_dict
[
"
bias
"
]
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication action
# get communication action
communication_action_mapping
=
{}
communication_action_mapping
=
{}
if
self
.
is_param
(
'
other
'
):
if
self
.
is_param
(
"
other
"
):
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
other
'
],
sharding_spec
=
sharding_spec_mapping
[
"
other
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
other
'
],
sharding_spec
=
sharding_spec_mapping
[
"
other
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
communication_action_mapping
[
'other'
]
=
other_comm_action
)
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
if
self
.
is_param
(
'
bias
'
):
if
self
.
is_param
(
"
bias
"
):
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
bias
'
],
sharding_spec
=
sharding_spec_mapping
[
"
bias
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
bias
'
],
sharding_spec
=
sharding_spec_mapping
[
"
bias
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
2
)
arg_index
=
2
,
communication_action_mapping
[
'bias'
]
=
bias_comm_action
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
...
@@ -209,12 +214,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
...
@@ -209,12 +214,13 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
class
LinearProjectionStrategyGenerator
(
MatMulStrategyGenerator
):
class
LinearProjectionStrategyGenerator
(
MatMulStrategyGenerator
):
def
__init__
(
def
__init__
(
self
,
self
,
operation_data_mapping
,
operation_data_mapping
,
device_mesh
,
device_mesh
,
linear_projection_type
=
'linear'
,
linear_projection_type
=
"linear"
,
solver_perference
=
SolverPerference
.
STANDARD
):
solver_perference
=
SolverPerference
.
STANDARD
,
):
super
().
__init__
(
operation_data_mapping
,
device_mesh
)
super
().
__init__
(
operation_data_mapping
,
device_mesh
)
self
.
linear_projection_type
=
linear_projection_type
self
.
linear_projection_type
=
linear_projection_type
self
.
solver_perference
=
solver_perference
self
.
solver_perference
=
solver_perference
...
@@ -224,17 +230,17 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -224,17 +230,17 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# C: [M, N], A: [M, P], B: [P, N]
# C: [M, N], A: [M, P], B: [P, N]
# fwd cost = MNP (only count mul)
# fwd cost = MNP (only count mul)
# bwd: 2 x fwd_cost
# bwd: 2 x fwd_cost
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
sharded_other_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
other
'
]].
get_sharded_shape_per_device
()
sharded_other_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
other
"
]].
get_sharded_shape_per_device
()
dim_m_val
=
reduce
(
operator
.
mul
,
sharded_input_shape
[:
-
1
])
dim_m_val
=
reduce
(
operator
.
mul
,
sharded_input_shape
[:
-
1
])
dim_n_val
=
sharded_other_shape
[
-
1
]
dim_n_val
=
sharded_other_shape
[
-
1
]
dim_p_val
=
sharded_other_shape
[
0
]
dim_p_val
=
sharded_other_shape
[
0
]
fwd_compute_cost
=
dim_m_val
*
dim_n_val
*
dim_p_val
fwd_compute_cost
=
dim_m_val
*
dim_n_val
*
dim_p_val
bwd_compute_cost
=
fwd_compute_cost
*
2
bwd_compute_cost
=
fwd_compute_cost
*
2
compute_cost
=
TrainCycleItem
(
fwd
=
bwd_compute_cost
,
compute_cost
=
TrainCycleItem
(
bwd
=
bwd_compute_cost
,
fwd
=
bwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
total
=
fwd_compute_cost
+
bwd_compute_cost
)
)
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
def
dp_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
dp_strategies
(
self
)
->
List
[
ShardingStrategy
]:
...
@@ -301,28 +307,21 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -301,28 +307,21 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_lhs_space_rhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_lhs_space_rhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# handle case SS = SR x RS
# handle case SS = SR x RS
name
=
f
'
S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
R x RS
{
mesh_dim_1
}
'
name
=
f
"
S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
R x RS
{
mesh_dim_1
}
"
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
]},
0
:
[
mesh_dim_0
]
"other"
:
{
-
1
:
[
mesh_dim_1
]},
},
"output"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]},
"other"
:
{
-
1
:
[
mesh_dim_1
]
},
"output"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]
},
}
}
# linear bias only has one dimension, but addmm bias has same dimensions
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
# as the output logically.
if
self
.
linear_projection_type
==
'
linear
'
:
if
self
.
linear_projection_type
==
"
linear
"
:
dim_partition_dict_mapping
[
'
bias
'
]
=
{
-
1
:
[
mesh_dim_1
]}
dim_partition_dict_mapping
[
"
bias
"
]
=
{
-
1
:
[
mesh_dim_1
]}
elif
self
.
linear_projection_type
==
'
addmm
'
:
elif
self
.
linear_projection_type
==
"
addmm
"
:
dim_partition_dict_mapping
[
'
bias
'
]
=
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]}
dim_partition_dict_mapping
[
"
bias
"
]
=
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]}
else
:
else
:
raise
(
'
Unsupported linear projection type
'
)
raise
(
"
Unsupported linear projection type
"
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
...
@@ -333,75 +332,75 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -333,75 +332,75 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
,
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
if
self
.
is_param
(
'
other
'
):
if
self
.
is_param
(
"
other
"
):
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
)
communication_action_mapping
[
'
input
'
]
=
input_comm_action
communication_action_mapping
[
"
input
"
]
=
input_comm_action
communication_action_mapping
[
'
other
'
]
=
other_comm_action
communication_action_mapping
[
"
other
"
]
=
other_comm_action
# we only add allreduce comm action for linear bias, because
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
# allreduce comm action for addmm bias will be considered in post processing
if
self
.
has_bias
and
self
.
linear_projection_type
==
'
linear
'
:
if
self
.
has_bias
and
self
.
linear_projection_type
==
"
linear
"
:
if
self
.
is_param
(
'
bias
'
):
if
self
.
is_param
(
"
bias
"
):
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
key_for_kwarg
=
"bias"
,
communication_action_mapping
[
'bias'
]
=
bias_comm_action
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_lhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_lhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# handle the case SR = SS x SR
# handle the case SR = SS x SR
name
=
f
'
S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x S
{
mesh_dim_1
}
R
'
name
=
f
"
S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x S
{
mesh_dim_1
}
R
"
# get sharding spec mapping
# get sharding spec mapping
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]},
0
:
[
mesh_dim_0
],
"other"
:
{
0
:
[
mesh_dim_1
]},
-
1
:
[
mesh_dim_1
]
},
"other"
:
{
0
:
[
mesh_dim_1
]
},
"bias"
:
{},
"bias"
:
{},
"output"
:
{
"output"
:
{
0
:
[
mesh_dim_0
]},
0
:
[
mesh_dim_0
]
},
}
}
# linear bias only has one dimension, but addmm bias has same dimensions
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
# as the output logically.
if
self
.
linear_projection_type
==
'
linear
'
:
if
self
.
linear_projection_type
==
"
linear
"
:
dim_partition_dict_mapping
[
'
bias
'
]
=
{}
dim_partition_dict_mapping
[
"
bias
"
]
=
{}
elif
self
.
linear_projection_type
==
'
addmm
'
:
elif
self
.
linear_projection_type
==
"
addmm
"
:
dim_partition_dict_mapping
[
'
bias
'
]
=
{
0
:
[
mesh_dim_0
]}
dim_partition_dict_mapping
[
"
bias
"
]
=
{
0
:
[
mesh_dim_0
]}
else
:
else
:
raise
(
'
Unsupported linear projection type
'
)
raise
(
"
Unsupported linear projection type
"
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
...
@@ -412,66 +411,64 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -412,66 +411,64 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
"output"
],
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_1
,
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
AFTER
,
)
if
self
.
is_param
(
'
other
'
):
if
self
.
is_param
(
"
other
"
):
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
)
communication_action_mapping
[
'
other
'
]
=
other_comm_action
communication_action_mapping
[
"
other
"
]
=
other_comm_action
communication_action_mapping
[
'
output
'
]
=
output_comm_action
communication_action_mapping
[
"
output
"
]
=
output_comm_action
# we only add allreduce comm action for linear bias, because
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
# allreduce comm action for addmm bias will be considered in post processing
if
self
.
has_bias
and
self
.
linear_projection_type
==
'
linear
'
:
if
self
.
has_bias
and
self
.
linear_projection_type
==
"
linear
"
:
if
self
.
is_param
(
'
bias
'
):
if
self
.
is_param
(
"
bias
"
):
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
key_for_kwarg
=
"bias"
,
communication_action_mapping
[
'bias'
]
=
bias_comm_action
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_rhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_rhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
RS
{
mesh_dim_1
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
'
name
=
f
"
RS
{
mesh_dim_1
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
"
# get sharding specs
# get sharding specs
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
-
1
:
[
mesh_dim_0
]},
-
1
:
[
mesh_dim_0
]
"other"
:
{
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]},
},
"bias"
:
{
-
1
:
[
mesh_dim_1
]},
"other"
:
{
"output"
:
{
-
1
:
[
mesh_dim_1
]},
0
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]
},
"bias"
:
{
-
1
:
[
mesh_dim_1
]
},
"output"
:
{
-
1
:
[
mesh_dim_1
]
},
}
}
# We don't have to do anything special for bias here, because
# We don't have to do anything special for bias here, because
...
@@ -482,34 +479,34 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -482,34 +479,34 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
output_comm_action
=
self
.
get_communication_action
(
output_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
output
'
],
sharding_spec
=
sharding_spec_mapping
[
"
output
"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
AFTER
,
)
input_comm_action
=
self
.
get_communication_action
(
input_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
input
'
],
sharding_spec
=
sharding_spec_mapping
[
"
input
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
,
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
communication_action_mapping
[
"input"
]
=
input_comm_action
communication_action_mapping
[
"input"
]
=
input_comm_action
communication_action_mapping
[
'output'
]
=
output_comm_action
communication_action_mapping
[
"output"
]
=
output_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
recompute_split_both_contract
(
self
,
mesh_dim
):
def
recompute_split_both_contract
(
self
,
mesh_dim
):
name
=
f
'
RR = RS
{
mesh_dim
}
x S
{
mesh_dim
}
R
'
name
=
f
"
RR = RS
{
mesh_dim
}
x S
{
mesh_dim
}
R
"
# get sharding spec
# get sharding spec
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
-
1
:
[
mesh_dim
]},
-
1
:
[
mesh_dim
]
"other"
:
{
0
:
[
mesh_dim
]},
},
"other"
:
{
0
:
[
mesh_dim
]
},
"bias"
:
{},
"bias"
:
{},
"output"
:
{},
"output"
:
{},
}
}
...
@@ -520,32 +517,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -520,32 +517,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
# get communication action
communication_action_mapping
=
{}
communication_action_mapping
=
{}
output_comm_action
=
self
.
get_communication_action
(
output_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
output
'
],
sharding_spec
=
sharding_spec_mapping
[
"
output
"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
AFTER
,
)
communication_action_mapping
[
'output'
]
=
output_comm_action
communication_action_mapping
[
"output"
]
=
output_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_rhs_space_only
(
self
,
mesh_dim
):
def
split_rhs_space_only
(
self
,
mesh_dim
):
name
=
f
'
RS
{
mesh_dim
}
= RR x RS
{
mesh_dim
}
'
name
=
f
"
RS
{
mesh_dim
}
= RR x RS
{
mesh_dim
}
"
# get sharding spec
# get sharding spec
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{},
"input"
:
{},
"other"
:
{
"other"
:
{
-
1
:
[
mesh_dim
]},
-
1
:
[
mesh_dim
]
"bias"
:
{
-
1
:
[
mesh_dim
]},
},
"output"
:
{
-
1
:
[
mesh_dim
]},
"bias"
:
{
-
1
:
[
mesh_dim
]
},
"output"
:
{
-
1
:
[
mesh_dim
]
},
}
}
# We don't have to do anything special for bias here, because
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
# the bias is already the same sharding spec as the output.
...
@@ -554,93 +548,94 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -554,93 +548,94 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_comm_action
=
self
.
get_communication_action
(
input_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
input
'
],
sharding_spec
=
sharding_spec_mapping
[
"
input
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
communication_action_mapping
[
'input'
]
=
input_comm_action
communication_action_mapping
[
"input"
]
=
input_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_lhs_1st_dim_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_lhs_1st_dim_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
S
{
mesh_dim_0
}{
mesh_dim_1
}
R = S
{
mesh_dim_0
}{
mesh_dim_1
}
R x RR
'
name
=
f
"
S
{
mesh_dim_0
}{
mesh_dim_1
}
R = S
{
mesh_dim_0
}{
mesh_dim_1
}
R x RR
"
# get sharding spec
# get sharding spec
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"other"
:
{},
"other"
:
{},
"bias"
:
{},
"bias"
:
{},
"output"
:
{
"output"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
}
}
# linear bias only has one dimension, but addmm bias has same dimensions
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
# as the output logically.
if
self
.
linear_projection_type
==
'
linear
'
:
if
self
.
linear_projection_type
==
"
linear
"
:
dim_partition_dict_mapping
[
'
bias
'
]
=
{}
dim_partition_dict_mapping
[
"
bias
"
]
=
{}
elif
self
.
linear_projection_type
==
'
addmm
'
:
elif
self
.
linear_projection_type
==
"
addmm
"
:
dim_partition_dict_mapping
[
'
bias
'
]
=
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]}
dim_partition_dict_mapping
[
"
bias
"
]
=
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]}
else
:
else
:
raise
(
'
Unsupported linear projection type
'
)
raise
(
"
Unsupported linear projection type
"
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# get communication action
# get communication action
communication_action_mapping
=
{}
communication_action_mapping
=
{}
if
self
.
is_param
(
'
other
'
):
if
self
.
is_param
(
"
other
"
):
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
other
'
],
sharding_spec
=
sharding_spec_mapping
[
"
other
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
other
'
],
sharding_spec
=
sharding_spec_mapping
[
"
other
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
communication_action_mapping
[
'other'
]
=
other_comm_action
)
communication_action_mapping
[
"other"
]
=
other_comm_action
# we only add allreduce comm action for linear bias, because
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
# allreduce comm action for addmm bias will be considered in post processing
if
self
.
has_bias
and
self
.
linear_projection_type
==
'
linear
'
:
if
self
.
has_bias
and
self
.
linear_projection_type
==
"
linear
"
:
if
self
.
is_param
(
'
bias
'
):
if
self
.
is_param
(
"
bias
"
):
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
bias
'
],
sharding_spec
=
sharding_spec_mapping
[
"
bias
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
,
)
else
:
else
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
bias
'
],
sharding_spec
=
sharding_spec_mapping
[
"
bias
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
key_for_kwarg
=
"bias"
,
communication_action_mapping
[
'bias'
]
=
bias_comm_action
)
return
self
.
get_sharding_strategy
(
name
=
name
,
communication_action_mapping
[
"bias"
]
=
bias_comm_action
sharding_spec_mapping
=
sharding_spec_mapping
,
return
self
.
get_sharding_strategy
(
communication_action_mapping
=
communication_action_mapping
)
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_lhs_2nd_dim_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_lhs_2nd_dim_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
RR = RS
{
mesh_dim_0
}{
mesh_dim_1
}
x S
{
mesh_dim_0
}{
mesh_dim_1
}
R
'
name
=
f
"
RR = RS
{
mesh_dim_0
}{
mesh_dim_1
}
x S
{
mesh_dim_0
}{
mesh_dim_1
}
R
"
# get sharding spec
# get sharding spec
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{
"input"
:
{
-
1
:
[
mesh_dim_0
,
mesh_dim_1
]},
-
1
:
[
mesh_dim_0
,
mesh_dim_1
]
"other"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
},
"other"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"bias"
:
{},
"bias"
:
{},
"output"
:
{},
"output"
:
{},
}
}
...
@@ -652,32 +647,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -652,32 +647,29 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
# get communication action
communication_action_mapping
=
{}
communication_action_mapping
=
{}
output_comm_action
=
self
.
get_communication_action
(
output_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
output
'
],
sharding_spec
=
sharding_spec_mapping
[
"
output
"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
AFTER
,
communication_action_mapping
[
'output'
]
=
output_comm_action
)
communication_action_mapping
[
"output"
]
=
output_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_rhs_2nd_dim_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_rhs_2nd_dim_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
RS
{
mesh_dim_0
}{
mesh_dim_1
}
= RR x RS
{
mesh_dim_0
}{
mesh_dim_1
}
'
name
=
f
"
RS
{
mesh_dim_0
}{
mesh_dim_1
}
= RR x RS
{
mesh_dim_0
}{
mesh_dim_1
}
"
# get sharding spec
# get sharding spec
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"input"
:
{},
"input"
:
{},
"other"
:
{
"other"
:
{
-
1
:
[
mesh_dim_0
,
mesh_dim_1
]},
-
1
:
[
mesh_dim_0
,
mesh_dim_1
]
"bias"
:
{
-
1
:
[
mesh_dim_0
,
mesh_dim_1
]},
},
"output"
:
{
-
1
:
[
mesh_dim_0
,
mesh_dim_1
]},
"bias"
:
{
-
1
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"output"
:
{
-
1
:
[
mesh_dim_0
,
mesh_dim_1
]
},
}
}
# We don't have to do anything special for bias here, because
# We don't have to do anything special for bias here, because
...
@@ -687,20 +679,23 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -687,20 +679,23 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
# get communication action
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_comm_action
=
self
.
get_communication_action
(
input_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
input
'
],
sharding_spec
=
sharding_spec_mapping
[
"
input
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
communication_action_mapping
[
'input'
]
=
input_comm_action
)
communication_action_mapping
[
"input"
]
=
input_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
non_split
(
self
):
def
non_split
(
self
):
name
=
f
'
RR = RR x RR
'
name
=
f
"
RR = RR x RR
"
# get sharding spec
# get sharding spec
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
...
@@ -717,22 +712,24 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -717,22 +712,24 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
# get communication action
communication_action_mapping
=
{}
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
assert
"input"
in
self
.
op_data
assert
"input"
in
self
.
op_data
assert
"other"
in
self
.
op_data
assert
"other"
in
self
.
op_data
# make sure the other has 2 dim
# make sure the other has 2 dim
input_data
=
self
.
op_data
[
'
input
'
]
input_data
=
self
.
op_data
[
"
input
"
]
other_data
=
self
.
op_data
[
'
other
'
]
other_data
=
self
.
op_data
[
"
other
"
]
assert
input_data
.
data
.
dim
()
>
0
and
other_data
.
data
.
dim
()
==
2
assert
input_data
.
data
.
dim
()
>
0
and
other_data
.
data
.
dim
()
==
2
assert
other_data
.
logical_shape
[
0
]
==
input_data
.
logical_shape
[
-
1
]
assert
other_data
.
logical_shape
[
0
]
==
input_data
.
logical_shape
[
-
1
]
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_data
=
self
.
op_data
[
'
bias
'
]
bias_data
=
self
.
op_data
[
"
bias
"
]
assert
bias_data
.
logical_shape
[
-
1
]
==
other_data
.
logical_shape
[
-
1
]
assert
bias_data
.
logical_shape
[
-
1
]
==
other_data
.
logical_shape
[
-
1
]
...
@@ -757,37 +754,38 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -757,37 +754,38 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def
_pop_batch_dim_sharding_for_output
(
self
,
dim_partition_dict
):
def
_pop_batch_dim_sharding_for_output
(
self
,
dim_partition_dict
):
# remove partition dict for dim 0
# remove partition dict for dim 0
dim_partition_dict
[
'
output
'
].
pop
(
0
,
None
)
dim_partition_dict
[
"
output
"
].
pop
(
0
,
None
)
# decrease the remaining dim index by 1
# decrease the remaining dim index by 1
temp_dim_partition
=
{}
temp_dim_partition
=
{}
keys
=
list
(
dim_partition_dict
[
'
output
'
].
keys
())
keys
=
list
(
dim_partition_dict
[
"
output
"
].
keys
())
for
key
in
keys
:
for
key
in
keys
:
val
=
dim_partition_dict
[
'
output
'
].
pop
(
key
)
val
=
dim_partition_dict
[
"
output
"
].
pop
(
key
)
temp_dim_partition
[
key
-
1
]
=
val
temp_dim_partition
[
key
-
1
]
=
val
dim_partition_dict
[
'
output
'
].
update
(
temp_dim_partition
)
dim_partition_dict
[
"
output
"
].
update
(
temp_dim_partition
)
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
input_op_data
=
self
.
op_data
[
'
input
'
]
input_op_data
=
self
.
op_data
[
"
input
"
]
other_op_data
=
self
.
op_data
[
'
other
'
]
other_op_data
=
self
.
op_data
[
"
other
"
]
assert
len
(
input_op_data
.
logical_shape
)
==
3
or
len
(
other_op_data
.
logical_shape
)
==
3
assert
len
(
input_op_data
.
logical_shape
)
==
3
or
len
(
other_op_data
.
logical_shape
)
==
3
if
'
bias
'
in
self
.
op_data
:
if
"
bias
"
in
self
.
op_data
:
bias_op_data
=
self
.
op_data
[
'
bias
'
]
bias_op_data
=
self
.
op_data
[
"
bias
"
]
assert
bias_op_data
.
data
.
dim
()
<
3
and
len
(
bias_op_data
.
logical_shape
)
==
2
assert
bias_op_data
.
data
.
dim
()
<
3
and
len
(
bias_op_data
.
logical_shape
)
==
2
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
fwd_compute_cost
=
self
.
op_data
[
'input'
].
data
.
shape
[
-
1
]
*
reduce
(
operator
.
mul
,
fwd_compute_cost
=
self
.
op_data
[
"input"
].
data
.
shape
[
-
1
]
*
reduce
(
self
.
op_data
[
'output'
].
data
.
shape
)
operator
.
mul
,
self
.
op_data
[
"output"
].
data
.
shape
)
bwd_compute_cost
=
fwd_compute_cost
*
2
bwd_compute_cost
=
fwd_compute_cost
*
2
compute_cost
=
TrainCycleItem
(
fwd
=
fwd_compute_cost
,
compute_cost
=
TrainCycleItem
(
bwd
=
bwd_compute_cost
,
fwd
=
fwd_compute_cost
,
bwd
=
bwd_compute_cost
,
total
=
fwd_compute_cost
+
bwd_compute_cost
total
=
fwd_compute_cost
+
bwd_compute_cost
)
)
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_one_batch_dim
(
self
,
mesh_dim
):
def
split_one_batch_dim
(
self
,
mesh_dim
):
name
=
f
'
Sb
{
mesh_dim
}
= Sb
{
mesh_dim
}
x Sb
{
mesh_dim
}
'
name
=
f
"
Sb
{
mesh_dim
}
= Sb
{
mesh_dim
}
x Sb
{
mesh_dim
}
"
# get sharding_spec
# get sharding_spec
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{
0
:
[
mesh_dim
]},
"bias"
:
{},
"output"
:
{
0
:
[
mesh_dim
]}}
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{
0
:
[
mesh_dim
]},
"bias"
:
{},
"output"
:
{
0
:
[
mesh_dim
]}}
...
@@ -799,30 +797,27 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -799,30 +797,27 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
bias
'
],
sharding_spec
=
sharding_spec_mapping
[
"
bias
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
communication_action_mapping
[
'bias'
]
=
bias_comm_action
)
return
self
.
get_sharding_strategy
(
name
=
name
,
communication_action_mapping
[
"bias"
]
=
bias_comm_action
sharding_spec_mapping
=
sharding_spec_mapping
,
return
self
.
get_sharding_strategy
(
communication_action_mapping
=
communication_action_mapping
)
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_two_batch_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_two_batch_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
Sb
{
mesh_dim_0
}{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}{
mesh_dim_1
}
'
name
=
f
"
Sb
{
mesh_dim_0
}{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}{
mesh_dim_1
}
"
dim_partition_dict
=
{
dim_partition_dict
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
0
:
[
mesh_dim_0
,
mesh_dim_1
]
"other"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
},
"other"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"bias"
:
{},
"bias"
:
{},
"output"
:
{
"output"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]},
0
:
[
mesh_dim_0
,
mesh_dim_1
]
}
}
}
if
self
.
squeeze_batch_dim
:
if
self
.
squeeze_batch_dim
:
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
...
@@ -832,35 +827,28 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -832,35 +827,28 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
bias
'
],
sharding_spec
=
sharding_spec_mapping
[
"
bias
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
communication_action_mapping
[
'bias'
]
=
bias_comm_action
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_batch_dim_lhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_batch_dim_lhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
Sb
{
mesh_dim_0
}
Si
{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}
Si
{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}
'
name
=
f
"
Sb
{
mesh_dim_0
}
Si
{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}
Si
{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}
"
dim_partition_dict
=
{
dim_partition_dict
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]},
0
:
[
mesh_dim_0
],
"other"
:
{
0
:
[
mesh_dim_0
]},
1
:
[
mesh_dim_1
]
"bias"
:
{
0
:
[
mesh_dim_1
]},
},
"output"
:
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]},
"other"
:
{
0
:
[
mesh_dim_0
]
},
"bias"
:
{
0
:
[
mesh_dim_1
]
},
"output"
:
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]
}
}
}
if
self
.
squeeze_batch_dim
:
if
self
.
squeeze_batch_dim
:
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
...
@@ -869,46 +857,40 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -869,46 +857,40 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
other_comm_action
=
self
.
get_communication_action
(
other_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
other
'
],
sharding_spec
=
sharding_spec_mapping
[
"
other
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
,
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
arg_index
=
1
,
communication_action_mapping
[
'other'
]
=
other_comm_action
)
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
bias
'
],
sharding_spec
=
sharding_spec_mapping
[
"
bias
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
communication_action_mapping
[
'bias'
]
=
bias_comm_action
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
# for addbmm case, other is the third argument instead of second.
# for addbmm case, other is the third argument instead of second.
communication_action_mapping
[
'
other
'
].
arg_index
+=
1
communication_action_mapping
[
"
other
"
].
arg_index
+=
1
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_batch_dim_rhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_batch_dim_rhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
Sb
{
mesh_dim_0
}
Sj
{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}
R x Sb
{
mesh_dim_0
}
Sj
{
mesh_dim_1
}
'
name
=
f
"
Sb
{
mesh_dim_0
}
Sj
{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}
R x Sb
{
mesh_dim_0
}
Sj
{
mesh_dim_1
}
"
dim_partition_dict
=
{
dim_partition_dict
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
]},
0
:
[
mesh_dim_0
]
"other"
:
{
0
:
[
mesh_dim_0
],
2
:
[
mesh_dim_1
]},
},
"bias"
:
{
1
:
[
mesh_dim_1
]},
"other"
:
{
"output"
:
{
0
:
[
mesh_dim_0
],
2
:
[
mesh_dim_1
]},
0
:
[
mesh_dim_0
],
2
:
[
mesh_dim_1
]
},
"bias"
:
{
1
:
[
mesh_dim_1
]
},
"output"
:
{
0
:
[
mesh_dim_0
],
2
:
[
mesh_dim_1
]
}
}
}
if
self
.
squeeze_batch_dim
:
if
self
.
squeeze_batch_dim
:
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
...
@@ -917,43 +899,41 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -917,43 +899,41 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_comm_action
=
self
.
get_communication_action
(
input_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
input
'
],
sharding_spec
=
sharding_spec_mapping
[
"
input
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
,
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
communication_action_mapping
[
'input'
]
=
input_comm_action
)
communication_action_mapping
[
"input"
]
=
input_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
bias
'
],
sharding_spec
=
sharding_spec_mapping
[
"
bias
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
)
comm_type
=
CommType
.
BEFORE
,
communication_action_mapping
[
'bias'
]
=
bias_comm_action
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
# for addbmm case, other is the second argument instead of first.
# for addbmm case, other is the second argument instead of first.
communication_action_mapping
[
'
input
'
].
arg_index
+=
1
communication_action_mapping
[
"
input
"
].
arg_index
+=
1
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
@
ignore_sharding_exception
@
ignore_sharding_exception
def
split_batch_dim_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
def
split_batch_dim_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'
Sb
{
mesh_dim_0
}
R = Sb
{
mesh_dim_0
}
Sk
{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}
Sk
{
mesh_dim_1
}
'
name
=
f
"
Sb
{
mesh_dim_0
}
R = Sb
{
mesh_dim_0
}
Sk
{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}
Sk
{
mesh_dim_1
}
"
dim_partition_dict
=
{
dim_partition_dict
=
{
"input"
:
{
"input"
:
{
0
:
[
mesh_dim_0
],
2
:
[
mesh_dim_1
]},
0
:
[
mesh_dim_0
],
"other"
:
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]},
2
:
[
mesh_dim_1
]
},
"other"
:
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
]
},
"bias"
:
{},
"bias"
:
{},
"output"
:
{
"output"
:
{
0
:
[
mesh_dim_0
],
0
:
[
mesh_dim_0
],
}
}
,
}
}
if
self
.
squeeze_batch_dim
:
if
self
.
squeeze_batch_dim
:
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
...
@@ -962,29 +942,33 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -962,29 +942,33 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
output_comm_action
=
self
.
get_communication_action
(
output_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
output
'
],
sharding_spec
=
sharding_spec_mapping
[
"
output
"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_1
,
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
AFTER
)
comm_type
=
CommType
.
AFTER
,
communication_action_mapping
[
'output'
]
=
output_comm_action
)
communication_action_mapping
[
"output"
]
=
output_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'
bias
'
],
sharding_spec
=
sharding_spec_mapping
[
"
bias
"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
communication_action_mapping
[
'bias'
]
=
bias_comm_action
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
device_mesh_is_1d
=
True
device_mesh_is_1d
=
True
if
len
(
self
.
device_mesh
.
mesh_
shape
)
==
2
and
1
not
in
self
.
device_mesh
.
mesh_
shape
:
if
len
(
self
.
device_mesh
.
shape
)
==
2
and
1
not
in
self
.
device_mesh
.
shape
:
device_mesh_is_1d
=
False
device_mesh_is_1d
=
False
if
device_mesh_is_1d
:
if
device_mesh_is_1d
:
...
@@ -992,10 +976,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -992,10 +976,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# Sb = Sb x Sb
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
# can be None as it is only for 1D device mesh
# only for 1D device mesh
# only for 1D device mesh
if
len
(
self
.
device_mesh
.
mesh_
shape
)
==
1
:
if
len
(
self
.
device_mesh
.
shape
)
==
1
:
mesh_dim
=
0
mesh_dim
=
0
else
:
else
:
mesh_dim
=
self
.
device_mesh
.
mesh_
shape
.
index
(
1
)
mesh_dim
=
self
.
device_mesh
.
shape
.
index
(
1
)
strategy_list
.
append
(
self
.
split_one_batch_dim
(
mesh_dim
))
strategy_list
.
append
(
self
.
split_one_batch_dim
(
mesh_dim
))
else
:
else
:
# for 2D device mesh
# for 2D device mesh
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/normal_pooling_generator.py
View file @
9e768b59
...
@@ -17,32 +17,35 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
...
@@ -17,32 +17,35 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
"""
"""
NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd.
NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd.
The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image,
The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image,
and reduce them depening on the operation type.
and reduce them depen
d
ing on the operation type.
"""
"""
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
'''
"""
In sanity check, we need make sure the input data having correct dimension size.
In sanity check, we need make sure the input data having correct dimension size.
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
'''
"""
input_op_data
=
self
.
op_data
[
'
input
'
]
input_op_data
=
self
.
op_data
[
"
input
"
]
assert
input_op_data
.
data
.
dim
()
in
(
assert
input_op_data
.
data
.
dim
()
in
(
3
,
4
,
5
),
f
'We suppose the dim of input fed into Pool op should in range of [3, 5].'
3
,
4
,
5
,
),
f
"We suppose the dim of input fed into Pool op should in range of [3, 5]."
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
TrainCycleItem
:
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
)
->
TrainCycleItem
:
'''
"""
Compute the computation cost per device with this specific strategy.
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be d
e
vided by TFLOPS, now it just shows the computation size.
Note: compute_cost need to be d
i
vided by TFLOPS, now it just shows the computation size.
'''
"""
# TODO: compute_cost need to be d
e
vided by TFLOPS, now it just shows the computation size.
# TODO: compute_cost need to be d
i
vided by TFLOPS, now it just shows the computation size.
# 1D: (Lout) * N * C * kernel
# 1D: (Lout) * N * C * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
output
'
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
output
"
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
kernel_size
=
self
.
op_data
[
"other"
].
data
kernel_size
=
self
.
op_data
[
"other"
].
data
if
isinstance
(
kernel_size
,
int
):
if
isinstance
(
kernel_size
,
int
):
...
@@ -61,8 +64,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
...
@@ -61,8 +64,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
forward_size_mapping
=
{
forward_size_mapping
=
{
'
input
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
"
input
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
,
}
}
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
...
@@ -88,12 +91,16 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
...
@@ -88,12 +91,16 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
'
name
=
(
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
'
)
communication_action_mapping
=
{}
communication_action_mapping
=
{}
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
return
strategy
return
strategy
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/output_generator.py
View file @
9e768b59
...
@@ -12,7 +12,7 @@ from colossalai.device.device_mesh import DeviceMesh
...
@@ -12,7 +12,7 @@ from colossalai.device.device_mesh import DeviceMesh
from
.strategy_generator
import
OutputStrategyGenerator
from
.strategy_generator
import
OutputStrategyGenerator
__all__
=
[
'
OutputGenerator
'
]
__all__
=
[
"
OutputGenerator
"
]
class
OutputGenerator
(
OutputStrategyGenerator
):
class
OutputGenerator
(
OutputStrategyGenerator
):
...
@@ -20,8 +20,13 @@ class OutputGenerator(OutputStrategyGenerator):
...
@@ -20,8 +20,13 @@ class OutputGenerator(OutputStrategyGenerator):
OutputGenerator is a generic class to generate strategies for Output Node.
OutputGenerator is a generic class to generate strategies for Output Node.
"""
"""
def
__init__
(
self
,
operation_data_mapping
:
Dict
[
str
,
OperationData
],
device_mesh
:
DeviceMesh
,
def
__init__
(
predecessor_nodes
:
List
[
Node
],
output_option
:
str
):
self
,
operation_data_mapping
:
Dict
[
str
,
OperationData
],
device_mesh
:
DeviceMesh
,
predecessor_nodes
:
List
[
Node
],
output_option
:
str
,
):
super
().
__init__
(
operation_data_mapping
,
device_mesh
,
predecessor_nodes
)
super
().
__init__
(
operation_data_mapping
,
device_mesh
,
predecessor_nodes
)
self
.
output_option
=
output_option
self
.
output_option
=
output_option
...
@@ -33,9 +38,9 @@ class OutputGenerator(OutputStrategyGenerator):
...
@@ -33,9 +38,9 @@ class OutputGenerator(OutputStrategyGenerator):
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the memory cost per device with this specific strategy.
Compute the memory cost per device with this specific strategy.
'''
"""
fwd_mem_cost
=
MemoryCost
(
activation
=
0
,
parameter
=
0
)
fwd_mem_cost
=
MemoryCost
(
activation
=
0
,
parameter
=
0
)
bwd_mem_cost
=
MemoryCost
(
activation
=
0
,
parameter
=
0
)
bwd_mem_cost
=
MemoryCost
(
activation
=
0
,
parameter
=
0
)
...
@@ -65,16 +70,18 @@ class OutputGenerator(OutputStrategyGenerator):
...
@@ -65,16 +70,18 @@ class OutputGenerator(OutputStrategyGenerator):
else
:
else
:
dim_partition_dict_for_output
=
tuple
(
dim_partition_dict_for_output
)
dim_partition_dict_for_output
=
tuple
(
dim_partition_dict_for_output
)
dim_partition_dict_mapping
[
'
output
'
]
=
dim_partition_dict_for_output
dim_partition_dict_mapping
[
"
output
"
]
=
dim_partition_dict_for_output
communication_action_mapping
=
{}
communication_action_mapping
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
name
=
'
Replica Output
'
name
=
"
Replica Output
"
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
return
strategy
return
strategy
def
distributed_strategy
(
self
,
mesh_list
:
List
[
List
[
int
]]
=
None
)
->
List
[
ShardingStrategy
]:
def
distributed_strategy
(
self
,
mesh_list
:
List
[
List
[
int
]]
=
None
)
->
List
[
ShardingStrategy
]:
...
@@ -82,19 +89,15 @@ class OutputGenerator(OutputStrategyGenerator):
...
@@ -82,19 +89,15 @@ class OutputGenerator(OutputStrategyGenerator):
Generate distributed strategy for output node.
Generate distributed strategy for output node.
"""
"""
# TODO: need to take care of the case when the first element of output only need to be sharded.
# TODO: need to take care of the case when the first element of output only need to be sharded.
output_op_data
=
self
.
op_data
[
'
output
'
]
output_op_data
=
self
.
op_data
[
"
output
"
]
if
isinstance
(
output_op_data
.
data
,
tuple
):
if
isinstance
(
output_op_data
.
data
,
tuple
):
length
=
len
(
output_op_data
.
data
)
length
=
len
(
output_op_data
.
data
)
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"output"
:
[{
"output"
:
[{
0
:
mesh_list
}]
*
length
,
0
:
mesh_list
}]
*
length
,
}
}
else
:
else
:
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"output"
:
{
"output"
:
{
0
:
mesh_list
},
0
:
mesh_list
},
}
}
for
index
,
_
in
enumerate
(
self
.
predecessor_nodes
):
for
index
,
_
in
enumerate
(
self
.
predecessor_nodes
):
mapping_name
=
f
"input_
{
index
}
"
mapping_name
=
f
"input_
{
index
}
"
...
@@ -103,19 +106,21 @@ class OutputGenerator(OutputStrategyGenerator):
...
@@ -103,19 +106,21 @@ class OutputGenerator(OutputStrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
name
=
'
Distributed Output
'
name
=
"
Distributed Output
"
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
return
strategy
return
strategy
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
mesh_list
=
[
0
,
1
]
mesh_list
=
[
0
,
1
]
if
self
.
output_option
==
'
replicated
'
:
if
self
.
output_option
==
"
replicated
"
:
strategy_list
.
append
(
self
.
replica_strategy
())
strategy_list
.
append
(
self
.
replica_strategy
())
elif
self
.
output_option
==
'
distributed
'
:
elif
self
.
output_option
==
"
distributed
"
:
strategy_list
.
append
(
self
.
distributed_strategy
(
mesh_list
))
strategy_list
.
append
(
self
.
distributed_strategy
(
mesh_list
))
return
strategy_list
return
strategy_list
colossalai/auto_parallel/tensor_shard/node_handler/strategy/placeholder_generator.py
View file @
9e768b59
...
@@ -10,7 +10,7 @@ from colossalai.device.device_mesh import DeviceMesh
...
@@ -10,7 +10,7 @@ from colossalai.device.device_mesh import DeviceMesh
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
__all__
=
[
'
PlaceholderGenerator
'
]
__all__
=
[
"
PlaceholderGenerator
"
]
class
PlaceholderGenerator
(
StrategyGenerator
):
class
PlaceholderGenerator
(
StrategyGenerator
):
...
@@ -18,8 +18,9 @@ class PlaceholderGenerator(StrategyGenerator):
...
@@ -18,8 +18,9 @@ class PlaceholderGenerator(StrategyGenerator):
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
PlaceholderGenerator is a generic class to generate strategies for placeholder node.
"""
"""
def
__init__
(
self
,
operation_data_mapping
:
Dict
[
str
,
OperationData
],
device_mesh
:
DeviceMesh
,
def
__init__
(
placeholder_option
:
str
):
self
,
operation_data_mapping
:
Dict
[
str
,
OperationData
],
device_mesh
:
DeviceMesh
,
placeholder_option
:
str
):
super
().
__init__
(
operation_data_mapping
,
device_mesh
)
super
().
__init__
(
operation_data_mapping
,
device_mesh
)
self
.
placeholder_option
=
placeholder_option
self
.
placeholder_option
=
placeholder_option
...
@@ -31,10 +32,10 @@ class PlaceholderGenerator(StrategyGenerator):
...
@@ -31,10 +32,10 @@ class PlaceholderGenerator(StrategyGenerator):
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the memory cost per device with this specific strategy.
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping
=
{
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)}
forward_size_mapping
=
{
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)}
# compute fwd cost incurred
# compute fwd cost incurred
# fwd_cost = output
# fwd_cost = output
...
@@ -58,11 +59,13 @@ class PlaceholderGenerator(StrategyGenerator):
...
@@ -58,11 +59,13 @@ class PlaceholderGenerator(StrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
name
=
'
Replica Placeholder
'
name
=
"
Replica Placeholder
"
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
return
strategy
return
strategy
...
@@ -71,29 +74,31 @@ class PlaceholderGenerator(StrategyGenerator):
...
@@ -71,29 +74,31 @@ class PlaceholderGenerator(StrategyGenerator):
Generate distributed strategy for placeholder node.
Generate distributed strategy for placeholder node.
"""
"""
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"output"
:
{
"output"
:
{
0
:
mesh_list
},
0
:
mesh_list
},
}
}
communication_action_mapping
=
{}
communication_action_mapping
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
name
=
'
Distributed Placeholder
'
name
=
"
Distributed Placeholder
"
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
return
strategy
return
strategy
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
if
self
.
placeholder_option
==
'
distributed
'
:
if
self
.
placeholder_option
==
"
distributed
"
:
mesh_list
=
[
0
,
1
]
mesh_list
=
[
0
,
1
]
distributed_strategy
=
self
.
distributed_placeholder
(
mesh_list
)
distributed_strategy
=
self
.
distributed_placeholder
(
mesh_list
)
strategy_list
.
append
(
distributed_strategy
)
strategy_list
.
append
(
distributed_strategy
)
else
:
else
:
assert
self
.
placeholder_option
==
'replicated'
,
f
'placeholder_option
{
self
.
placeholder_option
}
is not supported'
assert
(
self
.
placeholder_option
==
"replicated"
),
f
"placeholder_option
{
self
.
placeholder_option
}
is not supported"
replicated_strategy
=
self
.
replica_placeholder
()
replicated_strategy
=
self
.
replica_placeholder
()
strategy_list
.
append
(
replicated_strategy
)
strategy_list
.
append
(
replicated_strategy
)
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
View file @
9e768b59
...
@@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.utils import (
...
@@ -17,7 +17,7 @@ from colossalai.auto_parallel.tensor_shard.utils import (
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
__all__
=
[
'
ReshapeGenerator
'
,
'
ViewGenerator
'
,
'
PermuteGenerator
'
,
'
TransposeGenerator
'
,
'
SplitGenerator
'
]
__all__
=
[
"
ReshapeGenerator
"
,
"
ViewGenerator
"
,
"
PermuteGenerator
"
,
"
TransposeGenerator
"
,
"
SplitGenerator
"
]
class
ReshapeGenerator
(
FollowingStrategyGenerator
):
class
ReshapeGenerator
(
FollowingStrategyGenerator
):
...
@@ -33,12 +33,12 @@ class ReshapeGenerator(FollowingStrategyGenerator):
...
@@ -33,12 +33,12 @@ class ReshapeGenerator(FollowingStrategyGenerator):
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the memory cost per device with this specific strategy.
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping
=
{
forward_size_mapping
=
{
'
input
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
"
input
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
,
}
}
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
...
@@ -56,8 +56,9 @@ class ReshapeGenerator(FollowingStrategyGenerator):
...
@@ -56,8 +56,9 @@ class ReshapeGenerator(FollowingStrategyGenerator):
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
total_mem_cost
=
MemoryCost
(
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
...
@@ -77,8 +78,8 @@ class ViewGenerator(ReshapeGenerator):
...
@@ -77,8 +78,8 @@ class ViewGenerator(ReshapeGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
origin_shape
=
self
.
op_data
[
'
input
'
].
data
.
shape
origin_shape
=
self
.
op_data
[
"
input
"
].
data
.
shape
tgt_shape
=
self
.
op_data
[
'
tgt_shape
'
].
data
tgt_shape
=
self
.
op_data
[
"
tgt_shape
"
].
data
reshape_mapping_dict
=
detect_reshape_mapping
(
origin_shape
,
tgt_shape
)
reshape_mapping_dict
=
detect_reshape_mapping
(
origin_shape
,
tgt_shape
)
...
@@ -86,8 +87,9 @@ class ViewGenerator(ReshapeGenerator):
...
@@ -86,8 +87,9 @@ class ViewGenerator(ReshapeGenerator):
keep_sharding_status
=
check_keep_sharding_status
(
dim_partition_dict_for_input
,
reshape_mapping_dict
)
keep_sharding_status
=
check_keep_sharding_status
(
dim_partition_dict_for_input
,
reshape_mapping_dict
)
if
keep_sharding_status
:
if
keep_sharding_status
:
dim_partition_dict_for_output
=
infer_output_dim_partition_dict
(
dim_partition_dict_for_input
,
dim_partition_dict_for_output
=
infer_output_dim_partition_dict
(
reshape_mapping_dict
)
dim_partition_dict_for_input
,
reshape_mapping_dict
)
else
:
else
:
dim_partition_dict_for_output
=
{}
dim_partition_dict_for_output
=
{}
...
@@ -119,7 +121,8 @@ class ViewGenerator(ReshapeGenerator):
...
@@ -119,7 +121,8 @@ class ViewGenerator(ReshapeGenerator):
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
logical_process_axis
=
total_mesh_dim_list
,
logical_process_axis
=
total_mesh_dim_list
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
# it will gather the input through gather_dim during forward phase.
# it will gather the input through gather_dim during forward phase.
input_comm_action
.
comm_spec
.
gather_dim
=
shard_dim
input_comm_action
.
comm_spec
.
gather_dim
=
shard_dim
# it will split the input activation grad through shard_dim during backward phase.
# it will split the input activation grad through shard_dim during backward phase.
...
@@ -127,10 +130,10 @@ class ViewGenerator(ReshapeGenerator):
...
@@ -127,10 +130,10 @@ class ViewGenerator(ReshapeGenerator):
elif
len
(
total_mesh_dim_list
)
>=
2
:
elif
len
(
total_mesh_dim_list
)
>=
2
:
source_spec
=
sharding_spec_mapping
[
"input"
]
source_spec
=
sharding_spec_mapping
[
"input"
]
target_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
target_spec
=
ShardingSpec
(
entire_shape
=
source_spec
.
entire_shape
,
device_mesh
=
self
.
device_mesh
,
entire_shape
=
source_spec
.
entire_shape
,
dim_partition_dict
=
{}
dim_partition_dict
=
{}
)
)
comm_spec
=
{
'
src_spec
'
:
source_spec
,
'
tgt_spec
'
:
target_spec
}
comm_spec
=
{
"
src_spec
"
:
source_spec
,
"
tgt_spec
"
:
target_spec
}
input_comm_action
=
CommAction
(
comm_spec
=
comm_spec
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
input_comm_action
=
CommAction
(
comm_spec
=
comm_spec
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
else
:
else
:
...
@@ -139,9 +142,11 @@ class ViewGenerator(ReshapeGenerator):
...
@@ -139,9 +142,11 @@ class ViewGenerator(ReshapeGenerator):
if
input_comm_action
is
not
None
:
if
input_comm_action
is
not
None
:
communication_action_mapping
[
"input"
]
=
input_comm_action
communication_action_mapping
[
"input"
]
=
input_comm_action
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
return
strategy_list
return
strategy_list
...
@@ -159,7 +164,7 @@ class PermuteGenerator(ReshapeGenerator):
...
@@ -159,7 +164,7 @@ class PermuteGenerator(ReshapeGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
permute_dims
=
self
.
op_data
[
'
permute_dims
'
].
data
permute_dims
=
self
.
op_data
[
"
permute_dims
"
].
data
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_output
=
{}
dim_partition_dict_for_output
=
{}
for
dim_index
,
permute_dim
in
enumerate
(
permute_dims
):
for
dim_index
,
permute_dim
in
enumerate
(
permute_dims
):
...
@@ -177,9 +182,11 @@ class PermuteGenerator(ReshapeGenerator):
...
@@ -177,9 +182,11 @@ class PermuteGenerator(ReshapeGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
return
strategy_list
return
strategy_list
...
@@ -199,7 +206,7 @@ class TransposeGenerator(ReshapeGenerator):
...
@@ -199,7 +206,7 @@ class TransposeGenerator(ReshapeGenerator):
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_output
=
{}
dim_partition_dict_for_output
=
{}
transpose_dims
=
self
.
op_data
[
'
transpose_dims
'
].
data
transpose_dims
=
self
.
op_data
[
"
transpose_dims
"
].
data
dim_0
=
transpose_dims
[
0
]
dim_0
=
transpose_dims
[
0
]
dim_1
=
transpose_dims
[
1
]
dim_1
=
transpose_dims
[
1
]
for
dim
,
sharded_dims
in
dim_partition_dict_for_input
.
items
():
for
dim
,
sharded_dims
in
dim_partition_dict_for_input
.
items
():
...
@@ -221,9 +228,11 @@ class TransposeGenerator(ReshapeGenerator):
...
@@ -221,9 +228,11 @@ class TransposeGenerator(ReshapeGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
return
strategy_list
return
strategy_list
...
@@ -242,7 +251,7 @@ class SplitGenerator(ReshapeGenerator):
...
@@ -242,7 +251,7 @@ class SplitGenerator(ReshapeGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
dim_partition_dict_for_input
=
copy
.
deepcopy
(
input_sharding_spec
.
dim_partition_dict
)
dim_partition_dict_for_input
=
copy
.
deepcopy
(
input_sharding_spec
.
dim_partition_dict
)
split_size
,
split_dim
=
self
.
op_data
[
'
split_info
'
].
data
split_size
,
split_dim
=
self
.
op_data
[
"
split_info
"
].
data
if
split_dim
in
dim_partition_dict_for_input
:
if
split_dim
in
dim_partition_dict_for_input
:
recover_dims
=
dim_partition_dict_for_input
.
pop
(
split_dim
)
recover_dims
=
dim_partition_dict_for_input
.
pop
(
split_dim
)
...
@@ -271,7 +280,8 @@ class SplitGenerator(ReshapeGenerator):
...
@@ -271,7 +280,8 @@ class SplitGenerator(ReshapeGenerator):
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
logical_process_axis
=
recover_dims
,
logical_process_axis
=
recover_dims
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
# it will gather the input through gather_dim during forward phase.
# it will gather the input through gather_dim during forward phase.
input_comm_action
.
comm_spec
.
gather_dim
=
split_dim
input_comm_action
.
comm_spec
.
gather_dim
=
split_dim
# it will split the input activation grad through split_dim during backward phase.
# it will split the input activation grad through split_dim during backward phase.
...
@@ -282,7 +292,7 @@ class SplitGenerator(ReshapeGenerator):
...
@@ -282,7 +292,7 @@ class SplitGenerator(ReshapeGenerator):
source_spec
=
input_sharding_spec
source_spec
=
input_sharding_spec
# target sharding spec
# target sharding spec
target_spec
=
sharding_spec_mapping
[
"input"
]
target_spec
=
sharding_spec_mapping
[
"input"
]
comm_spec
=
{
'
src_spec
'
:
source_spec
,
'
tgt_spec
'
:
target_spec
}
comm_spec
=
{
"
src_spec
"
:
source_spec
,
"
tgt_spec
"
:
target_spec
}
input_comm_action
=
CommAction
(
comm_spec
=
comm_spec
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
input_comm_action
=
CommAction
(
comm_spec
=
comm_spec
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
else
:
else
:
...
@@ -291,9 +301,11 @@ class SplitGenerator(ReshapeGenerator):
...
@@ -291,9 +301,11 @@ class SplitGenerator(ReshapeGenerator):
if
input_comm_action
is
not
None
:
if
input_comm_action
is
not
None
:
communication_action_mapping
[
"input"
]
=
input_comm_action
communication_action_mapping
[
"input"
]
=
input_comm_action
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
return
strategy_list
return
strategy_list
...
@@ -341,16 +353,17 @@ class DefaultReshapeGenerator(ReshapeGenerator):
...
@@ -341,16 +353,17 @@ class DefaultReshapeGenerator(ReshapeGenerator):
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
logical_process_axis
=
total_mesh_dim_list
,
logical_process_axis
=
total_mesh_dim_list
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
,
)
input_comm_action
.
comm_spec
.
gather_dim
=
total_mesh_dim_list
input_comm_action
.
comm_spec
.
gather_dim
=
total_mesh_dim_list
input_comm_action
.
comm_spec
.
shard_dim
=
total_mesh_dim_list
input_comm_action
.
comm_spec
.
shard_dim
=
total_mesh_dim_list
elif
len
(
total_mesh_dim_list
)
>=
2
:
elif
len
(
total_mesh_dim_list
)
>=
2
:
source_spec
=
sharding_spec_mapping
[
"input"
]
source_spec
=
sharding_spec_mapping
[
"input"
]
target_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
target_spec
=
ShardingSpec
(
entire_shape
=
source_spec
.
entire_shape
,
device_mesh
=
self
.
device_mesh
,
entire_shape
=
source_spec
.
entire_shape
,
dim_partition_dict
=
{}
dim_partition_dict
=
{}
)
)
comm_spec
=
{
'
src_spec
'
:
source_spec
,
'
tgt_spec
'
:
target_spec
}
comm_spec
=
{
"
src_spec
"
:
source_spec
,
"
tgt_spec
"
:
target_spec
}
input_comm_action
=
CommAction
(
comm_spec
=
comm_spec
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
input_comm_action
=
CommAction
(
comm_spec
=
comm_spec
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
else
:
else
:
...
@@ -358,9 +371,11 @@ class DefaultReshapeGenerator(ReshapeGenerator):
...
@@ -358,9 +371,11 @@ class DefaultReshapeGenerator(ReshapeGenerator):
if
input_comm_action
is
not
None
:
if
input_comm_action
is
not
None
:
communication_action_mapping
[
"input"
]
=
input_comm_action
communication_action_mapping
[
"input"
]
=
input_comm_action
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
return
strategy_list
return
strategy_list
colossalai/auto_parallel/tensor_shard/node_handler/strategy/softmax_generator.py
View file @
9e768b59
...
@@ -4,21 +4,9 @@ from functools import reduce
...
@@ -4,21 +4,9 @@ from functools import reduce
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator
import
FollowingStrategyGenerator
from
colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator
import
FollowingStrategyGenerator
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
CommAction
,
CommType
,
__all__
=
[
"SoftmaxGenerator"
]
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.auto_parallel.tensor_shard.utils
import
(
check_keep_sharding_status
,
detect_reshape_mapping
,
infer_output_dim_partition_dict
,
)
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
__all__
=
[
'SoftmaxGenerator'
]
class
SoftmaxGenerator
(
FollowingStrategyGenerator
):
class
SoftmaxGenerator
(
FollowingStrategyGenerator
):
...
@@ -30,11 +18,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
...
@@ -30,11 +18,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
return
super
().
validate
()
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the computation cost per device with this specific strategy.
Compute the computation cost per device with this specific strategy.
'''
"""
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
output
'
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
output
"
]].
get_sharded_shape_per_device
()
input_size_product
=
reduce
(
operator
.
mul
,
sharded_input_shape
)
input_size_product
=
reduce
(
operator
.
mul
,
sharded_input_shape
)
output_size_product
=
reduce
(
operator
.
mul
,
sharded_output_shape
)
output_size_product
=
reduce
(
operator
.
mul
,
sharded_output_shape
)
...
@@ -45,12 +33,12 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
...
@@ -45,12 +33,12 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the memory cost per device with this specific strategy.
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping
=
{
forward_size_mapping
=
{
'
input
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
"
input
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
,
}
}
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
...
@@ -68,8 +56,9 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
...
@@ -68,8 +56,9 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
total_mem_cost
=
MemoryCost
(
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
...
@@ -80,10 +69,10 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
...
@@ -80,10 +69,10 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
dim_partition_dict_for_input
=
copy
.
deepcopy
(
input_sharding_spec
.
dim_partition_dict
)
dim_partition_dict_for_input
=
copy
.
deepcopy
(
input_sharding_spec
.
dim_partition_dict
)
softmax_dim
=
self
.
op_data
[
'
softmax_dim
'
].
data
softmax_dim
=
self
.
op_data
[
"
softmax_dim
"
].
data
if
softmax_dim
in
dim_partition_dict_for_input
:
if
softmax_dim
in
dim_partition_dict_for_input
:
recover_dims
=
dim_partition_dict_for_input
.
pop
(
softmax_dim
)
dim_partition_dict_for_input
.
pop
(
softmax_dim
)
dim_partition_dict_for_output
=
copy
.
deepcopy
(
dim_partition_dict_for_input
)
dim_partition_dict_for_output
=
copy
.
deepcopy
(
dim_partition_dict_for_input
)
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
...
@@ -96,9 +85,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
...
@@ -96,9 +85,11 @@ class SoftmaxGenerator(FollowingStrategyGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
return
strategy_list
return
strategy_list
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
View file @
9e768b59
...
@@ -39,7 +39,7 @@ class StrategyGenerator(ABC):
...
@@ -39,7 +39,7 @@ class StrategyGenerator(ABC):
"""
"""
A utility method to check for the existence of bias operand for convenience.
A utility method to check for the existence of bias operand for convenience.
"""
"""
return
'
bias
'
in
self
.
op_data
return
"
bias
"
in
self
.
op_data
def
is_param
(
self
,
op_data_name
):
def
is_param
(
self
,
op_data_name
):
other_data
=
self
.
op_data
[
op_data_name
]
other_data
=
self
.
op_data
[
op_data_name
]
...
@@ -49,8 +49,12 @@ class StrategyGenerator(ABC):
...
@@ -49,8 +49,12 @@ class StrategyGenerator(ABC):
other_data
=
self
.
op_data
[
op_data_name
]
other_data
=
self
.
op_data
[
op_data_name
]
return
other_data
.
type
==
OperationDataType
.
BUFFER
return
other_data
.
type
==
OperationDataType
.
BUFFER
def
get_sharding_strategy
(
self
,
name
:
str
,
sharding_spec_mapping
:
Dict
[
str
,
ShardingSpec
],
def
get_sharding_strategy
(
communication_action_mapping
:
Dict
[
str
,
CommSpec
]):
self
,
name
:
str
,
sharding_spec_mapping
:
Dict
[
str
,
ShardingSpec
],
communication_action_mapping
:
Dict
[
str
,
CommSpec
],
):
"""
"""
A factory method to produce a ShardingStrategy object.
A factory method to produce a ShardingStrategy object.
...
@@ -80,24 +84,28 @@ class StrategyGenerator(ABC):
...
@@ -80,24 +84,28 @@ class StrategyGenerator(ABC):
op_data
=
self
.
op_data
[
op_data_name
]
op_data
=
self
.
op_data
[
op_data_name
]
def
_to_sharding_spec
(
def
_to_sharding_spec
(
data
:
any
,
logical_shape
:
any
,
data
:
any
,
logical_shape
:
any
,
dim_partition_dict
:
Dict
[
int
,
List
[
int
]]
dim_partition_dict
:
Dict
[
int
,
List
[
int
]]
)
->
Union
[
ShardingSpec
,
List
[
ShardingSpec
],
None
]:
)
->
Union
[
ShardingSpec
,
List
[
ShardingSpec
],
None
]:
"""
"""
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
"""
"""
if
isinstance
(
data
,
torch
.
Tensor
):
if
isinstance
(
data
,
torch
.
Tensor
):
dim_size
=
len
(
logical_shape
)
dim_size
=
len
(
logical_shape
)
dim_partition_dict
=
convert_dim_partition_dict
(
dim_size
,
dim_partition_dict
)
dim_partition_dict
=
convert_dim_partition_dict
(
dim_size
,
dim_partition_dict
)
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
sharding_spec
=
ShardingSpec
(
entire_shape
=
logical_shape
,
device_mesh
=
self
.
device_mesh
,
dim_partition_dict
=
dim_partition_dict
)
entire_shape
=
logical_shape
,
dim_partition_dict
=
dim_partition_dict
,
)
return
sharding_spec
return
sharding_spec
elif
isinstance
(
data
,
(
list
,
tuple
)):
elif
isinstance
(
data
,
(
list
,
tuple
)):
sharding_spec
=
[]
sharding_spec
=
[]
for
data_element
,
logical_shape_element
,
dim_partition_dict_element
in
zip
(
for
data_element
,
logical_shape_element
,
dim_partition_dict_element
in
zip
(
data
,
logical_shape
,
dim_partition_dict
):
data
,
logical_shape
,
dim_partition_dict
):
sharding_spec
.
append
(
sharding_spec
.
append
(
_to_sharding_spec
(
data_element
,
logical_shape_element
,
dim_partition_dict_element
))
_to_sharding_spec
(
data_element
,
logical_shape_element
,
dim_partition_dict_element
)
)
return
sharding_spec
return
sharding_spec
else
:
else
:
return
None
return
None
...
@@ -116,31 +124,41 @@ class StrategyGenerator(ABC):
...
@@ -116,31 +124,41 @@ class StrategyGenerator(ABC):
results
[
op_data
]
=
v
results
[
op_data
]
=
v
return
results
return
results
def
get_communication_spec
(
self
,
sharding_spec
:
ShardingSpec
,
communication_pattern
:
CollectiveCommPattern
,
def
get_communication_spec
(
logical_process_axis
:
Union
[
int
,
List
[
int
]]):
self
,
sharding_spec
:
ShardingSpec
,
communication_pattern
:
CollectiveCommPattern
,
logical_process_axis
:
Union
[
int
,
List
[
int
]],
):
"""
"""
A factory method to produce a CommSpec object.
A factory method to produce a CommSpec object.
"""
"""
return
CommSpec
(
comm_pattern
=
communication_pattern
,
return
CommSpec
(
sharding_spec
=
sharding_spec
,
comm_pattern
=
communication_pattern
,
sharding_spec
=
sharding_spec
,
logical_process_axis
=
logical_process_axis
logical_process_axis
=
logical_process_axis
)
)
def
get_communication_action
(
self
,
def
get_communication_action
(
sharding_spec
:
ShardingSpec
,
self
,
communication_pattern
:
CollectiveCommPattern
,
sharding_spec
:
ShardingSpec
,
logical_process_axis
:
Union
[
int
,
List
[
int
]],
communication_pattern
:
CollectiveCommPattern
,
comm_type
:
CommType
,
logical_process_axis
:
Union
[
int
,
List
[
int
]],
arg_index
:
int
=
-
1
,
comm_type
:
CommType
,
key_for_kwarg
:
any
=
None
)
->
CommAction
:
arg_index
:
int
=
-
1
,
key_for_kwarg
:
any
=
None
,
)
->
CommAction
:
"""
"""
A factory method to produce a CommAction object.
A factory method to produce a CommAction object.
"""
"""
return
CommAction
(
comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec
,
return
CommAction
(
communication_pattern
=
communication_pattern
,
comm_spec
=
self
.
get_communication_spec
(
logical_process_axis
=
logical_process_axis
),
sharding_spec
=
sharding_spec
,
comm_type
=
comm_type
,
communication_pattern
=
communication_pattern
,
arg_index
=
arg_index
,
logical_process_axis
=
logical_process_axis
,
key_for_kwarg
=
key_for_kwarg
)
),
comm_type
=
comm_type
,
arg_index
=
arg_index
,
key_for_kwarg
=
key_for_kwarg
,
)
def
update_communication_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
def
update_communication_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
"""
"""
...
@@ -155,9 +173,9 @@ class StrategyGenerator(ABC):
...
@@ -155,9 +173,9 @@ class StrategyGenerator(ABC):
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
for
phase
,
cost
in
num_ele_in_comm
.
items
():
for
phase
,
cost
in
num_ele_in_comm
.
items
():
num_ele_in_comm
[
phase
]
=
num_ele_in_comm
[
phase
]
*
size_per_elem_bytes
num_ele_in_comm
[
phase
]
=
num_ele_in_comm
[
phase
]
*
size_per_elem_bytes
comm_cost
.
fwd
+=
num_ele_in_comm
[
'
forward
'
]
comm_cost
.
fwd
+=
num_ele_in_comm
[
"
forward
"
]
comm_cost
.
bwd
+=
num_ele_in_comm
[
'
backward
'
]
comm_cost
.
bwd
+=
num_ele_in_comm
[
"
backward
"
]
comm_cost
.
total
+=
num_ele_in_comm
[
'
total
'
]
comm_cost
.
total
+=
num_ele_in_comm
[
"
total
"
]
# check if communication action exists
# check if communication action exists
# if so, loop over each action and compute the cost of each action
# if so, loop over each action and compute the cost of each action
...
@@ -169,8 +187,8 @@ class StrategyGenerator(ABC):
...
@@ -169,8 +187,8 @@ class StrategyGenerator(ABC):
# this condition branch will be removed after all the handler updated.
# this condition branch will be removed after all the handler updated.
comm_spec
=
comm_action
comm_spec
=
comm_action
if
isinstance
(
comm_spec
,
dict
):
if
isinstance
(
comm_spec
,
dict
):
src_spec
=
comm_spec
[
'
src_spec
'
]
src_spec
=
comm_spec
[
"
src_spec
"
]
tgt_spec
=
comm_spec
[
'
tgt_spec
'
]
tgt_spec
=
comm_spec
[
"
tgt_spec
"
]
shape_consistency_manager
=
ShapeConsistencyManager
()
shape_consistency_manager
=
ShapeConsistencyManager
()
_
,
comm_action_sequence
,
_
=
shape_consistency_manager
.
shape_consistency
(
src_spec
,
tgt_spec
)
_
,
comm_action_sequence
,
_
=
shape_consistency_manager
.
shape_consistency
(
src_spec
,
tgt_spec
)
for
comm_spec_
in
comm_action_sequence
:
for
comm_spec_
in
comm_action_sequence
:
...
@@ -187,14 +205,12 @@ class StrategyGenerator(ABC):
...
@@ -187,14 +205,12 @@ class StrategyGenerator(ABC):
"""
"""
Customize this method to compute the computation flops.
Customize this method to compute the computation flops.
"""
"""
pass
@
abstractmethod
@
abstractmethod
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
"""
"""
Customize this method to compute the memory cost in bytes.
Customize this method to compute the memory cost in bytes.
"""
"""
pass
def
_compute_size_in_bytes
(
self
,
strategy
:
ShardingStrategy
,
key
:
str
):
def
_compute_size_in_bytes
(
self
,
strategy
:
ShardingStrategy
,
key
:
str
):
"""
"""
...
@@ -212,20 +228,21 @@ class StrategyGenerator(ABC):
...
@@ -212,20 +228,21 @@ class StrategyGenerator(ABC):
num_elements
=
1
num_elements
=
1
else
:
else
:
num_elements
=
reduce
(
operator
.
mul
,
sharded_shape
)
num_elements
=
reduce
(
operator
.
mul
,
sharded_shape
)
dtype
=
getattr
(
meta_data
,
'
dtype
'
)
dtype
=
getattr
(
meta_data
,
"
dtype
"
)
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
return
num_elements
*
size_per_elem_bytes
return
num_elements
*
size_per_elem_bytes
if
isinstance
(
op_data
.
data
,
tuple
):
if
isinstance
(
op_data
.
data
,
tuple
):
assert
isinstance
(
strategy
.
sharding_specs
[
op_data
],
list
),
\
assert
isinstance
(
'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
strategy
.
sharding_specs
[
op_data
],
list
),
"sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple."
total_bytes
=
0
total_bytes
=
0
for
index
,
sharding_spec
in
enumerate
(
strategy
.
sharding_specs
[
op_data
]):
for
index
,
sharding_spec
in
enumerate
(
strategy
.
sharding_specs
[
op_data
]):
meta_data
=
op_data
.
data
[
index
]
meta_data
=
op_data
.
data
[
index
]
if
isinstance
(
meta_data
,
torch
.
Tensor
):
if
isinstance
(
meta_data
,
torch
.
Tensor
):
element_bytes
=
_compute_size_in_bytes_helper
(
sharding_spec
,
meta_data
)
element_bytes
=
_compute_size_in_bytes_helper
(
sharding_spec
,
meta_data
)
else
:
else
:
# if meta_data is not a tensor, we count the mem
r
oy as 0
# if meta_data is not a tensor, we count the memo
r
y as 0
element_bytes
=
0
element_bytes
=
0
total_bytes
+=
element_bytes
total_bytes
+=
element_bytes
...
@@ -233,7 +250,7 @@ class StrategyGenerator(ABC):
...
@@ -233,7 +250,7 @@ class StrategyGenerator(ABC):
if
isinstance
(
op_data
.
data
,
torch
.
Tensor
):
if
isinstance
(
op_data
.
data
,
torch
.
Tensor
):
total_bytes
=
_compute_size_in_bytes_helper
(
strategy
.
sharding_specs
[
op_data
],
op_data
.
data
)
total_bytes
=
_compute_size_in_bytes_helper
(
strategy
.
sharding_specs
[
op_data
],
op_data
.
data
)
else
:
else
:
# if op_data.data is not a tensor, we count the mem
r
oy as 0
# if op_data.data is not a tensor, we count the memo
r
y as 0
total_bytes
=
0
total_bytes
=
0
return
total_bytes
return
total_bytes
...
@@ -270,7 +287,6 @@ class StrategyGenerator(ABC):
...
@@ -270,7 +287,6 @@ class StrategyGenerator(ABC):
Validate if the operands are of desired shape.
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
If True, means this generator can be used for the current operation.
"""
"""
pass
class
FollowingStrategyGenerator
(
StrategyGenerator
):
class
FollowingStrategyGenerator
(
StrategyGenerator
):
...
@@ -280,8 +296,9 @@ class FollowingStrategyGenerator(StrategyGenerator):
...
@@ -280,8 +296,9 @@ class FollowingStrategyGenerator(StrategyGenerator):
TODO: remove the original strategy_generator.py after refactoring
TODO: remove the original strategy_generator.py after refactoring
"""
"""
def
__init__
(
self
,
operation_data_mapping
:
Dict
[
str
,
OperationData
],
device_mesh
:
DeviceMesh
,
def
__init__
(
predecessor_node
:
Node
):
self
,
operation_data_mapping
:
Dict
[
str
,
OperationData
],
device_mesh
:
DeviceMesh
,
predecessor_node
:
Node
):
self
.
op_data
=
operation_data_mapping
self
.
op_data
=
operation_data_mapping
self
.
device_mesh
=
device_mesh
self
.
device_mesh
=
device_mesh
self
.
predecessor_node
=
predecessor_node
self
.
predecessor_node
=
predecessor_node
...
@@ -292,7 +309,8 @@ class OutputStrategyGenerator(StrategyGenerator):
...
@@ -292,7 +309,8 @@ class OutputStrategyGenerator(StrategyGenerator):
OutputStrategyGenerator is used to generate the sharding strategies for Output Node.
OutputStrategyGenerator is used to generate the sharding strategies for Output Node.
"""
"""
def
__init__
(
self
,
operation_data_mapping
:
Dict
[
str
,
OperationData
],
device_mesh
:
DeviceMesh
,
def
__init__
(
predecessor_nodes
:
List
[
Node
]):
self
,
operation_data_mapping
:
Dict
[
str
,
OperationData
],
device_mesh
:
DeviceMesh
,
predecessor_nodes
:
List
[
Node
]
):
super
().
__init__
(
operation_data_mapping
,
device_mesh
)
super
().
__init__
(
operation_data_mapping
,
device_mesh
)
self
.
predecessor_nodes
=
predecessor_nodes
self
.
predecessor_nodes
=
predecessor_nodes
colossalai/auto_parallel/tensor_shard/node_handler/strategy/sum_generator.py
View file @
9e768b59
...
@@ -4,22 +4,9 @@ from functools import reduce
...
@@ -4,22 +4,9 @@ from functools import reduce
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator
import
FollowingStrategyGenerator
from
colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator
import
FollowingStrategyGenerator
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
CommAction
,
CommType
,
__all__
=
[
"SumGenerator"
]
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.auto_parallel.tensor_shard.utils
import
(
check_keep_sharding_status
,
detect_reshape_mapping
,
infer_output_dim_partition_dict
,
)
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.sharding_spec
import
ShardingSpec
__all__
=
[
'SumGenerator'
]
class
SumGenerator
(
FollowingStrategyGenerator
):
class
SumGenerator
(
FollowingStrategyGenerator
):
...
@@ -31,24 +18,24 @@ class SumGenerator(FollowingStrategyGenerator):
...
@@ -31,24 +18,24 @@ class SumGenerator(FollowingStrategyGenerator):
return
super
().
validate
()
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy
):
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
input
'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
input
"
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'
output
'
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
"
output
"
]].
get_sharded_shape_per_device
()
input_size_product
=
reduce
(
operator
.
mul
,
sharded_input_shape
)
input_size_product
=
reduce
(
operator
.
mul
,
sharded_input_shape
)
output_size_product
=
reduce
(
operator
.
mul
,
sharded_output_shape
)
output_size_product
=
reduce
(
operator
.
mul
,
sharded_output_shape
)
compute_cost
=
TrainCycleItem
(
fwd
=
input_size_product
,
compute_cost
=
TrainCycleItem
(
bwd
=
output_size_product
,
fwd
=
input_size_product
,
bwd
=
output_size_product
,
total
=
input_size_product
+
output_size_product
total
=
input_size_product
+
output_size_product
)
)
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the memory cost per device with this specific strategy.
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping
=
{
forward_size_mapping
=
{
'
input
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
"
input
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
,
}
}
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
...
@@ -66,8 +53,9 @@ class SumGenerator(FollowingStrategyGenerator):
...
@@ -66,8 +53,9 @@ class SumGenerator(FollowingStrategyGenerator):
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
total_mem_cost
=
MemoryCost
(
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
...
@@ -78,7 +66,7 @@ class SumGenerator(FollowingStrategyGenerator):
...
@@ -78,7 +66,7 @@ class SumGenerator(FollowingStrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
dim_partition_dict_for_input
=
copy
.
deepcopy
(
input_sharding_spec
.
dim_partition_dict
)
dim_partition_dict_for_input
=
copy
.
deepcopy
(
input_sharding_spec
.
dim_partition_dict
)
sum_dims
,
sum_mapping_dict
=
self
.
op_data
[
'
sum_info
'
].
data
sum_dims
,
sum_mapping_dict
=
self
.
op_data
[
"
sum_info
"
].
data
# TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce
# TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce
# among all the shard groups
# among all the shard groups
...
@@ -90,7 +78,7 @@ class SumGenerator(FollowingStrategyGenerator):
...
@@ -90,7 +78,7 @@ class SumGenerator(FollowingStrategyGenerator):
elif
dim
in
sum_mapping_dict
:
elif
dim
in
sum_mapping_dict
:
dim_partition_dict_for_output
[
sum_mapping_dict
[
dim
]]
=
dim_partition_dict_for_input
[
dim
]
dim_partition_dict_for_output
[
sum_mapping_dict
[
dim
]]
=
dim_partition_dict_for_input
[
dim
]
else
:
else
:
raise
RuntimeError
(
f
'
dim
{
dim
}
is not in sum_mapping_dict or sum_dims
'
)
raise
RuntimeError
(
f
"
dim
{
dim
}
is not in sum_mapping_dict or sum_dims
"
)
for
dim
in
recover_dims
:
for
dim
in
recover_dims
:
dim_partition_dict_for_input
.
pop
(
dim
)
dim_partition_dict_for_input
.
pop
(
dim
)
...
@@ -105,9 +93,11 @@ class SumGenerator(FollowingStrategyGenerator):
...
@@ -105,9 +93,11 @@ class SumGenerator(FollowingStrategyGenerator):
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
return
strategy_list
return
strategy_list
colossalai/auto_parallel/tensor_shard/node_handler/strategy/tensor_constructor_generator.py
View file @
9e768b59
import
copy
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
CommAction
,
CommType
,
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
__all__
=
[
'
TensorConstructorGenerator
'
]
__all__
=
[
"
TensorConstructorGenerator
"
]
class
TensorConstructorGenerator
(
StrategyGenerator
):
class
TensorConstructorGenerator
(
StrategyGenerator
):
...
@@ -30,10 +21,10 @@ class TensorConstructorGenerator(StrategyGenerator):
...
@@ -30,10 +21,10 @@ class TensorConstructorGenerator(StrategyGenerator):
strategy
.
compute_cost
=
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
):
'''
"""
Compute the memory cost per device with this specific strategy.
Compute the memory cost per device with this specific strategy.
'''
"""
forward_size_mapping
=
{
'
output
'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)}
forward_size_mapping
=
{
"
output
"
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)}
# compute fwd cost incurred
# compute fwd cost incurred
# fwd_cost = input + output
# fwd_cost = input + output
...
@@ -57,11 +48,13 @@ class TensorConstructorGenerator(StrategyGenerator):
...
@@ -57,11 +48,13 @@ class TensorConstructorGenerator(StrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
name
=
'
Replica Tensor Constructor
'
name
=
"
Replica Tensor Constructor
"
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
sharding_spec_mapping
=
sharding_spec_mapping
,
name
=
name
,
communication_action_mapping
=
communication_action_mapping
)
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
,
)
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
return
strategy_list
return
strategy_list
Prev
1
…
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment