Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
746f8f97
Unverified
Commit
746f8f97
authored
Sep 29, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 29, 2022
Browse files
[autoparallel] add batch norm handler v2 (#1666)
parent
9708638d
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
424 additions
and
0 deletions
+424
-0
colossalai/auto_parallel/solver/op_handler/batch_norm_handler_v2.py
.../auto_parallel/solver/op_handler/batch_norm_handler_v2.py
+45
-0
colossalai/auto_parallel/solver/strategy/batch_norm_generator.py
...lai/auto_parallel/solver/strategy/batch_norm_generator.py
+291
-0
tests/test_auto_parallel/test_node_handler/test_batch_norm_handler_v2.py
..._parallel/test_node_handler/test_batch_norm_handler_v2.py
+88
-0
No files found.
colossalai/auto_parallel/solver/op_handler/batch_norm_handler_v2.py
0 → 100644
View file @
746f8f97
import
torch
import
torch.nn.functional
as
F
from
.node_handler
import
ModuleHandler
,
NodeHandler
from
..sharding_strategy
import
ShardingStrategy_V2
,
OperationDataType
,
OperationData
from
..strategy
import
BatchNormStrategyGenerator
,
StrategyGenerator_V2
from
typing
import
List
,
Dict
from
.registry
import
operator_registry
__all__
=
[
'BatchNormModuleHandler'
]
@
operator_registry
.
register
(
torch
.
nn
.
BatchNorm1d
)
@
operator_registry
.
register
(
torch
.
nn
.
BatchNorm2d
)
@
operator_registry
.
register
(
torch
.
nn
.
BatchNorm3d
)
class
BatchNormModuleHandler
(
ModuleHandler
):
"""
A BatchNormModuleHandler which deals with the sharding strategies for nn.BatchNormXd module.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator_V2
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
BatchNormStrategyGenerator
(
op_data_mapping
,
self
.
device_mesh
))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
physical_other_operand
=
OperationData
(
name
=
"weight"
,
type
=
OperationDataType
.
PARAM
,
data
=
self
.
named_parameters
[
'weight'
],
logical_shape
=
self
.
named_parameters
[
'weight'
].
shape
)
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
if
self
.
named_parameters
[
'bias'
]
is
not
None
:
physical_bias_operand
=
OperationData
(
name
=
"bias"
,
type
=
OperationDataType
.
PARAM
,
data
=
self
.
named_parameters
[
'bias'
])
mapping
[
'bias'
]
=
physical_bias_operand
return
mapping
colossalai/auto_parallel/solver/strategy/batch_norm_generator.py
0 → 100644
View file @
746f8f97
import
operator
from
functools
import
reduce
from
..sharding_strategy
import
ShardingStrategy_V2
,
TrainCycleItem
,
MemoryCost
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
.strategy_generator
import
StrategyGenerator_V2
from
typing
import
List
from
.._utils
import
exception_handler
import
copy
__all__
=
[
'BatchNormStrategyGenerator'
]
class
BatchNormStrategyGenerator
(
StrategyGenerator_V2
):
"""
A StrategyGenerator which deals with the sharding strategies of batch normalization.
To keep the math consistency, there are two way to do BatchNorm if the input
shards on batch dimension:
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
us to keep the computing correctness.
In this generator, both methods will be considered.
"""
@
property
def
has_bias
(
self
):
return
'bias'
in
self
.
op_data
def
validate
(
self
)
->
bool
:
'''
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data
=
self
.
op_data
[
'input'
]
assert
input_op_data
.
dim
()
in
(
3
,
4
,
5
),
f
'We suppose the dim of input fed into conv op should in range of [3, 5].'
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
'''
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
sharded_input_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'input'
]].
get_sharded_shape_per_device
()
sharded_output_shape
=
strategy
.
sharding_specs
[
self
.
op_data
[
'output'
]].
get_sharded_shape_per_device
()
if
self
.
has_bias
:
# bias add is an element wise operation, so the cost is equal to product of output shape.
bias_compute_cost
=
reduce
(
operator
.
mul
,
sharded_output_shape
)
input_product
=
reduce
(
operator
.
mul
,
sharded_input_shape
,
1
)
forward_compute_cost
=
input_product
backward_activation_compute_cost
=
input_product
backward_weight_compute_cost
=
input_product
backward_compute_cost
=
backward_weight_compute_cost
+
backward_activation_compute_cost
if
self
.
has_bias
:
forward_compute_cost
+=
bias_compute_cost
backward_compute_cost
+=
bias_compute_cost
total_compute_cost
=
forward_compute_cost
+
backward_compute_cost
compute_cost
=
TrainCycleItem
(
fwd
=
forward_compute_cost
,
bwd
=
backward_compute_cost
,
total
=
total_compute_cost
)
return
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
forward_size_mapping
=
{
'input'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'other'
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
'output'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
}
if
self
.
has_bias
:
bias_size
=
self
.
_compute_size_in_bytes
(
strategy
,
"bias"
)
forward_size_mapping
[
'bias'
]
=
bias_size
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
.
pop
(
"output"
)
# compute fwd cost incurred
# fwd_cost = input + other + bias + output
fwd_activation_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)])
fwd_parameter_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
fwd_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
,
parameter
=
fwd_parameter_cost
)
# compute bwd cost incurred
# bwd_cost = input_grad + other_grad + bias_grad
bwd_activation_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)])
bwd_activation_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_activation_cost
)
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_activation_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
def
split_input_channel
(
self
,
mesh_dim_0
):
strategy_list
=
[]
name
=
f
'RS
{
mesh_dim_0
}
= RS
{
mesh_dim_0
}
x S
{
mesh_dim_0
}
'
dim_partition_dict_mapping
=
{
"input"
:
{
1
:
[
mesh_dim_0
]
},
"other"
:
{
0
:
[
mesh_dim_0
]
},
"output"
:
{
1
:
[
mesh_dim_0
]
},
}
if
self
.
has_bias
:
dim_partition_dict_mapping
[
"bias"
]
=
{
0
:
[
mesh_dim_0
]}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
def
split_input_channel_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'RS
{
mesh_dim_0
}{
mesh_dim_1
}
= RS
{
mesh_dim_0
}{
mesh_dim_1
}
x S
{
mesh_dim_0
}{
mesh_dim_1
}
'
dim_partition_dict_mapping
=
{
"input"
:
{
1
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"other"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"output"
:
{
1
:
[
mesh_dim_0
,
mesh_dim_1
]
},
}
if
self
.
has_bias
:
dim_partition_dict_mapping
[
"bias"
]
=
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
def
non_split
(
self
):
name
=
f
'RR = RR x R'
dim_partition_dict_mapping
=
{
"input"
:
{},
"other"
:
{},
"output"
:
{},
}
if
self
.
has_bias
:
dim_partition_dict_mapping
[
"bias"
]
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
def
split_input_batch
(
self
,
mesh_dim_0
):
name
=
f
'S
{
mesh_dim_0
}
R = S
{
mesh_dim_0
}
R x R WITH SYNC_BN'
dim_partition_dict_mapping
=
{
"input"
:
{
0
:
[
mesh_dim_0
]
},
"other"
:
{},
"output"
:
{
0
:
[
mesh_dim_0
]
},
}
if
self
.
has_bias
:
dim_partition_dict_mapping
[
"bias"
]
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# set communication action
# For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_0
)
communication_action_mapping
=
{
"output"
:
output_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
def
split_input_batch_1d
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'S
{
mesh_dim_0
}{
mesh_dim_1
}
R = S
{
mesh_dim_0
}{
mesh_dim_1
}
R x R WITH SYNC_BN'
dim_partition_dict_mapping
=
{
"input"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
"other"
:
{},
"output"
:
{
0
:
[
mesh_dim_0
,
mesh_dim_1
]
},
}
if
self
.
has_bias
:
dim_partition_dict_mapping
[
"bias"
]
=
{}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
communication_action_mapping
=
{
"output"
:
output_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
def
split_input_both_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
= S
{
mesh_dim_0
}
S
{
mesh_dim_1
}
x S
{
mesh_dim_1
}
WITH SYNC_BN'
dim_partition_dict_mapping
=
{
"input"
:
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
],
},
"other"
:
{
0
:
[
mesh_dim_1
],
},
"output"
:
{
0
:
[
mesh_dim_0
],
1
:
[
mesh_dim_1
],
},
}
if
self
.
has_bias
:
dim_partition_dict_mapping
[
"bias"
]
=
{
0
:
[
mesh_dim_1
],
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# set communication action
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
output_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
])
communication_action_mapping
=
{
"output"
:
output_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
def
generate
(
self
):
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
'''
strategy_list
=
[]
# RS = RS x S
strategy_list
.
append
(
self
.
split_input_channel
(
0
))
strategy_list
.
append
(
self
.
split_input_channel
(
1
))
# RR = RR x R
strategy_list
.
append
(
self
.
non_split
())
# RS01 = RS01 x S01
strategy_list
.
append
(
self
.
split_input_channel_1d
(
0
,
1
))
# SR = SR x R WITH SYNC_BN
strategy_list
.
append
(
self
.
split_input_batch
(
0
))
strategy_list
.
append
(
self
.
split_input_batch
(
1
))
# SS = SS x S WITH SYNC_BN
strategy_list
.
append
(
self
.
split_input_both_dim
(
0
,
1
))
strategy_list
.
append
(
self
.
split_input_both_dim
(
1
,
0
))
# S01R = S01R x R WITH SYNC_BN
strategy_list
.
append
(
self
.
split_input_batch_1d
(
0
,
1
))
return
strategy_list
tests/test_auto_parallel/test_node_handler/test_batch_norm_handler_v2.py
0 → 100644
View file @
746f8f97
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
import
torch
import
torch.nn
as
nn
from
colossalai.fx
import
ColoTracer
,
ColoGraphModule
from
colossalai.auto_parallel.solver.op_handler.batch_norm_handler_v2
import
BatchNormModuleHandler
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
def
test_bn_module_handler
():
model
=
nn
.
Sequential
(
nn
.
BatchNorm2d
(
16
).
to
(
'meta'
))
tracer
=
ColoTracer
()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %_0 : [#users=1] = call_module[target=0](args = (%input_1,), kwargs = {})
# return _0
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
16
,
64
,
64
).
to
(
'meta'
)})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
bn_mod_node
=
list
(
graph
.
nodes
)[
1
]
strategies_vector
=
StrategiesVector
(
bn_mod_node
)
# build handler
handler
=
BatchNormModuleHandler
(
node
=
bn_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
strategies_vector
)
# check operation data mapping
mapping
=
handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
logical_shape
is
not
None
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"input_1"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
16
,
64
,
64
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
16
,
64
,
64
])
assert
mapping
[
'other'
].
name
==
"weight"
assert
mapping
[
'other'
].
data
.
is_meta
assert
mapping
[
'other'
].
data
.
shape
==
torch
.
Size
([
16
])
assert
mapping
[
'other'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'other'
].
logical_shape
==
torch
.
Size
([
16
])
assert
mapping
[
'bias'
].
name
==
"bias"
assert
mapping
[
'bias'
].
data
.
is_meta
assert
mapping
[
'bias'
].
data
.
shape
==
torch
.
Size
([
16
])
assert
mapping
[
'bias'
].
type
==
OperationDataType
.
PARAM
assert
mapping
[
'bias'
].
logical_shape
==
torch
.
Size
([
16
])
assert
mapping
[
'output'
].
name
==
"_0"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
16
,
64
,
64
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
()
#[ 'S01R = S01R x R WITH SYNC_BN']
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
# RS = RS x S
assert
'RS0 = RS0 x S0'
in
strategy_name_list
assert
'RS1 = RS1 x S1'
in
strategy_name_list
# RR = RR x R
assert
'RR = RR x R'
in
strategy_name_list
# RS01 = RS01 x S01
assert
'RS01 = RS01 x S01'
in
strategy_name_list
# SR = SR x R WITH SYNC_BN
assert
'S0R = S0R x R WITH SYNC_BN'
in
strategy_name_list
assert
'S1R = S1R x R WITH SYNC_BN'
in
strategy_name_list
# SS = SS x S WITH SYNC_BN
assert
'S0S1 = S0S1 x S1 WITH SYNC_BN'
in
strategy_name_list
assert
'S1S0 = S1S0 x S0 WITH SYNC_BN'
in
strategy_name_list
# S01R = S01R x R WITH SYNC_BN
assert
'S01R = S01R x R WITH SYNC_BN'
in
strategy_name_list
if
__name__
==
'__main__'
:
test_bn_module_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment