Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
35e6b9ec
Unverified
Commit
35e6b9ec
authored
Nov 21, 2022
by
YuliangLiu0306
Committed by
GitHub
Nov 21, 2022
Browse files
[autoparallel] adapt handlers with attention block (#1990)
* [autoparallel] adapt handlers with attention block * polish
parent
b5dbb461
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
114 additions
and
33 deletions
+114
-33
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
...uto_parallel/tensor_shard/node_handler/reshape_handler.py
+1
-0
colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
...ensor_shard/node_handler/strategy/batch_norm_generator.py
+9
-3
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
...l/tensor_shard/node_handler/strategy/getitem_generator.py
+8
-5
colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
...lel/tensor_shard/node_handler/strategy/where_generator.py
+7
-3
colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
...el/tensor_shard/node_handler/unary_elementwise_handler.py
+5
-0
colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py
.../auto_parallel/tensor_shard/node_handler/where_handler.py
+0
-18
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
...st_tensor_shard/test_node_handler/test_getitem_handler.py
+84
-4
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/reshape_handler.py
View file @
35e6b9ec
...
@@ -12,6 +12,7 @@ __all__ = ['ReshapeHandler']
...
@@ -12,6 +12,7 @@ __all__ = ['ReshapeHandler']
@
operator_registry
.
register
(
torch
.
reshape
)
@
operator_registry
.
register
(
torch
.
reshape
)
@
operator_registry
.
register
(
torch
.
Tensor
.
split
)
@
operator_registry
.
register
(
torch
.
Tensor
.
split
)
@
operator_registry
.
register
(
torch
.
split
)
@
operator_registry
.
register
(
torch
.
flatten
)
@
operator_registry
.
register
(
torch
.
flatten
)
@
operator_registry
.
register
(
torch
.
Tensor
.
transpose
)
@
operator_registry
.
register
(
torch
.
Tensor
.
transpose
)
@
operator_registry
.
register
(
torch
.
Tensor
.
permute
)
@
operator_registry
.
register
(
torch
.
Tensor
.
permute
)
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
View file @
35e6b9ec
...
@@ -220,7 +220,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -220,7 +220,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
IMPLICIT
)
comm_type
=
CommType
.
IMPLICIT
)
communication_action_mapping
=
{
"output"
:
output_comm_action
}
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -256,7 +258,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -256,7 +258,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
IMPLICIT
)
comm_type
=
CommType
.
IMPLICIT
)
communication_action_mapping
=
{
"output"
:
output_comm_action
}
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -302,7 +306,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -302,7 +306,9 @@ class BatchNormStrategyGenerator(StrategyGenerator):
logical_process_axis
=
[
mesh_dim_0
],
logical_process_axis
=
[
mesh_dim_0
],
comm_type
=
CommType
.
IMPLICIT
)
comm_type
=
CommType
.
IMPLICIT
)
communication_action_mapping
=
{
"output"
:
output_comm_action
}
# TODO: Temporary solution has no communication cost,
# above action should be added after the SyncBN replace pass completed.
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
View file @
35e6b9ec
...
@@ -69,7 +69,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
...
@@ -69,7 +69,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
for
strategy
in
self
.
predecessor_node
.
strategies_vector
:
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
)
:
dim_partition_dict_mapping
=
{}
dim_partition_dict_mapping
=
{}
communication_action_mapping
=
{}
communication_action_mapping
=
{}
dim_partition_dict_for_input
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]].
dim_partition_dict
dim_partition_dict_for_input
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]].
dim_partition_dict
...
@@ -96,7 +96,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
...
@@ -96,7 +96,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
arg_index
=
0
)
arg_index
=
0
)
communication_action_mapping
[
"input"
]
=
input_communication_action
communication_action_mapping
[
"input"
]
=
input_communication_action
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
'
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -121,7 +121,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
...
@@ -121,7 +121,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
strategy_list
=
[]
strategy_list
=
[]
index
=
self
.
op_data
[
"index"
].
data
index
=
self
.
op_data
[
"index"
].
data
for
strategy
in
self
.
predecessor_node
.
strategies_vector
:
for
strategy
_index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
)
:
# the sharding spec for input in this case is a tuple of ShardingSpec.
# the sharding spec for input in this case is a tuple of ShardingSpec.
sharding_spec_for_input
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
sharding_spec_for_input
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]]
dim_partition_dict_for_output
=
sharding_spec_for_input
[
index
].
dim_partition_dict
dim_partition_dict_for_output
=
sharding_spec_for_input
[
index
].
dim_partition_dict
...
@@ -132,8 +132,11 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
...
@@ -132,8 +132,11 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
}
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
[
"input"
]
=
sharding_spec_for_input
sharding_spec_mapping
[
"input"
]
=
sharding_spec_for_input
input_sharding_info
=
f
"get the
{
index
}
element from ("
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
'
for
sharding_spec
in
sharding_spec_for_input
:
input_sharding_info
+=
f
'
{
sharding_spec
.
sharding_sequence
}
, '
input_sharding_info
+=
")"
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
input_sharding_info
}
_
{
strategy_index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/where_generator.py
View file @
35e6b9ec
import
copy
import
copy
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
from
colossalai.auto_parallel.tensor_shard.utils
import
(
enumerate_all_possible_1d_sharding
,
from
colossalai.auto_parallel.tensor_shard.utils
import
(
enumerate_all_possible_2d_sharding
)
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
,
ignore_sharding_exception
,
)
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
...
@@ -50,6 +53,7 @@ class WhereGenerator(StrategyGenerator):
...
@@ -50,6 +53,7 @@ class WhereGenerator(StrategyGenerator):
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
memory_cost
=
TrainCycleItem
(
fwd
=
fwd_mem_cost
,
bwd
=
bwd_mem_cost
,
total
=
total_mem_cost
)
strategy
.
memory_cost
=
memory_cost
strategy
.
memory_cost
=
memory_cost
@
ignore_sharding_exception
def
_generate_strategy_with_dim_partition
(
self
,
dim_partition
):
def
_generate_strategy_with_dim_partition
(
self
,
dim_partition
):
dim_partition_dict_mapping
=
{
dim_partition_dict_mapping
=
{
"condition"
:
dim_partition
,
"condition"
:
dim_partition
,
...
...
colossalai/auto_parallel/tensor_shard/node_handler/unary_elementwise_handler.py
View file @
35e6b9ec
...
@@ -14,6 +14,11 @@ __all__ = ['UnaryElementwiseHandler']
...
@@ -14,6 +14,11 @@ __all__ = ['UnaryElementwiseHandler']
@
operator_registry
.
register
(
torch
.
Tensor
.
type
)
@
operator_registry
.
register
(
torch
.
Tensor
.
type
)
@
operator_registry
.
register
(
torch
.
abs
)
@
operator_registry
.
register
(
torch
.
abs
)
@
operator_registry
.
register
(
torch
.
nn
.
ReLU
)
@
operator_registry
.
register
(
torch
.
nn
.
ReLU
)
# TODO: softmax need to be relocated
@
operator_registry
.
register
(
torch
.
nn
.
functional
.
softmax
)
@
operator_registry
.
register
(
torch
.
nn
.
modules
.
dropout
.
Dropout
)
@
operator_registry
.
register
(
torch
.
Tensor
.
contiguous
)
@
operator_registry
.
register
(
torch
.
nn
.
functional
.
dropout
)
class
UnaryElementwiseHandler
(
NodeHandler
):
class
UnaryElementwiseHandler
(
NodeHandler
):
"""
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
...
...
colossalai/auto_parallel/tensor_shard/node_handler/where_handler.py
View file @
35e6b9ec
...
@@ -57,24 +57,6 @@ class WhereHandler(NodeHandler):
...
@@ -57,24 +57,6 @@ class WhereHandler(NodeHandler):
logical_operand
.
logical_shape
=
target_shape
logical_operand
.
logical_shape
=
target_shape
return
logical_operand
return
logical_operand
def
register_strategy
(
self
,
compute_resharding_cost
:
bool
=
False
)
->
StrategiesVector
:
"""
Register different sharding strategies for the current node.
"""
strategy_generators
=
self
.
get_strategy_generator
()
for
generator
in
strategy_generators
:
strategies
=
generator
.
generate
()
strategies_vector
=
map
(
self
.
post_process
,
strategies
)
# compute the resharding costs based on the previous node
# strategies if specified
if
compute_resharding_cost
:
strategies
=
list
(
map
(
self
.
update_resharding_cost
,
strategies
))
self
.
strategies_vector
.
extend
(
strategies
)
self
.
strategies_vector
=
list
(
strategies_vector
)
return
self
.
strategies_vector
def
post_process
(
self
,
strategy
:
ShardingStrategy
):
def
post_process
(
self
,
strategy
:
ShardingStrategy
):
logical_op_data_mapping
,
physical_op_data_mapping
=
self
.
get_operation_data_mapping
()
logical_op_data_mapping
,
physical_op_data_mapping
=
self
.
get_operation_data_mapping
()
for
key
in
logical_op_data_mapping
.
keys
():
for
key
in
logical_op_data_mapping
.
keys
():
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
View file @
35e6b9ec
...
@@ -3,6 +3,8 @@ import torch.nn as nn
...
@@ -3,6 +3,8 @@ import torch.nn as nn
from
colossalai.auto_parallel.tensor_shard.node_handler.conv_handler
import
ConvFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.conv_handler
import
ConvFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler
import
GetItemHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler
import
GetItemHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler
import
PlacehodlerHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler
import
ReshapeHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
...
@@ -10,7 +12,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
...
@@ -10,7 +12,7 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
class
GetItemModel
(
nn
.
Module
):
class
GetItem
FromTensor
Model
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
().
__init__
()
super
().
__init__
()
...
@@ -21,8 +23,8 @@ class GetItemModel(nn.Module):
...
@@ -21,8 +23,8 @@ class GetItemModel(nn.Module):
return
x
return
x
def
test_getitem_f
unction
_handler
():
def
test_getitem_f
rom_tensor
_handler
():
model
=
GetItemModel
()
model
=
GetItem
FromTensor
Model
()
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
# graph():
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
...
@@ -83,5 +85,83 @@ def test_getitem_function_handler():
...
@@ -83,5 +85,83 @@ def test_getitem_function_handler():
assert
len
(
getitem_strategies_vector
)
==
len
(
conv_strategies_vector
)
assert
len
(
getitem_strategies_vector
)
==
len
(
conv_strategies_vector
)
class
GetItemFromTupleModel
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
input
):
split_node
=
torch
.
split
(
input
,
2
,
0
)
x
=
split_node
[
1
]
return
x
def
test_getitem_from_tuple_handler
():
model
=
GetItemFromTupleModel
()
tracer
=
ColoTracer
()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %split : [#users=1] = call_function[target=torch.functional.split](args = (%conv2d, 2), kwargs = {dim: 0})
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%split, 1), kwargs = {})
# return getitem
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
4
,
64
,
64
).
to
(
'meta'
),
})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
input_node
=
list
(
graph
.
nodes
)[
0
]
split_node
=
list
(
graph
.
nodes
)[
1
]
getitem_node
=
list
(
graph
.
nodes
)[
2
]
input_strategies_vector
=
StrategiesVector
(
input_node
)
getitem_strategies_vector
=
StrategiesVector
(
getitem_node
)
split_strategies_vector
=
StrategiesVector
(
split_node
)
# build handler
input_handler
=
PlacehodlerHandler
(
node
=
input_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
input_strategies_vector
,
placeholder_option
=
'replicated'
,
)
input_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
setattr
(
input_node
,
'strategies_vector'
,
input_strategies_vector
)
split_handler
=
ReshapeHandler
(
node
=
split_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
split_strategies_vector
)
split_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
setattr
(
split_node
,
'strategies_vector'
,
split_strategies_vector
)
getitem_handler
=
GetItemHandler
(
node
=
getitem_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
getitem_strategies_vector
)
getitem_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
setattr
(
getitem_node
,
'strategies_vector'
,
getitem_strategies_vector
)
# check operation data mapping
mapping
=
getitem_handler
.
get_operation_data_mapping
()
for
name
,
op_data
in
mapping
.
items
():
op_data
:
OperationData
# make sure they have valid values
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"split"
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
(
torch
.
Size
([
2
,
4
,
64
,
64
]),
torch
.
Size
([
2
,
4
,
64
,
64
]))
assert
mapping
[
'index'
].
name
==
"index"
assert
isinstance
(
mapping
[
'index'
].
data
,
int
)
assert
mapping
[
'index'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'output'
].
name
==
"getitem"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
2
,
4
,
64
,
64
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
# getitem is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert
len
(
getitem_strategies_vector
)
==
len
(
split_strategies_vector
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_getitem_function_handler
()
test_getitem_from_tensor_handler
()
test_getitem_from_tuple_handler
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment