Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
517b6393
Unverified
Commit
517b6393
authored
Oct 09, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 09, 2022
Browse files
[autoparallel] add unary element wise handler v2 (#1674)
parent
f6c6a932
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
204 additions
and
1 deletion
+204
-1
colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler_v2.py
...arallel/solver/op_handler/unary_elementwise_handler_v2.py
+35
-0
colossalai/auto_parallel/solver/strategy/__init__.py
colossalai/auto_parallel/solver/strategy/__init__.py
+2
-1
colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py
...o_parallel/solver/strategy/unary_elementwise_generator.py
+84
-0
tests/test_auto_parallel/test_node_handler/test_unary_element_wise_handler_v2.py
...l/test_node_handler/test_unary_element_wise_handler_v2.py
+83
-0
No files found.
colossalai/auto_parallel/solver/op_handler/unary_elementwise_handler_v2.py
0 → 100644
View file @
517b6393
import
torch
from
.node_handler
import
NodeHandler
from
..sharding_strategy
import
ShardingStrategy_V2
,
OperationDataType
,
OperationData
,
StrategiesVector
from
..strategy
import
UnaryElementwiseGenerator
,
StrategyGenerator_V2
from
typing
import
List
,
Dict
from
.registry
import
operator_registry
import
operator
__all__
=
[
'UnaryElementwiseHandler'
]
@
operator_registry
.
register
(
torch
.
abs
)
@
operator_registry
.
register
(
torch
.
nn
.
ReLU
)
class
UnaryElementwiseHandler
(
NodeHandler
):
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator_V2
]:
op_data_mapping
=
self
.
get_operation_data_mapping
()
generators
=
[]
generators
.
append
(
UnaryElementwiseGenerator
(
op_data_mapping
,
self
.
device_mesh
,
self
.
node
.
args
[
0
]))
return
generators
def
get_operation_data_mapping
(
self
)
->
Dict
[
str
,
OperationData
]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
0
]),
type
=
OperationDataType
.
ARG
,
data
=
self
.
node
.
args
[
0
].
_meta_data
)
physical_output
=
OperationData
(
name
=
str
(
self
.
node
),
type
=
OperationDataType
.
OUTPUT
,
data
=
self
.
node
.
_meta_data
)
mapping
=
{
"input"
:
physical_input_operand
,
"output"
:
physical_output
}
return
mapping
colossalai/auto_parallel/solver/strategy/__init__.py
View file @
517b6393
...
@@ -2,12 +2,13 @@ from .strategy_generator import StrategyGenerator_V2
...
@@ -2,12 +2,13 @@ from .strategy_generator import StrategyGenerator_V2
from
.matmul_strategy_generator
import
DotProductStrategyGenerator
,
MatVecStrategyGenerator
,
LinearProjectionStrategyGenerator
,
BatchedMatMulStrategyGenerator
from
.matmul_strategy_generator
import
DotProductStrategyGenerator
,
MatVecStrategyGenerator
,
LinearProjectionStrategyGenerator
,
BatchedMatMulStrategyGenerator
from
.conv_strategy_generator
import
ConvStrategyGenerator
from
.conv_strategy_generator
import
ConvStrategyGenerator
from
.batch_norm_generator
import
BatchNormStrategyGenerator
from
.batch_norm_generator
import
BatchNormStrategyGenerator
from
.unary_elementwise_generator
import
UnaryElementwiseGenerator
from
.getitem_generator
import
GetItemStrategyGenerator
,
TensorStrategyGenerator
,
TensorTupleStrategyGenerator
from
.getitem_generator
import
GetItemStrategyGenerator
,
TensorStrategyGenerator
,
TensorTupleStrategyGenerator
from
.layer_norm_generator
import
LayerNormGenerator
from
.layer_norm_generator
import
LayerNormGenerator
__all__
=
[
__all__
=
[
'StrategyGenerator_V2'
,
'DotProductStrategyGenerator'
,
'MatVecStrategyGenerator'
,
'StrategyGenerator_V2'
,
'DotProductStrategyGenerator'
,
'MatVecStrategyGenerator'
,
'LinearProjectionStrategyGenerator'
,
'BatchedMatMulStrategyGenerator'
,
'ConvStrategyGenerator'
,
'LinearProjectionStrategyGenerator'
,
'BatchedMatMulStrategyGenerator'
,
'ConvStrategyGenerator'
,
'UnaryElementwiseGenerator'
,
'BatchNormStrategyGenerator'
,
'GetItemStrategyGenerator'
,
'TensorStrategyGenerator'
,
'TensorTupleStrategyGenerator'
,
'BatchNormStrategyGenerator'
,
'GetItemStrategyGenerator'
,
'TensorStrategyGenerator'
,
'TensorTupleStrategyGenerator'
,
'LayerNormGenerator'
'LayerNormGenerator'
]
]
colossalai/auto_parallel/solver/strategy/unary_elementwise_generator.py
0 → 100644
View file @
517b6393
import
operator
from
functools
import
reduce
from
..sharding_strategy
import
ShardingStrategy_V2
,
TrainCycleItem
,
MemoryCost
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
.strategy_generator
import
FollowingStrategyGenerator
from
typing
import
List
from
.._utils
import
exception_handler
import
copy
__all__
=
[
'UnaryElementwiseGenerator'
]
class
UnaryElementwiseGenerator
(
FollowingStrategyGenerator
):
"""
UnaryElementwiseGenerator which deals with the sharding strategies of UnaryElementwiseOp.
"""
def
validate
(
self
)
->
bool
:
return
super
().
validate
()
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
return
TrainCycleItem
(
fwd
=
10
,
bwd
=
10
,
total
=
20
)
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
TrainCycleItem
:
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping
=
{
'input'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
'output'
:
self
.
_compute_size_in_bytes
(
strategy
,
"output"
)
}
backward_size_mapping
=
copy
.
deepcopy
(
forward_size_mapping
)
backward_size_mapping
.
pop
(
"output"
)
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)])
fwd_parameter_cost
=
sum
([
v
for
k
,
v
in
forward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
fwd_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
,
parameter
=
fwd_parameter_cost
)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
not
self
.
is_param
(
k
)])
bwd_parameter_cost
=
sum
([
v
for
k
,
v
in
backward_size_mapping
.
items
()
if
self
.
is_param
(
k
)])
bwd_mem_cost
=
MemoryCost
(
activation
=
bwd_activation_cost
,
parameter
=
bwd_parameter_cost
)
# compute total cost
total_mem_cost
=
MemoryCost
(
activation
=
fwd_activation_cost
+
bwd_activation_cost
,
parameter
=
fwd_parameter_cost
+
bwd_parameter_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
return
super
().
update_memory_cost
(
strategy
)
def
generate
(
self
):
strategy_list
=
[]
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise function.
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
input_sharding_spec
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
dim_partition_dict_for_input
=
input_sharding_spec
.
dim_partition_dict
dim_partition_dict_for_output
=
copy
.
deepcopy
(
dim_partition_dict_for_input
)
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name
=
f
'
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
->
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
strategy_list
.
append
(
strategy
)
for
strategy
in
strategy_list
:
self
.
update_communication_cost
(
strategy
)
self
.
update_compute_cost
(
strategy
)
self
.
update_memory_cost
(
strategy
)
return
strategy_list
tests/test_auto_parallel/test_node_handler/test_unary_element_wise_handler_v2.py
0 → 100644
View file @
517b6393
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
import
torch
import
torch.nn
as
nn
from
colossalai.fx
import
ColoTracer
,
ColoGraphModule
from
colossalai.auto_parallel.solver.op_handler.unary_elementwise_handler_v2
import
UnaryElementwiseHandler
from
colossalai.auto_parallel.solver.op_handler.conv_handler_v2
import
ConvFunctionHandler
from
colossalai.auto_parallel.solver.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
class
ReLuModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
act
=
torch
.
nn
.
ReLU
()
def
forward
(
self
,
input
,
other
):
conv_node
=
nn
.
functional
.
conv2d
(
input
,
other
)
relu_node
=
self
.
act
(
conv_node
)
return
relu_node
def
test_elementwise_handler
():
model
=
ReLuModel
()
tracer
=
ColoTracer
()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %act : [#users=1] = call_module[target=act](args = (%conv2d,), kwargs = {})
# return act
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
4
,
64
,
64
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
4
,
16
,
3
,
3
).
to
(
'meta'
),
})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
conv_mod_node
=
list
(
graph
.
nodes
)[
2
]
relu_mod_node
=
list
(
graph
.
nodes
)[
3
]
relu_strategies_vector
=
StrategiesVector
(
relu_mod_node
)
conv_strategies_vector
=
StrategiesVector
(
conv_mod_node
)
# build handler
conv_handler
=
ConvFunctionHandler
(
node
=
conv_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
conv_strategies_vector
)
conv_handler
.
register_strategy
()
setattr
(
conv_mod_node
,
'strategies_vector'
,
conv_strategies_vector
)
relu_handler
=
UnaryElementwiseHandler
(
node
=
relu_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
relu_strategies_vector
)
relu_handler
.
register_strategy
()
# check operation data mapping
mapping
=
relu_handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"conv2d"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
4
,
62
,
62
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
4
,
62
,
62
])
assert
mapping
[
'output'
].
name
==
"act"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
4
,
62
,
62
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
# getitem is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert
len
(
relu_strategies_vector
)
==
len
(
conv_strategies_vector
)
if
__name__
==
'__main__'
:
test_elementwise_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment