Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
0c703189
"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "8e3d0ad8f1d098baf9731eb4f57b90c6c2c0a34e"
Unverified
Commit
0c703189
authored
Sep 23, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 23, 2022
Browse files
[autoparallel] add layernorm handler (#1629)
parent
bf77d3ab
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
434 additions
and
62 deletions
+434
-62
colossalai/auto_parallel/solver/_utils.py
colossalai/auto_parallel/solver/_utils.py
+37
-1
colossalai/auto_parallel/solver/constants.py
colossalai/auto_parallel/solver/constants.py
+16
-2
colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py
...salai/auto_parallel/solver/op_handler/bcast_op_handler.py
+6
-34
colossalai/auto_parallel/solver/op_handler/layer_norm_handler.py
...lai/auto_parallel/solver/op_handler/layer_norm_handler.py
+233
-0
colossalai/auto_parallel/solver/strategies_constructor.py
colossalai/auto_parallel/solver/strategies_constructor.py
+72
-25
tests/test_auto_parallel/test_layer_norm_handler.py
tests/test_auto_parallel/test_layer_norm_handler.py
+70
-0
No files found.
colossalai/auto_parallel/solver/_utils.py
View file @
0c703189
...
...
@@ -94,7 +94,43 @@ def exception_handler(func):
def
wrapper
(
*
args
,
**
kwargs
):
try
:
func
(
*
args
,
**
kwargs
)
except
Exception
as
e
:
except
AssertionError
as
e
:
warnings
.
warn
(
f
'
{
e
}
'
)
return
wrapper
def
enumerate_all_possible_2d_sharding
(
mesh_dim_0
,
mesh_dim_1
,
dim_size
):
dim_partition_list
=
[]
# enumerate all the 2D sharding cases
for
i
in
range
(
dim_size
):
for
j
in
range
(
i
+
1
,
dim_size
):
dim_partition_dict_0
=
{
i
:
[
mesh_dim_0
],
j
:
[
mesh_dim_1
]}
dim_partition_dict_1
=
{
i
:
[
mesh_dim_1
],
j
:
[
mesh_dim_0
]}
dim_partition_list
.
append
(
dim_partition_dict_0
)
dim_partition_list
.
append
(
dim_partition_dict_1
)
for
i
in
range
(
dim_size
):
dim_partition_dict_flatten
=
{
i
:
[
mesh_dim_0
,
mesh_dim_1
]}
dim_partition_list
.
append
(
dim_partition_dict_flatten
)
return
dim_partition_list
def
enumerate_all_possible_1d_sharding
(
mesh_dim_0
,
dim_size
):
dim_partition_list
=
[]
# enumerate all the 1D sharding cases
for
i
in
range
(
dim_size
):
dim_partition_dict_0
=
{
i
:
[
mesh_dim_0
]}
dim_partition_list
.
append
(
dim_partition_dict_0
)
return
dim_partition_list
def
generate_sharding_size
(
dim_partition_dict
,
device_mesh
):
total_sharding_size
=
1
for
mesh_dim_list
in
dim_partition_dict
.
values
():
mesh_dim_sharding_size
=
[
device_mesh
.
shape
[
mesh_dim
]
for
mesh_dim
in
mesh_dim_list
]
sharding_size
=
reduce
(
operator
.
mul
,
mesh_dim_sharding_size
)
total_sharding_size
*=
sharding_size
return
total_sharding_size
colossalai/auto_parallel/solver/constants.py
View file @
0c703189
...
...
@@ -3,7 +3,8 @@ import operator
__all__
=
[
'ELEMENTWISE_MODULE_OP'
,
'ELEMENTWISE_FUNC_OP'
,
'RESHAPE_FUNC_OP'
,
'CONV_MODULE_OP'
,
'CONV_FUNC_OP'
,
'LINEAR_MODULE_OP'
,
'LINEAR_FUNC_OP'
,
'BATCHNORM_MODULE_OP'
,
'POOL_MODULE_OP'
,
'NON_PARAM_FUNC_OP'
,
'BCAST_FUNC_OP'
'LINEAR_MODULE_OP'
,
'LINEAR_FUNC_OP'
,
'BATCHNORM_MODULE_OP'
,
'POOL_MODULE_OP'
,
'NON_PARAM_FUNC_OP'
,
'BCAST_FUNC_OP'
,
'EMBEDDING_MODULE_OP'
,
'LAYERNORM_MODULE_OP'
,
'ELEMENTWISE_METHOD_OP'
,
'RESHAPE_METHOD_OP'
]
ELEMENTWISE_MODULE_OP
=
[
torch
.
nn
.
Dropout
,
torch
.
nn
.
ReLU
]
...
...
@@ -11,7 +12,18 @@ ELEMENTWISE_FUNC_OP = [
torch
.
abs
,
torch
.
cos
,
torch
.
exp
,
operator
.
neg
,
torch
.
multiply
,
torch
.
nn
.
functional
.
relu
,
torch
.
nn
.
functional
.
dropout
,
torch
.
flatten
]
RESHAPE_FUNC_OP
=
[
torch
.
flatten
,
torch
.
Tensor
.
view
,
torch
.
reshape
]
ELEMENTWISE_METHOD_OP
=
[
torch
.
Tensor
.
to
,
torch
.
Tensor
.
type
,
]
RESHAPE_FUNC_OP
=
[
torch
.
flatten
,
torch
.
reshape
]
RESHAPE_METHOD_OP
=
[
torch
.
Tensor
.
view
,
torch
.
Tensor
.
unsqueeze
,
torch
.
Tensor
.
split
,
torch
.
Tensor
.
permute
,
torch
.
Tensor
.
transpose
,
]
BCAST_FUNC_OP
=
[
torch
.
add
,
torch
.
sub
,
torch
.
mul
,
torch
.
div
,
torch
.
floor_divide
,
torch
.
true_divide
,
operator
.
add
,
operator
.
sub
,
operator
.
mul
,
operator
.
floordiv
,
operator
.
truediv
,
torch
.
matmul
...
...
@@ -23,9 +35,11 @@ CONV_MODULE_OP = [
CONV_FUNC_OP
=
[
torch
.
conv1d
,
torch
.
conv2d
,
torch
.
conv3d
,
torch
.
conv_transpose1d
,
torch
.
conv_transpose2d
,
torch
.
conv_transpose3d
]
EMBEDDING_MODULE_OP
=
[
torch
.
nn
.
modules
.
sparse
.
Embedding
]
LINEAR_MODULE_OP
=
[
torch
.
nn
.
Linear
]
LINEAR_FUNC_OP
=
[
torch
.
nn
.
functional
.
linear
,
torch
.
matmul
,
torch
.
bmm
]
BATCHNORM_MODULE_OP
=
[
torch
.
nn
.
BatchNorm1d
,
torch
.
nn
.
BatchNorm2d
,
torch
.
nn
.
BatchNorm3d
,
torch
.
nn
.
SyncBatchNorm
]
LAYERNORM_MODULE_OP
=
[
torch
.
nn
.
LayerNorm
]
POOL_MODULE_OP
=
[
torch
.
nn
.
MaxPool1d
,
torch
.
nn
.
MaxPool2d
,
torch
.
nn
.
MaxPool3d
,
torch
.
nn
.
AdaptiveAvgPool2d
]
NON_PARAM_FUNC_OP
=
RESHAPE_FUNC_OP
+
ELEMENTWISE_FUNC_OP
...
...
colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py
View file @
0c703189
...
...
@@ -8,7 +8,7 @@ from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
copy
import
deepcopy
from
typing
import
Dict
,
List
from
colossalai.auto_parallel.solver._utils
import
exception_handler
from
colossalai.auto_parallel.solver._utils
import
exception_handler
,
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
__all__
=
[
'BcastOpHandler'
]
...
...
@@ -110,45 +110,19 @@ class BcastOpHandler(OperatorHandler):
return
sharding_spec_list
def
_enumerate_all_possible_2d_sharding
(
self
,
mesh_dim_0
,
mesh_dim_1
,
dim_size
):
dim_partition_list
=
[]
# enumerate all the 2D sharding cases
for
i
in
range
(
dim_size
):
for
j
in
range
(
i
+
1
,
dim_size
):
dim_partition_dict_0
=
{
i
:
[
mesh_dim_0
],
j
:
[
mesh_dim_1
]}
dim_partition_dict_1
=
{
i
:
[
mesh_dim_1
],
j
:
[
mesh_dim_0
]}
dim_partition_list
.
append
(
dim_partition_dict_0
)
dim_partition_list
.
append
(
dim_partition_dict_1
)
for
i
in
range
(
dim_size
):
dim_partition_dict_flatten
=
{
i
:
[
mesh_dim_0
,
mesh_dim_1
]}
dim_partition_list
.
append
(
dim_partition_dict_flatten
)
# sharding_spec_list = self._convert_partition_dict_to_sharding_spec(dim_partition_list)
return
dim_partition_list
def
_enumerate_all_possible_1d_sharding
(
self
,
mesh_dim_0
,
dim_size
):
dim_partition_list
=
[]
# enumerate all the 1D sharding cases
for
i
in
range
(
dim_size
):
dim_partition_dict_0
=
{
i
:
[
mesh_dim_0
]}
dim_partition_list
.
append
(
dim_partition_dict_0
)
# sharding_spec_list = self._convert_partition_dict_to_sharding_spec(dim_partition_list)
return
dim_partition_list
def
_enumerate_all_possible_output
(
self
,
mesh_dim_0
,
mesh_dim_1
):
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
output_dim_partition_list
=
[]
dim_size
=
self
.
output_data
.
dim
()
# enumerate all the 2D sharding cases
sharding_list_2d
=
self
.
_
enumerate_all_possible_2d_sharding
(
mesh_dim_0
,
mesh_dim_1
,
dim_size
)
sharding_list_2d
=
enumerate_all_possible_2d_sharding
(
mesh_dim_0
,
mesh_dim_1
,
dim_size
)
output_dim_partition_list
.
extend
(
sharding_list_2d
)
# enumerate all the 1D sharding cases
sharding_list_1d_on_dim_0
=
self
.
_
enumerate_all_possible_1d_sharding
(
mesh_dim_0
,
dim_size
)
sharding_list_1d_on_dim_0
=
enumerate_all_possible_1d_sharding
(
mesh_dim_0
,
dim_size
)
output_dim_partition_list
.
extend
(
sharding_list_1d_on_dim_0
)
sharding_list_1d_on_dim_1
=
self
.
_
enumerate_all_possible_1d_sharding
(
mesh_dim_1
,
dim_size
)
sharding_list_1d_on_dim_1
=
enumerate_all_possible_1d_sharding
(
mesh_dim_1
,
dim_size
)
output_dim_partition_list
.
extend
(
sharding_list_1d_on_dim_1
)
# add empty dict for fully replicated case
...
...
@@ -545,15 +519,13 @@ class BcastOpHandler(OperatorHandler):
dim_size
=
self
.
output_data
.
dim
()
-
2
# Both device mesh axises are uesd on batch dimensions
dim_partition_dicts_2d
=
self
.
_enumerate_all_possible_2d_sharding
(
MESH_DIM_LIST
[
0
],
MESH_DIM_LIST
[
1
],
dim_size
)
dim_partition_dicts_2d
=
enumerate_all_possible_2d_sharding
(
MESH_DIM_LIST
[
0
],
MESH_DIM_LIST
[
1
],
dim_size
)
for
dim_partition_dict
in
dim_partition_dicts_2d
:
self
.
_registry_no_split_strategies_for_matmul
(
dim_partition_dict
)
# Only one device mesh axis is uesd on batch dimensions
for
mesh_dim_index
in
[
0
,
1
]:
dim_partition_dicts_1d
=
self
.
_enumerate_all_possible_1d_sharding
(
MESH_DIM_LIST
[
mesh_dim_index
],
dim_size
)
dim_partition_dicts_1d
=
enumerate_all_possible_1d_sharding
(
MESH_DIM_LIST
[
mesh_dim_index
],
dim_size
)
for
dim_partition_dict
in
dim_partition_dicts_1d
:
self
.
_registry_no_split_strategies_for_matmul
(
dim_partition_dict
)
self
.
_registry_1d_strategies_for_matmul
(
dim_partition_dict
,
[
MESH_DIM_LIST
[
mesh_dim_index
-
1
]])
...
...
colossalai/auto_parallel/solver/op_handler/layer_norm_handler.py
0 → 100644
View file @
0c703189
import
operator
from
functools
import
reduce
import
torch
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.operator_handler
import
OperatorHandler
from
colossalai.auto_parallel.solver._utils
import
exception_handler
,
enumerate_all_possible_2d_sharding
,
enumerate_all_possible_1d_sharding
,
generate_sharding_size
__all__
=
[
'LayerNormHandler'
]
class
LayerNormHandler
(
OperatorHandler
):
"""
A OperatorHandler which deals with the sharding strategies of normalization.
Note: To keep the math consistency, LayerNorm do not allow shards on hidden dimension.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
input_data
=
self
.
predecessor_node
[
0
].
_meta_data
self
.
weight
=
self
.
module_named_parameters
[
'weight'
]
self
.
bias
=
self
.
module_named_parameters
[
'bias'
]
self
.
output_data
=
self
.
node
.
_meta_data
def
_generate_compute_cost
(
self
,
total_sharding_size
):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
Argument:
bs(int): Batch size of the input data.
channel_in(int): The channel dimension of input data.
Return:
compute_cost(float): Computation cost per device with this specific strategy
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
norm_kernel_size
=
self
.
weight
.
shape
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
input_batch_shape
=
self
.
input_data
.
shape
[:
-
len
(
norm_kernel_size
)]
input_batch_product
=
reduce
(
operator
.
mul
,
input_batch_shape
,
1
)
norm_kernel_product
=
reduce
(
operator
.
mul
,
norm_kernel_size
,
1
)
forward_compute_cost
=
input_batch_product
*
norm_kernel_product
/
total_sharding_size
backward_activation_compute_cost
=
input_batch_product
*
norm_kernel_product
/
total_sharding_size
# To compute gradient of on norm kernel element requires input_batch_product times computation, so
# the total cost is input_batch_product * norm_kernel_product
backward_weight_compute_cost
=
input_batch_product
*
norm_kernel_product
/
total_sharding_size
backward_compute_cost
=
backward_activation_compute_cost
+
backward_weight_compute_cost
compute_cost
=
forward_compute_cost
+
backward_compute_cost
return
compute_cost
def
_generate_memory_cost
(
self
,
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype
=
self
.
input_data
.
dtype
numel_output
=
self
.
output_data
.
numel
()
# this operation will not change the shape of input
numel_input
=
numel_output
numel_weight
=
self
.
weight
.
numel
()
size_per_elem_bytes
=
torch
.
tensor
([],
dtype
=
dtype
).
element_size
()
# forward memory_cost
memory_cost_forward_activation
=
numel_output
*
size_per_elem_bytes
/
sharding_size_forward
memory_cost_forward_weight
=
numel_weight
*
size_per_elem_bytes
/
sharding_size_weight
memory_cost_forward
=
memory_cost_forward_activation
+
memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation
=
numel_input
*
size_per_elem_bytes
/
sharding_size_backward_activation
memory_cost_backward_weight
=
numel_weight
*
size_per_elem_bytes
/
sharding_size_weight
memory_cost_backward
=
memory_cost_backward_activation
+
memory_cost_backward_weight
# memory_cost pair
memory_cost
=
(
memory_cost_forward
,
memory_cost_backward
)
return
memory_cost
,
memory_cost_forward_activation
,
memory_cost_backward_activation
,
memory_cost_backward_weight
def
_generate_strategy_with_dim_partition
(
self
,
dim_partition
):
dim_partition_dict_for_input
=
dim_partition
sharding_spec_for_input
=
self
.
_generate_sharding_spec
(
self
.
input_data
,
dim_partition_dict_for_input
)
dim_partition_dict_for_weight
=
{}
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
dim_partition
sharding_spec_for_output
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
name
=
f
'
{
sharding_spec_for_output
.
sharding_sequence
}
=
{
sharding_spec_for_input
.
sharding_sequence
}
x
{
sharding_spec_for_weight
.
sharding_sequence
}
'
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
total_sharding_size
=
generate_sharding_size
(
dim_partition
,
self
.
device_mesh
)
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
total_sharding_size
)
# compute the memory cost of this strategy
sharding_size_forward
=
generate_sharding_size
(
dim_partition_dict_for_input
,
self
.
device_mesh
)
sharding_size_backward_activation
=
generate_sharding_size
(
dim_partition_dict_for_output
,
self
.
device_mesh
)
sharding_size_weight
=
generate_sharding_size
(
dim_partition_dict_for_weight
,
self
.
device_mesh
)
memory_cost
,
_
,
_
,
memory_cost_backward_weight
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
total_mesh_dim_list
=
[]
for
mesh_dim_list
in
dim_partition
.
values
():
total_mesh_dim_list
.
extend
(
mesh_dim_list
)
# This strategy do not need to do all_reduce operation for activation
communication_cost_forward_activation
=
0
communication_cost_backward_activation
=
0
if
len
(
total_mesh_dim_list
)
==
1
:
communication_cost_backward_weight
=
self
.
device_mesh
.
all_reduce_cost
(
memory_cost_backward_weight
,
total_mesh_dim_list
[
0
])
else
:
assert
len
(
total_mesh_dim_list
)
==
2
,
f
'temporally we just support 2d device mesh.'
communication_cost_backward_weight
=
self
.
device_mesh
.
flatten_device_mesh
.
all_reduce_cost
(
memory_cost_backward_weight
,
0
)
communication_cost
=
communication_cost_forward_activation
+
communication_cost_backward_activation
+
communication_cost_backward_weight
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_output
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
append
(
sharding_strategies
)
@
exception_handler
def
split_input_batch_single_mesh_dim
(
self
,
mesh_dim_0
):
batch_dimension_length
=
self
.
input_data
.
dim
()
-
self
.
weight
.
dim
()
dim_partition_list
=
enumerate_all_possible_1d_sharding
(
mesh_dim_0
,
batch_dimension_length
)
for
dim_partition
in
dim_partition_list
:
self
.
_generate_strategy_with_dim_partition
(
dim_partition
)
@
exception_handler
def
split_input_batch_both_mesh_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
batch_dimension_length
=
self
.
input_data
.
dim
()
-
self
.
weight
.
dim
()
dim_partition_list
=
enumerate_all_possible_2d_sharding
(
mesh_dim_0
,
mesh_dim_1
,
batch_dimension_length
)
for
dim_partition
in
dim_partition_list
:
self
.
_generate_strategy_with_dim_partition
(
dim_partition
)
@
exception_handler
def
non_split
(
self
):
name
=
f
'RR = RR x R'
dim_partition_dict_for_input
=
{}
sharding_spec_for_input
=
self
.
_generate_sharding_spec
(
self
.
input_data
,
dim_partition_dict_for_input
)
dim_partition_dict_for_weight
=
{}
sharding_spec_for_weight
=
self
.
_generate_sharding_spec
(
self
.
weight
,
dim_partition_dict_for_weight
)
dim_partition_dict_for_output
=
{}
sharding_spec_for_output
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
dim_partition_dict_for_output
)
# generate resharding cost for this strategy
resharding_costs
=
self
.
_generate_resharding_costs
([
sharding_spec_for_input
])
total_sharding_size
=
1
# compute the computation cost of this strategy
compute_cost
=
self
.
_generate_compute_cost
(
total_sharding_size
)
# compute the memory cost of this strategy
sharding_size_forward
=
1
sharding_size_backward_activation
=
1
sharding_size_weight
=
1
memory_cost
,
_
,
_
,
_
=
self
.
_generate_memory_cost
(
sharding_size_forward
,
sharding_size_backward_activation
,
sharding_size_weight
)
# This strategy do not need to do all_reduce operation
communication_cost
=
0
sharding_strategies
=
ShardingStrategy
(
name
,
output_sharding_spec
=
sharding_spec_for_output
,
compute_cost
=
compute_cost
,
communication_cost
=
communication_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
(
sharding_spec_for_input
,
sharding_spec_for_weight
))
self
.
strategies_vector
.
append
(
sharding_strategies
)
def
register_strategy
(
self
)
->
StrategiesVector
:
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
Example:
norm_handler = BatchNormHandler(node, strategies_vector,
self.shape_consistency_manager)
norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector:
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
Output:
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
'''
# SR = SR x R with single mesh dim on batch dimensions
self
.
split_input_batch_single_mesh_dim
(
0
)
self
.
split_input_batch_single_mesh_dim
(
1
)
# SR = SR x R with both mesh dims on batch dimensions
self
.
split_input_batch_both_mesh_dim
(
0
,
1
)
# RR = RR x R
self
.
non_split
()
return
self
.
strategies_vector
colossalai/auto_parallel/solver/strategies_constructor.py
View file @
0c703189
from
torch.fx
import
Graph
,
Node
from
colossalai.auto_parallel.solver.op_handler.bcast_op_handler
import
BcastOpHandler
from
colossalai.auto_parallel.solver.op_handler.layer_norm_handler
import
LayerNormHandler
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
...
...
@@ -216,6 +217,15 @@ class StrategiesConstructor:
input_shardings
=
[
input_sharding_spec
])
strategies_vector
.
append
(
sharding_strategy
)
# embedding module
elif
submod_type
in
EMBEDDING_MODULE_OP
:
embedding_handler
=
EmbeddingHandler
(
node
,
self
.
device_mesh
,
strategies_vector
)
embedding_handler
.
register_strategy
()
# layernorm module
elif
submod_type
in
LAYERNORM_MODULE_OP
:
layernorm_handler
=
LayerNormHandler
(
node
,
self
.
device_mesh
,
strategies_vector
)
layernorm_handler
.
register_strategy
()
# other module
else
:
raise
RuntimeError
(
f
'
{
submod_type
}
module is NOT supported now.'
)
...
...
@@ -349,35 +359,72 @@ class StrategiesConstructor:
elif
target
==
operator
.
getitem
:
index
=
node
.
args
[
1
]
input_tensor_node
=
strategies_vector
.
predecessor_nodes
[
0
]
for
strategy
in
input_tensor_node
.
strategies_vector
:
input_sharding_spec
=
input_tensor_node
.
output_sharding_spec
[
index
]
assert
isinstance
(
input_sharding_spec
,
ShardingSpec
),
f
'This assertion is used to debug.'
dim_partition_dict_for_output
=
deepcopy
(
input_sharding_spec
.
dim_partition_dict
)
entire_shape_output
=
deepcopy
(
input_sharding_spec
.
entire_shape
)
output_sharding_spec
=
ShardingSpec
(
self
.
device_mesh
,
entire_shape_output
,
dim_partition_dict
=
dim_partition_dict_for_output
)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost
=
0
memory_cost
=
0
resharding_costs
=
generate_resharding_costs
(
strategies_vector
.
predecessor_nodes
,
[
input_sharding_spec
])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs
[
input_tensor_node
]
=
[
cost
if
cost
==
0
else
math
.
inf
for
cost
in
resharding_costs
[
input_tensor_node
]
]
sharding_strategy
=
ShardingStrategy
(
name
,
output_sharding_spec
,
compute_cost
=
compute_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
[
input_tensor_node
.
output_sharding_spec
])
strategies_vector
.
append
(
sharding_strategy
)
if
isinstance
(
input_tensor_node
,
torch
.
Tensor
):
for
strategy
in
input_tensor_node
.
strategies_vector
:
input_sharding_spec
=
strategy
.
output_sharding_spec
[
index
]
assert
isinstance
(
input_sharding_spec
,
ShardingSpec
),
f
'This assertion is used to debug.'
dim_partition_dict_for_output
=
deepcopy
(
input_sharding_spec
.
dim_partition_dict
)
entire_shape_output
=
deepcopy
(
input_sharding_spec
.
entire_shape
)
output_sharding_spec
=
ShardingSpec
(
self
.
device_mesh
,
entire_shape_output
,
dim_partition_dict
=
dim_partition_dict_for_output
)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost
=
0
memory_cost
=
0
resharding_costs
=
generate_resharding_costs
(
strategies_vector
.
predecessor_nodes
,
[
input_sharding_spec
])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs
[
input_tensor_node
]
=
[
cost
if
cost
==
0
else
math
.
inf
for
cost
in
resharding_costs
[
input_tensor_node
]
]
sharding_strategy
=
ShardingStrategy
(
name
,
output_sharding_spec
,
compute_cost
=
compute_cost
,
memory_cost
=
memory_cost
,
resharding_costs
=
resharding_costs
,
input_shardings
=
[
input_tensor_node
.
output_sharding_spec
])
strategies_vector
.
append
(
sharding_strategy
)
# torch.arange function
elif
target
==
torch
.
arange
:
name
=
f
'FULLY REPLICATED ARANGE'
entire_shape_output
=
node
.
_meta_data
.
shape
dim_partition_dict_for_output
=
{}
output_sharding_spec
=
ShardingSpec
(
self
.
device_mesh
,
entire_shape_output
,
dim_partition_dict
=
dim_partition_dict_for_output
)
memory_cost
=
node
.
_meta_data
.
numel
()
sharding_strategy
=
ShardingStrategy
(
name
,
output_sharding_spec
,
compute_cost
=
0
,
memory_cost
=
memory_cost
)
strategies_vector
.
append
(
sharding_strategy
)
# op list to be processed to support gpt2
elif
target
in
(
builtins
.
getattr
,
operator
.
le
,
torch
.
addmm
,
operator
.
pow
,
torch
.
where
,
torch
.
softmax
,
torch
.
nn
.
functional
.
softmax
,
torch
.
pow
,
torch
.
tanh
):
pass
# other function
else
:
raise
RuntimeError
(
f
'
{
target
}
function is NOT supported now.'
)
# call_method node
if
node
.
op
==
'call_method'
:
method
=
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
if
method
in
(
torch
.
Tensor
.
size
,
torch
.
Tensor
.
contiguous
):
pass
elif
method
in
ELEMENTWISE_METHOD_OP
:
unary_elementwise_handler
=
UnaryElementwiseHandler
(
node
,
self
.
device_mesh
,
strategies_vector
)
unary_elementwise_handler
.
register_strategy
()
elif
method
in
RESHAPE_METHOD_OP
:
reshape_handler
=
ReshapeHandler
(
node
,
self
.
device_mesh
,
strategies_vector
)
reshape_handler
.
register_strategy
()
else
:
raise
RuntimeError
(
f
'
{
method
}
function is NOT supported now.'
)
# output node
if
node
.
op
==
'output'
:
if
self
.
solver_options
.
fast
:
...
...
tests/test_auto_parallel/test_layer_norm_handler.py
0 → 100644
View file @
0c703189
import
torch
from
torch.fx
import
GraphModule
import
torch.nn
as
nn
import
pytest
from
colossalai.auto_parallel.solver
import
sharding_strategy
from
colossalai.fx.proxy
import
ColoProxy
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
_DimSpec
from
colossalai.auto_parallel.solver.op_handler.layer_norm_handler
import
LayerNormHandler
from
colossalai.auto_parallel.solver.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
class
LNModel
(
nn
.
Module
):
def
__init__
(
self
,
c
):
super
().
__init__
()
self
.
ln
=
nn
.
LayerNorm
(
c
)
def
forward
(
self
,
x
):
x
=
x
*
2
x
=
self
.
ln
(
x
)
return
x
def
test_bn_handler
():
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
entire_shape
=
torch
.
Size
((
4
,
4
,
128
))
tracer
=
ColoTracer
()
model
=
LNModel
(
128
)
input_sample
=
{
'x'
:
torch
.
rand
(
4
,
4
,
128
).
to
(
'meta'
)}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %ln : [#users=1] = call_module[target=ln](args = (%mul,), kwargs = {})
# return ln
graph
=
tracer
.
trace
(
root
=
model
,
meta_args
=
input_sample
)
gm
=
GraphModule
(
model
,
graph
,
model
.
__class__
.
__name__
)
gm
.
recompile
()
# [x, mul, ln, output]
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
sharding_spec_for_input
=
ShardingSpec
(
device_mesh
,
entire_shape
,
{})
sharding_strategy_for_input
=
ShardingStrategy
(
'node_1'
,
sharding_spec_for_input
)
strategies_vector_for_input
=
StrategiesVector
(
nodes
[
1
])
strategies_vector_for_input
.
append
(
sharding_strategy_for_input
)
setattr
(
nodes
[
1
],
'strategies_vector'
,
strategies_vector_for_input
)
# generate bn strategy
strategies_vector
=
StrategiesVector
(
node
=
nodes
[
2
])
ln_handler
=
LayerNormHandler
(
node
=
nodes
[
2
],
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
,
)
ln_handler
.
register_strategy
()
# ['[S0, R, R] = [S0, R, R] x [R]', '[R, S0, R] = [R, S0, R] x [R]', '[S1, R, R] = [S1, R, R] x [R]', '[R, S1, R] = [R, S1, R] x [R]',
# '[S0, S1, R] = [S0, S1, R] x [R]', '[S1, S0, R] = [S1, S0, R] x [R]', '[S01, R, R] = [S01, R, R] x [R]', '[R, S01, R] = [R, S01, R] x [R]', 'RR = RR x R']
strategy_name_list
=
[
strategy
.
name
for
strategy
in
ln_handler
.
strategies_vector
]
assert
len
(
strategy_name_list
)
==
9
if
__name__
==
'__main__'
:
test_bn_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment