Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
78509124
Unverified
Commit
78509124
authored
Dec 27, 2022
by
YuliangLiu0306
Committed by
GitHub
Dec 27, 2022
Browse files
[autoparallel] update getitem handler (#2207)
parent
29868a9e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
123 additions
and
75 deletions
+123
-75
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+2
-1
colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
...l/tensor_shard/node_handler/binary_elementwise_handler.py
+1
-1
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
...l/tensor_shard/node_handler/strategy/getitem_generator.py
+56
-32
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
...st_tensor_shard/test_node_handler/test_getitem_handler.py
+64
-41
No files found.
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
78509124
...
@@ -223,7 +223,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
...
@@ -223,7 +223,8 @@ def _size_value_converting(gm: torch.fx.GraphModule, device_mesh: DeviceMesh):
node
.
args
=
new_args
node
.
args
=
new_args
elif
isinstance
(
getitem_index
,
(
tuple
,
list
)):
elif
isinstance
(
getitem_index
,
(
tuple
,
list
)):
assert
isinstance
(
getitem_index
[
0
],
slice
)
if
not
isinstance
(
getitem_index
[
0
],
slice
):
continue
new_slice_items
=
[]
new_slice_items
=
[]
for
slice_item
in
getitem_index
:
for
slice_item
in
getitem_index
:
...
...
colossalai/auto_parallel/tensor_shard/node_handler/binary_elementwise_handler.py
View file @
78509124
...
@@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
...
@@ -16,7 +16,7 @@ __all__ = ['BinaryElementwiseHandler']
@
operator_registry
.
register
(
BCAST_FUNC_OP
)
@
operator_registry
.
register
(
BCAST_FUNC_OP
)
class
BinaryElementwiseHandler
(
MetaInfo
NodeHandler
):
class
BinaryElementwiseHandler
(
NodeHandler
):
"""
"""
An BinaryBcastOpHandler is a node handler which deals with operations which have two
An BinaryBcastOpHandler is a node handler which deals with operations which have two
operands and broadcasting occurs such as torch.add.
operands and broadcasting occurs such as torch.add.
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
View file @
78509124
...
@@ -7,7 +7,9 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
...
@@ -7,7 +7,9 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
ShardingStrategy
,
ShardingStrategy
,
TrainCycleItem
,
TrainCycleItem
,
)
)
from
colossalai.logging
import
get_dist_logger
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.sharding_spec
import
ShardingSpecException
from
.strategy_generator
import
FollowingStrategyGenerator
from
.strategy_generator
import
FollowingStrategyGenerator
...
@@ -69,39 +71,61 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
...
@@ -69,39 +71,61 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
def
collate_strategies
(
self
)
->
List
[
ShardingStrategy
]:
strategy_list
=
[]
strategy_list
=
[]
getitem_index
=
self
.
op_data
[
'index'
].
data
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
for
index
,
strategy
in
enumerate
(
self
.
predecessor_node
.
strategies_vector
):
dim_partition_dict_mapping
=
{}
try
:
communication_action_mapping
=
{}
logger
=
get_dist_logger
()
dim_partition_dict_for_input
=
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]].
dim_partition_dict
dim_partition_dict_mapping
=
{}
dim_partition_dict_for_output
=
copy
.
deepcopy
(
dim_partition_dict_for_input
)
communication_action_mapping
=
{}
gather_input
=
0
in
dim_partition_dict_for_input
dim_partition_dict_for_input
=
copy
.
deepcopy
(
if
gather_input
:
strategy
.
output_sharding_specs
[
self
.
op_data
[
"input"
]].
dim_partition_dict
)
logical_process_axis
=
dim_partition_dict_for_output
.
pop
(
0
)
int_index
=
False
shift_dim_partition_dict_for_output
=
{}
if
isinstance
(
getitem_index
,
int
):
for
dim
,
mesh_dim_list
in
dim_partition_dict_for_output
.
items
():
int_index
=
True
shift_dim_partition_dict_for_output
[
dim
-
1
]
=
mesh_dim_list
getitem_dims
=
[
dim_partition_dict_for_output
=
shift_dim_partition_dict_for_output
0
,
dim_partition_dict_mapping
=
{
]
"input"
:
dim_partition_dict_for_input
,
shift_length
=
1
"output"
:
dim_partition_dict_for_output
,
elif
isinstance
(
getitem_index
,
slice
):
}
getitem_dims
=
[
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
0
,
if
gather_input
:
]
input_communication_action
=
self
.
get_communication_action
(
else
:
sharding_spec_mapping
[
"input"
],
getitem_dims
=
[
i
for
i
in
range
(
len
(
getitem_index
))]
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
if
isinstance
(
getitem_index
[
0
],
int
):
logical_process_axis
=
logical_process_axis
,
int_index
=
True
comm_type
=
CommType
.
BEFORE
,
shift_length
=
len
(
getitem_index
)
arg_index
=
0
)
communication_action_mapping
[
"input"
]
=
input_communication_action
gather_dims
=
[]
for
dim
in
getitem_dims
:
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
_
{
index
}
'
if
dim
in
dim_partition_dict_for_input
:
gather_dims
.
append
(
dim
)
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
for
dim
in
gather_dims
:
communication_action_mapping
=
communication_action_mapping
)
dim_partition_dict_for_input
.
pop
(
dim
)
dim_partition_dict_for_output
=
copy
.
deepcopy
(
dim_partition_dict_for_input
)
if
int_index
:
shift_dim_partition_dict_for_output
=
{}
for
dim
,
mesh_dim_list
in
dim_partition_dict_for_output
.
items
():
shift_dim_partition_dict_for_output
[
dim
-
shift_length
]
=
mesh_dim_list
dim_partition_dict_for_output
=
shift_dim_partition_dict_for_output
dim_partition_dict_mapping
=
{
"input"
:
dim_partition_dict_for_input
,
"output"
:
dim_partition_dict_for_output
,
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
_
{
index
}
'
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
except
ShardingSpecException
as
e
:
logger
.
debug
(
e
)
continue
strategy_list
.
append
(
strategy
)
strategy_list
.
append
(
strategy
)
for
strategy
in
strategy_list
:
for
strategy
in
strategy_list
:
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_getitem_handler.py
View file @
78509124
from
functools
import
partial
import
pytest
import
torch
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.conv_handler
import
ConvFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler
import
GetItemHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.getitem_handler
import
GetItemHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.linear_handler
import
LinearFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler
import
PlacehodlerHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.placeholder_handler
import
PlacehodlerHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler
import
ReshapeHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.reshape_handler
import
ReshapeHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
from
colossalai.fx.tracer.meta_patch.patched_module
import
linear
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
class
GetItemFromTensorModel
(
nn
.
Module
):
class
GetItemFromTensorModel
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
,
getitem_index
):
super
().
__init__
()
super
().
__init__
()
self
.
getitem_index
=
getitem_index
def
forward
(
self
,
input
,
other
):
def
forward
(
self
,
input
,
other
):
conv
_node
=
nn
.
functional
.
conv2d
(
input
,
other
)
linear
_node
=
nn
.
functional
.
linear
(
input
,
other
,
bias
=
None
)
x
=
conv_node
[
1
]
x
=
linear_node
[
self
.
getitem_index
]
return
x
return
x
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
check_getitem_from_tensor_handler
(
rank
,
getitem_index
,
world_size
,
port
):
def
test_getitem_from_tensor_handler
():
disable_existing_loggers
()
model
=
GetItemFromTensorModel
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
GetItemFromTensorModel
(
getitem_index
=
getitem_index
)
input
=
torch
.
rand
(
8
,
16
,
64
,
32
).
to
(
'cuda'
)
other
=
torch
.
rand
(
64
,
32
).
to
(
'cuda'
)
# index of linear node in computation graph
node_index
=
2
# total number of linear strategies
strategy_number
=
23
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
[
input
,
other
],
meta_arg_names
=
[
'input'
,
'other'
],
node_type
=
'following'
)
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%input_1, %other), kwargs = {})
# %getitem : [#users=1] = call_function[target=operator.getitem](args = (%conv2d, 1), kwargs = {})
# return getitem
graph
=
tracer
.
trace
(
model
,
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
meta_args
=
{
"input"
:
torch
.
rand
(
4
,
4
,
64
,
64
).
to
(
'meta'
),
"input"
:
torch
.
rand
(
8
,
16
,
64
,
32
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
4
,
16
,
3
,
3
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
64
,
32
).
to
(
'meta'
),
})
})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
gm
=
ColoGraphModule
(
model
,
graph
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
linear_mod_node
=
list
(
graph
.
nodes
)[
2
]
conv_mod_node
=
list
(
graph
.
nodes
)[
2
]
getitem_mod_node
=
list
(
graph
.
nodes
)[
3
]
getitem_mod_node
=
list
(
graph
.
nodes
)[
3
]
getitem_strategies_vector
=
StrategiesVector
(
getitem_mod_node
)
getitem_strategies_vector
=
StrategiesVector
(
getitem_mod_node
)
conv
_strategies_vector
=
StrategiesVector
(
conv
_mod_node
)
linear
_strategies_vector
=
StrategiesVector
(
linear
_mod_node
)
# build handler
# build handler
conv
_handler
=
Conv
FunctionHandler
(
node
=
conv
_mod_node
,
linear
_handler
=
Linear
FunctionHandler
(
node
=
linear
_mod_node
,
device_mesh
=
device_mesh
,
device_mesh
=
device_mesh
,
strategies_vector
=
conv
_strategies_vector
)
strategies_vector
=
linear
_strategies_vector
)
conv
_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
linear
_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
setattr
(
conv
_mod_node
,
'strategies_vector'
,
conv
_strategies_vector
)
setattr
(
linear
_mod_node
,
'strategies_vector'
,
linear
_strategies_vector
)
getitem_handler
=
GetItemHandler
(
node
=
getitem_mod_node
,
getitem_handler
=
GetItemHandler
(
node
=
getitem_mod_node
,
device_mesh
=
device_mesh
,
device_mesh
=
device_mesh
,
strategies_vector
=
getitem_strategies_vector
)
strategies_vector
=
getitem_strategies_vector
)
...
@@ -67,23 +91,22 @@ def test_getitem_from_tensor_handler():
...
@@ -67,23 +91,22 @@ def test_getitem_from_tensor_handler():
# make sure they have valid values
# make sure they have valid values
assert
op_data
.
data
is
not
None
assert
op_data
.
data
is
not
None
assert
mapping
[
'input'
].
name
==
"conv2d"
# getitem is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert
mapping
[
'input'
].
data
.
is_meta
assert
len
(
getitem_strategies_vector
)
==
len
(
linear_strategies_vector
)
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
4
,
4
,
62
,
62
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'input'
].
logical_shape
==
torch
.
Size
([
4
,
4
,
62
,
62
])
assert
mapping
[
'index'
].
name
==
"index"
assert
isinstance
(
mapping
[
'index'
].
data
,
int
)
assert
mapping
[
'index'
].
type
==
OperationDataType
.
ARG
assert
mapping
[
'output'
].
name
==
"getitem"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
62
,
62
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
# getitem is a following strategy handler, so the number of strategies is equal to the predecessor node.
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
assert
len
(
getitem_strategies_vector
)
==
len
(
conv_strategies_vector
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
# @parameterize('getitem_index', [slice(0, 2), (slice(None), slice(None))])
@
parameterize
(
'getitem_index'
,
[
1
,
(
1
,
4
),
slice
(
0
,
2
),
(
slice
(
None
),
slice
(
None
))])
def
test_getitem_from_tensor_handler
(
getitem_index
):
world_size
=
4
run_func
=
partial
(
check_getitem_from_tensor_handler
,
getitem_index
=
getitem_index
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
class
GetItemFromTupleModel
(
nn
.
Module
):
class
GetItemFromTupleModel
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment