Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
af718e83
Unverified
Commit
af718e83
authored
Oct 11, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 11, 2022
Browse files
[autoparallel] add reshape handler v2 and fix some previous bug (#1683)
parent
6878e422
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
250 additions
and
23 deletions
+250
-23
colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py
...lai/auto_parallel/solver/op_handler/reshape_handler_v2.py
+35
-0
colossalai/auto_parallel/solver/strategy/__init__.py
colossalai/auto_parallel/solver/strategy/__init__.py
+4
-3
colossalai/auto_parallel/solver/strategy/batch_norm_generator.py
...lai/auto_parallel/solver/strategy/batch_norm_generator.py
+3
-3
colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py
.../auto_parallel/solver/strategy/conv_strategy_generator.py
+3
-3
colossalai/auto_parallel/solver/strategy/getitem_generator.py
...ssalai/auto_parallel/solver/strategy/getitem_generator.py
+4
-4
colossalai/auto_parallel/solver/strategy/layer_norm_generator.py
...lai/auto_parallel/solver/strategy/layer_norm_generator.py
+6
-3
colossalai/auto_parallel/solver/strategy/reshape_generator.py
...ssalai/auto_parallel/solver/strategy/reshape_generator.py
+100
-0
colossalai/auto_parallel/solver/strategy/strategy_generator.py
...salai/auto_parallel/solver/strategy/strategy_generator.py
+10
-3
colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py
...o_parallel/solver/strategy/unary_elementwise_generator.py
+4
-4
tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py
...uto_parallel/test_node_handler/test_reshape_handler_v2.py
+81
-0
No files found.
colossalai/auto_parallel/solver/op_handler/reshape_handler_v2.py
0 → 100644
View file @
af718e83
import
torch
from
.node_handler
import
NodeHandler
from
..sharding_strategy
import
ShardingStrategy_V2
,
OperationDataType
,
OperationData
,
StrategiesVector
from
..strategy
import
ReshapeGenerator
,
StrategyGenerator_V2
from
typing
import
List
,
Dict
from
.registry
import
operator_registry
import
operator
__all__
=
[
'ReshapeHandler'
]
@
operator_registry
.
register
(
torch
.
reshape
)
@
operator_registry
.
register
(
torch
.
Tensor
.
permute
)
class
ReshapeHandler
(
NodeHandler
):
"""
A ReshapeHandler which deals with the sharding strategies for Reshape Op, such as torch.reshape.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator_V2
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
ReshapeGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"output"
:
physical_output
}
return
mapping
colossalai/auto_parallel/solver/strategy/__init__.py
View file @
af718e83
...
...
@@ -5,10 +5,11 @@ from .batch_norm_generator import BatchNormStrategyGenerator
from
.unary_elementwise_generator
import
UnaryElementwiseGenerator
from
.getitem_generator
import
GetItemStrategyGenerator
,
TensorStrategyGenerator
,
TensorTupleStrategyGenerator
from
.layer_norm_generator
import
LayerNormGenerator
from
.reshape_generator
import
ReshapeGenerator
__all__
=
[
'StrategyGenerator_V2'
,
'DotProductStrategyGenerator'
,
'MatVecStrategyGenerator'
,
'LinearProjectionStrategyGenerator'
,
'BatchedMatMulStrategyGenerator'
,
'ConvStrategyGenerator'
,
'UnaryElementwiseGenerator'
,
'BatchNormStrategyGenerator'
,
'GetItemStrategyGenerator'
,
'TensorStrategyGenerator'
,
'TensorTupleStrategyGenerator'
,
'
LayerNorm
Generator'
'LinearProjectionStrategyGenerator'
,
'BatchedMatMulStrategyGenerator'
,
'ConvStrategyGenerator'
,
'UnaryElementwiseGenerator'
,
'BatchNormStrategyGenerator'
,
'GetItemStrategyGenerator'
,
'TensorStrategyGenerator'
,
'
TensorTupleStrategyGenerator'
,
'LayerNormGenerator'
,
'Reshape
Generator'
]
colossalai/auto_parallel/solver/strategy/batch_norm_generator.py
View file @
af718e83
...
...
@@ -37,7 +37,7 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2):
assert
input_op_data
.
dim
()
in
(
3
,
4
,
5
),
f
'We suppose the dim of input fed into conv op should in range of [3, 5].'
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
'''
Compute the computation cost per device with this specific strategy.
...
...
@@ -62,9 +62,9 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2):
backward_compute_cost
+=
bias_compute_cost
total_compute_cost
=
forward_compute_cost
+
backward_compute_cost
compute_cost
=
TrainCycleItem
(
fwd
=
forward_compute_cost
,
bwd
=
backward_compute_cost
,
total
=
total_compute_cost
)
return
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
forward_size_mapping
=
{
'input'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'other'
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
...
...
colossalai/auto_parallel/solver/strategy/conv_strategy_generator.py
View file @
af718e83
...
...
@@ -29,7 +29,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
assert
input_op_data
.
dim
()
in
(
3
,
4
,
5
),
f
'We suppose the dim of input fed into conv op should in range of [3, 5].'
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
'''
Compute the computation cost per device with this specific strategy.
...
...
@@ -67,9 +67,9 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
total_compute_cost
=
forward_compute_cost
+
backward_compute_cost
compute_cost
=
TrainCycleItem
(
fwd
=
forward_compute_cost
,
bwd
=
backward_compute_cost
,
total
=
total_compute_cost
)
return
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
ShardingStrategy_V2
:
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
forward_size_mapping
=
{
'input'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'other'
:
self
.
_compute_size_in_bytes
(
strategy
,
"other"
),
...
...
colossalai/auto_parallel/solver/strategy/getitem_generator.py
View file @
af718e83
...
...
@@ -28,10 +28,11 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
def
validate
(
self
)
->
bool
:
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
return
TrainCycleItem
(
fwd
=
10
,
bwd
=
10
,
total
=
20
)
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
compute_cost
=
TrainCycleItem
(
fwd
=
10
,
bwd
=
10
,
total
=
20
)
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
'''
Compute the memory cost per device with this specific strategy.
'''
...
...
@@ -59,7 +60,6 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
return
super
().
update_memory_cost
(
strategy
)
class
TensorStrategyGenerator
(
GetItemStrategyGenerator
):
...
...
colossalai/auto_parallel/solver/strategy/layer_norm_generator.py
View file @
af718e83
...
...
@@ -23,7 +23,7 @@ class LayerNormGenerator(StrategyGenerator_V2):
def
validate
(
self
)
->
bool
:
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
'''
Compute the computation cost per device with this specific strategy.
...
...
@@ -52,9 +52,9 @@ class LayerNormGenerator(StrategyGenerator_V2):
backward_compute_cost
+=
bias_compute_cost
total_compute_cost
=
forward_compute_cost
+
backward_compute_cost
compute_cost
=
TrainCycleItem
(
fwd
=
forward_compute_cost
,
bwd
=
backward_compute_cost
,
total
=
total_compute_cost
)
return
compute_cost
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
'''
Compute the memory cost per device with this specific strategy.
'''
...
...
@@ -103,6 +103,9 @@ class LayerNormGenerator(StrategyGenerator_V2):
total_mesh_dim_list
=
[]
for
mesh_dim_list
in
dim_partition
.
values
():
total_mesh_dim_list
.
extend
(
mesh_dim_list
)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if
len
(
total_mesh_dim_list
)
==
1
:
total_mesh_dim_list
=
total_mesh_dim_list
[
0
]
communication_action_mapping
=
{}
other_comm_spec
=
self
.
get_communication_spec
(
...
...
colossalai/auto_parallel/solver/strategy/reshape_generator.py
0 → 100644
View file @
af718e83
import
operator
from
functools
import
reduce
from
..sharding_strategy
import
ShardingStrategy_V2
,
TrainCycleItem
,
MemoryCost
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
.strategy_generator
import
FollowingStrategyGenerator
from
typing
import
List
import
copy
__all__
=
[
'ReshapeGenerator'
]
class
ReshapeGenerator
(
FollowingStrategyGenerator
):
"""
ReshapeGenerator which deals with the sharding strategies of Reshape Op, such as torch.Tensor.permute.
"""
def
validate
(
self
)
->
bool
:
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
compute_cost
=
TrainCycleItem
(
fwd
=
10
,
bwd
=
10
,
total
=
20
)
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping
=
{
'input'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'output'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
}
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
.
pop
(
"output"
)
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)])
fwd_parameter_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
fwd_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
,
parameter
=
fwd_parameter_cost
)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)])
bwd_parameter_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
def
generate
(
self
):
strategy_list
=
[]
# For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in
# replica status and let the successor node choose the way to resharding the
# output node. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for reshape function.
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_output
=
{}
if
isinstance
(
self
.
op_data
[
"output"
].
data
,
tuple
):
dim_partition_dict_for_output
=
[{}
for
_
in
range
(
len
(
self
.
op_data
[
"output"
].
data
))]
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
-> FULLY REPLICATED_
{
index
}
'
total_mesh_dim_list
=
[]
for
mesh_dim_list
in
dim_partition_dict_for_input
.
values
():
total_mesh_dim_list
.
extend
(
mesh_dim_list
)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if
len
(
total_mesh_dim_list
)
==
1
:
total_mesh_dim_list
=
total_mesh_dim_list
[
0
]
input_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
logical_process_axis
=
total_mesh_dim_list
)
communication_action_mapping
[
"input"
]
=
input_comm_spec
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
for
strategy
in
strategy_list
:
self
.
update_communication_cost
(
strategy
)
self
.
update_compute_cost
(
strategy
)
self
.
update_memory_cost
(
strategy
)
return
strategy_list
colossalai/auto_parallel/solver/strategy/strategy_generator.py
View file @
af718e83
...
...
@@ -53,6 +53,13 @@ class StrategyGenerator_V2(ABC):
for
op_data_name
,
dim_partition_dict
in
mapping
.
items
():
if
op_data_name
in
self
.
op_data
:
op_data
=
self
.
op_data
[
op_data_name
]
if
isinstance
(
op_data
.
data
,
tuple
)
and
isinstance
(
op_data
.
data
[
0
],
torch
.
Tensor
):
sharding_spec
=
[]
for
output
,
dim_partition_dict_element
in
zip
(
op_data
.
data
,
dim_partition_dict
):
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
output
.
shape
,
dim_partition_dict
=
dim_partition_dict_element
)
else
:
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
op_data
.
logical_shape
,
dim_partition_dict
=
dim_partition_dict
)
...
...
colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py
View file @
af718e83
...
...
@@ -18,10 +18,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
def
validate
(
self
)
->
bool
:
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
return
TrainCycleItem
(
fwd
=
10
,
bwd
=
10
,
total
=
20
)
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
compute_cost
=
TrainCycleItem
(
fwd
=
10
,
bwd
=
10
,
total
=
20
)
strategy
.
compute_cost
=
compute_cost
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
):
'''
Compute the memory cost per device with this specific strategy.
'''
...
...
@@ -49,7 +50,6 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
return
super
().
update_memory_cost
(
strategy
)
def
generate
(
self
):
strategy_list
=
[]
...
...
tests/test_auto_parallel/test_node_handler/test_reshape_handler_v2.py
0 → 100644
View file @
af718e83
import
torch
import
torch.nn
as
nn
from
colossalai.fx
import
ColoTracer
,
ColoGraphModule
from
colossalai.auto_parallel.solver.op_handler.conv_handler_v2
import
ConvFunctionHandler
from
colossalai.auto_parallel.solver.op_handler.reshape_handler_v2
import
ReshapeHandler
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
class
ReshapeModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
input
,
other
):
conv_node
=
nn
.
functional
.
conv2d
(
input
,
other
)
reshape_node
=
conv_node
.
view
(
2
,
-
1
)
return
reshape_node
def
test_reshape_handler
():
model
=
ReshapeModel
()
tracer
=
ColoTracer
()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %view : [#users=1] = call_method[target=view](args = (%conv2d, 2, -1), kwargs = {})
# return view
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
4
,
64
,
64
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
4
,
16
,
3
,
3
).
to
(
'meta'
),
})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
conv_mod_node
=
list
(
graph
.
nodes
)[
2
]
reshape_node
=
list
(
graph
.
nodes
)[
3
]
reshape_strategies_vector
=
StrategiesVector
(
reshape_node
)
conv_strategies_vector
=
StrategiesVector
(
conv_mod_node
)
# build handler
conv_handler
=
ConvFunctionHandler
(
node
=
conv_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
conv_strategies_vector
)
conv_handler
.
register_strategy
()
setattr
(
conv_mod_node
,
'strategies_vector'
,
conv_strategies_vector
)
reshape_handler
=
ReshapeHandler
(
node
=
reshape_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
reshape_strategies_vector
)
reshape_handler
.
register_strategy
()
# check operation data mapping
mapping
=
reshape_handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"conv2d"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
4
,
62
,
62
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
4
,
62
,
62
])
assert
mapping
[
'output'
].
name
==
"view"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
2
,
30752
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert
len
(
reshape_strategies_vector
)
==
len
(
conv_strategies_vector
)
if
__name__
==
'__main__'
:
test_reshape_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment