Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
56088e6d
Unverified
Commit
56088e6d
authored
Oct 13, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 13, 2022
Browse files
[autoparallel] add pooling handler (#1690)
* [autoparallel] add pooling handler * polish code
parent
319d654f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
213 additions
and
1 deletion
+213
-1
colossalai/auto_parallel/solver/op_handler/normal_pooling_handler.py
...auto_parallel/solver/op_handler/normal_pooling_handler.py
+40
-0
colossalai/auto_parallel/solver/strategy/__init__.py
colossalai/auto_parallel/solver/strategy/__init__.py
+2
-1
colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py
...auto_parallel/solver/strategy/normal_pooling_generator.py
+117
-0
tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py
...o_parallel/test_node_handler/test_norm_pooling_handler.py
+54
-0
No files found.
colossalai/auto_parallel/solver/op_handler/normal_pooling_handler.py
0 → 100644
View file @
56088e6d
import
torch
import
torch.nn.functional
as
F
from
.node_handler
import
ModuleHandler
,
NodeHandler
from
..sharding_strategy
import
ShardingStrategy_V2
,
OperationDataType
,
OperationData
from
..strategy
import
NormalPoolStrategyGenerator
,
StrategyGenerator_V2
from
typing
import
List
,
Dict
from
.registry
import
operator_registry
__all__
=
[
'LinearModuleHandler'
,
'LinearFunctionHandler'
]
@
operator_registry
.
register
(
torch
.
nn
.
MaxPool1d
)
@
operator_registry
.
register
(
torch
.
nn
.
MaxPool2d
)
@
operator_registry
.
register
(
torch
.
nn
.
MaxPool1d
)
@
operator_registry
.
register
(
torch
.
nn
.
AvgPool1d
)
@
operator_registry
.
register
(
torch
.
nn
.
AvgPool2d
)
@
operator_registry
.
register
(
torch
.
nn
.
AvgPool3d
)
class
NormPoolingHandler
(
ModuleHandler
):
"""
A NormPoolingHandler which deals with the sharding strategies for nn.MaxPoolxd module.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator_V2
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
NormalPoolStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
physical_weight_operand
=
OperationData
(
name
=
"kernel"
,
type
=
OperationDataType
.
ARG
,
data
=
self
.
module
.
kernel_size
)
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_weight_operand
,
"output"
:
physical_output
}
return
mapping
colossalai/auto_parallel/solver/strategy/__init__.py
View file @
56088e6d
...
@@ -7,10 +7,11 @@ from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator
...
@@ -7,10 +7,11 @@ from .getitem_generator import GetItemStrategyGenerator, TensorStrategyGenerator
from
.layer_norm_generator
import
LayerNormGenerator
from
.layer_norm_generator
import
LayerNormGenerator
from
.where_generator
import
WhereGenerator
from
.where_generator
import
WhereGenerator
from
.reshape_generator
import
ReshapeGenerator
from
.reshape_generator
import
ReshapeGenerator
from
.normal_pooling_generator
import
NormalPoolStrategyGenerator
__all__
=
[
__all__
=
[
'StrategyGenerator_V2'
,
'DotProductStrategyGenerator'
,
'MatVecStrategyGenerator'
,
'StrategyGenerator_V2'
,
'DotProductStrategyGenerator'
,
'MatVecStrategyGenerator'
,
'LinearProjectionStrategyGenerator'
,
'BatchedMatMulStrategyGenerator'
,
'ConvStrategyGenerator'
,
'LinearProjectionStrategyGenerator'
,
'BatchedMatMulStrategyGenerator'
,
'ConvStrategyGenerator'
,
'UnaryElementwiseGenerator'
,
'BatchNormStrategyGenerator'
,
'GetItemStrategyGenerator'
,
'TensorStrategyGenerator'
,
'UnaryElementwiseGenerator'
,
'BatchNormStrategyGenerator'
,
'GetItemStrategyGenerator'
,
'TensorStrategyGenerator'
,
'TensorTupleStrategyGenerator'
,
'LayerNormGenerator'
,
"WhereGenerator"
,
'ReshapeGenerator'
'TensorTupleStrategyGenerator'
,
'LayerNormGenerator'
,
"WhereGenerator"
,
'ReshapeGenerator'
,
'NormalPoolStrategyGenerator'
]
]
colossalai/auto_parallel/solver/strategy/normal_pooling_generator.py
0 → 100644
View file @
56088e6d
import
operator
from
functools
import
reduce
from
..sharding_strategy
import
ShardingStrategy_V2
,
TrainCycleItem
,
MemoryCost
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
.strategy_generator
import
StrategyGenerator_V2
from
typing
import
List
from
.._utils
import
exception_handler
,
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
import
copy
class
NormalPoolStrategyGenerator
(
StrategyGenerator_V2
):
"""
NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd.
The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image,
and reduce them depening on the operation type.
"""
def
validate
(
self
)
->
bool
:
'''
In sanity check, we need make sure the input data having correct dimension size.
For Pool1d, the dim of input data should be 3([N, C, L]).
For Pool2d, the dim of input data should be 4([N, C, H, W]).
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data
=
self
.
op_data
[
'input'
]
assert
input_op_data
.
dim
()
in
(
3
,
4
,
5
),
f
'We suppose the dim of input fed into Pool op should in range of [3, 5].'
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# 1D: (Lout) * N * C * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'output'
]].
get_sharded_shape_per_device
()
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'input'
]].
get_sharded_shape_per_device
()
kernel_size
=
self
.
op_data
[
"other"
].
data
if
isinstance
(
kernel_size
,
int
):
kernel_size
=
[
kernel_size
]
*
(
len
(
sharded_output_shape
)
-
2
)
kernel_size_product
=
reduce
(
operator
.
mul
,
kernel_size
)
output_size_product
=
reduce
(
operator
.
mul
,
sharded_output_shape
)
input_size_product
=
reduce
(
operator
.
mul
,
sharded_input_shape
)
forward_compute_cost
=
output_size_product
*
kernel_size_product
backward_compute_cost
=
input_size_product
*
kernel_size_product
total_compute_cost
=
forward_compute_cost
+
backward_compute_cost
compute_cost
=
TrainCycleItem
(
fwd
=
forward_compute_cost
,
bwd
=
backward_compute_cost
,
total
=
total_compute_cost
)
return
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
ShardingStrategy_V2
:
forward_size_mapping
=
{
'input'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'output'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
}
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
.
pop
(
"output"
)
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()])
fwd_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
,
parameter
=
0
)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()])
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
0
)
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
0
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
def
_generate_strategy_with_dim_partition
(
self
,
dim_partition
):
dim_partition_dict_mapping
=
{
"input"
:
dim_partition
,
"output"
:
dim_partition
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
'
communication_action_mapping
=
{}
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
return
strategy
def
enumerate_all_possible_batch_dimensions_dim_partition
(
self
,
mesh_dim_0
,
mesh_dim_1
):
dim_partition_list
=
[]
dim_partition_list
.
extend
(
enumerate_all_possible_1d_sharding
(
mesh_dim_0
,
2
))
dim_partition_list
.
extend
(
enumerate_all_possible_1d_sharding
(
mesh_dim_1
,
2
))
dim_partition_list
.
extend
(
enumerate_all_possible_2d_sharding
(
mesh_dim_0
,
mesh_dim_1
,
2
))
# append {} for non_split case
dim_partition_list
.
append
({})
return
dim_partition_list
def
generate
(
self
)
->
List
[
ShardingStrategy_V2
]:
strategy_list
=
[]
dim_partition_list
=
self
.
enumerate_all_possible_batch_dimensions_dim_partition
(
0
,
1
)
for
dim_partition
in
dim_partition_list
:
strategy
=
self
.
_generate_strategy_with_dim_partition
(
dim_partition
)
strategy_list
.
append
(
strategy
)
for
strategy
in
strategy_list
:
self
.
update_communication_cost
(
strategy
)
self
.
update_compute_cost
(
strategy
)
self
.
update_memory_cost
(
strategy
)
return
strategy_list
tests/test_auto_parallel/test_node_handler/test_norm_pooling_handler.py
0 → 100644
View file @
56088e6d
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
import
torch
import
torch.nn
as
nn
from
colossalai.fx
import
ColoTracer
,
ColoGraphModule
from
colossalai.auto_parallel.solver.op_handler.normal_pooling_handler
import
NormPoolingHandler
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
def
test_norm_pool_handler
():
model
=
nn
.
Sequential
(
nn
.
MaxPool2d
(
4
,
padding
=
1
).
to
(
'meta'
))
tracer
=
ColoTracer
()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
4
,
64
,
64
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
conv_mod_node
=
list
(
graph
.
nodes
)[
1
]
strategies_vector
=
StrategiesVector
(
conv_mod_node
)
# build handler
handler
=
NormPoolingHandler
(
node
=
conv_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
4
,
64
,
64
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
4
,
64
,
64
])
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
4
,
16
,
16
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
()
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
assert
len
(
strategy_name_list
)
==
9
if
__name__
==
'__main__'
:
test_norm_pool_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment