Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
ea0f6b8d
Unverified
Commit
ea0f6b8d
authored
Nov 25, 2022
by
YuliangLiu0306
Committed by
GitHub
Nov 25, 2022
Browse files
[autoparallel] add runtime pass and numerical test for view handler (#2018)
parent
bb624561
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
250 additions
and
49 deletions
+250
-49
colossalai/auto_parallel/passes/runtime_preparation_pass.py
colossalai/auto_parallel/passes/runtime_preparation_pass.py
+24
-0
colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py
.../tensor_shard/node_handler/experimental/view_generator.py
+6
-1
colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
...auto_parallel/tensor_shard/node_handler/linear_handler.py
+2
-1
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
.../test_tensor_shard/test_node_handler/test_view_handler.py
+213
-46
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
...uto_parallel/test_tensor_shard/test_node_handler/utils.py
+5
-1
No files found.
colossalai/auto_parallel/passes/runtime_preparation_pass.py
View file @
ea0f6b8d
...
...
@@ -37,6 +37,30 @@ def _solution_annotatation(gm: torch.fx.GraphModule, solution: List[int]):
origin_node_sharding_spec_dict
[
node_index
]
=
strategies_vector
[
strategy_index
].
get_sharding_spec_by_name
(
str
(
node
))
# experimental pass for torch.Tensor.view
# Arguments of view op will be divided in the sharded dimensions.
for
node
in
nodes
:
if
node
.
op
==
'call_method'
and
getattr
(
node
.
args
[
0
].
_meta_data
.
__class__
,
node
.
target
)
in
(
torch
.
Tensor
.
view
,):
output_dim_partition_dict
=
node
.
sharding_spec
.
dim_partition_dict
device_mesh
=
node
.
sharding_spec
.
device_mesh
new_args
=
[]
for
arg
in
node
.
args
:
if
isinstance
(
arg
,
Node
):
if
isinstance
(
arg
.
_meta_data
,
int
):
new_args
.
append
(
arg
.
_meta_data
)
else
:
new_args
.
append
(
arg
)
else
:
assert
isinstance
(
arg
,
int
),
'The argument in view node should be either type of Node or int.'
new_args
.
append
(
arg
)
for
dim
,
shard_dims
in
output_dim_partition_dict
.
items
():
total_shard_size
=
1
for
shard_dim
in
shard_dims
:
total_shard_size
*=
device_mesh
.
shape
[
shard_dim
]
new_args
[
dim
+
1
]
//=
total_shard_size
node
.
args
=
tuple
(
new_args
)
# the dict to get input sharding specs of user node
sharding_spec_convert_dict
=
{}
# the dict to record comm actions of nodes
...
...
colossalai/auto_parallel/tensor_shard/node_handler/experimental/view_generator.py
View file @
ea0f6b8d
...
...
@@ -103,13 +103,18 @@ class ViewGenerator(FollowingStrategyGenerator):
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if
len
(
total_mesh_dim_list
)
==
1
:
total_mesh_dim_list
=
total_mesh_dim_list
[
0
]
# the total mesh dim list only has one element, so the shard dim has only one element as well.
shard_dim
=
list
(
dim_partition_dict_for_input
.
keys
())[
0
]
input_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
logical_process_axis
=
total_mesh_dim_list
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
input_comm_action
.
comm_spec
.
gather_dim
=
total_mesh_dim_list
# it will gather the input through gather_dim during forward phase.
input_comm_action
.
comm_spec
.
gather_dim
=
shard_dim
# it will split the input activation grad through shard_dim during backward phase.
input_comm_action
.
comm_spec
.
shard_dim
=
shard_dim
elif
len
(
total_mesh_dim_list
)
>=
2
:
source_spec
=
sharding_spec_mapping
[
"input"
]
...
...
colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
View file @
ea0f6b8d
...
...
@@ -105,6 +105,7 @@ def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: Sha
dim_mapping
=
{
0
:
i
},
physical_shape
=
output_op_data
.
data
.
shape
,
inplace
=
True
)
strategy_copy
.
name
=
f
'
{
strategy
.
name
}
_
{
i
}
'
sharding_strategies
.
append
(
strategy_copy
)
except
ShardingNotDivisibleError
as
e
:
logger
.
debug
(
...
...
@@ -194,7 +195,7 @@ class LinearModuleHandler(ModuleHandler):
@
operator_registry
.
register
(
F
.
linear
)
class
LinearFunctionHandler
(
NodeHandler
):
"""
A Linear
Module
Handler which deals with the sharding strategies for
nn
.Linear
module
.
A Linear
Function
Handler which deals with the sharding strategies for
F
.Linear.
"""
def
get_strategy_generator
(
self
)
->
List
[
StrategyGenerator
]:
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_view_handler.py
View file @
ea0f6b8d
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
torch.nn
as
nn
from
colossalai.auto_parallel.tensor_shard.node_handler.conv_handler
import
ConvFunctionHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.experimental
import
ViewHandler
from
colossalai.auto_parallel.tensor_shard.node_handler.linear_handler
import
LinearFunctionHandler
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
OperationData
,
OperationDataType
,
StrategiesVector
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
assert_close
,
parameterize
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils
import
numerical_test_for_node_strategy
class
ConvViewModel
(
nn
.
Module
):
def
__init__
(
self
,
tgt_shape
):
super
().
__init__
()
self
.
tgt_shape
=
tgt_shape
def
forward
(
self
,
input
,
other
):
conv_node
=
nn
.
functional
.
conv2d
(
input
,
other
,
bias
=
None
)
reshape_node
=
conv_node
.
view
(
*
self
.
tgt_shape
)
return
reshape_node
class
ViewModel
(
nn
.
Module
):
class
Linear
ViewModel
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
,
tgt_shape
):
super
().
__init__
()
self
.
tgt_shape
=
tgt_shape
def
forward
(
self
,
input
,
other
):
conv
_node
=
nn
.
functional
.
conv2d
(
input
,
other
)
reshape_node
=
conv
_node
.
view
(
32
,
4
,
32
,
32
,
4
)
linear
_node
=
nn
.
functional
.
linear
(
input
,
other
,
bias
=
None
)
reshape_node
=
linear
_node
.
view
(
*
self
.
tgt_shape
)
return
reshape_node
def
test_view_handler
():
model
=
ViewModel
()
def
check_view_handler
(
rank
,
tgt_shape
,
model_cls
,
world_size
,
port
):
disable_existing_loggers
()
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
model
=
model_cls
(
tgt_shape
).
cuda
()
if
model_cls
.
__name__
==
'ConvViewModel'
:
input
=
torch
.
rand
(
8
,
8
,
66
,
66
).
to
(
'cuda'
)
other
=
torch
.
rand
(
16
,
8
,
3
,
3
).
to
(
'cuda'
)
# index of conv node in computation graph
node_index
=
2
# total number of conv strategies
strategy_number
=
16
if
model_cls
.
__name__
==
'LinearViewModel'
:
input
=
torch
.
rand
(
8
,
16
,
64
,
32
).
to
(
'cuda'
)
other
=
torch
.
rand
(
64
,
32
).
to
(
'cuda'
)
# index of linear node in computation graph
node_index
=
2
# total number of linear strategies
strategy_number
=
23
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
numerical_test_for_node_strategy
(
model
=
model
,
device_mesh
=
device_mesh
,
node_index
=
node_index
,
strategy_number
=
strategy_number
,
input_args
=
[
input
,
other
],
meta_arg_names
=
[
'input'
,
'other'
],
node_type
=
'following'
)
tracer
=
ColoTracer
()
if
model_cls
.
__name__
==
'ConvViewModel'
:
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
...
...
@@ -31,25 +84,47 @@ def test_view_handler():
# return view
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
8
,
8
,
66
,
66
).
to
(
'meta'
),
"input"
:
torch
.
rand
(
8
,
16
,
66
,
66
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
16
,
8
,
3
,
3
).
to
(
'meta'
),
})
if
model_cls
.
__name__
==
'LinearViewModel'
:
# graph():
# %input_1 : torch.Tensor [#users=1] = placeholder[target=input]
# %other : torch.Tensor [#users=1] = placeholder[target=other]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%input_1, %other), kwargs = {bias: None})
# %view : [#users=1] = call_method[target=view](args = (%linear, 32, 4, 32, 32, 4), kwargs = {})
# return view
graph
=
tracer
.
trace
(
model
,
meta_args
=
{
"input"
:
torch
.
rand
(
8
,
16
,
64
,
32
).
to
(
'meta'
),
"other"
:
torch
.
rand
(
64
,
32
).
to
(
'meta'
),
})
gm
=
ColoGraphModule
(
model
,
graph
)
physical_mesh_id
=
torch
.
arange
(
0
,
4
)
mesh_shape
=
(
2
,
2
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
)
conv_mod_node
=
list
(
graph
.
nodes
)[
2
]
previous_mod_node
=
list
(
graph
.
nodes
)[
2
]
view_node
=
list
(
graph
.
nodes
)[
3
]
view_strategies_vector
=
StrategiesVector
(
view_node
)
conv
_strategies_vector
=
StrategiesVector
(
conv
_mod_node
)
previous
_strategies_vector
=
StrategiesVector
(
previous
_mod_node
)
# build handler
conv_handler
=
ConvFunctionHandler
(
node
=
conv_mod_node
,
if
model_cls
.
__name__
==
'ConvViewModel'
:
conv_handler
=
ConvFunctionHandler
(
node
=
previous_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
conv
_strategies_vector
)
strategies_vector
=
previous
_strategies_vector
)
conv_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
setattr
(
conv_mod_node
,
'strategies_vector'
,
conv_strategies_vector
)
setattr
(
previous_mod_node
,
'strategies_vector'
,
previous_strategies_vector
)
if
model_cls
.
__name__
==
'LinearViewModel'
:
assert
len
(
previous_strategies_vector
)
==
0
linear_handler
=
LinearFunctionHandler
(
node
=
previous_mod_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
previous_strategies_vector
)
linear_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
setattr
(
previous_mod_node
,
'strategies_vector'
,
previous_strategies_vector
)
view_handler
=
ViewHandler
(
node
=
view_node
,
device_mesh
=
device_mesh
,
strategies_vector
=
view_strategies_vector
)
view_handler
.
register_strategy
(
compute_resharding_cost
=
False
)
...
...
@@ -62,7 +137,10 @@ def test_view_handler():
# make sure they have valid values
assert
op_data
.
data
is
not
None
if
model_cls
.
__name__
==
'ConvViewModel'
:
assert
mapping
[
'input'
].
name
==
"conv2d"
else
:
assert
mapping
[
'input'
].
name
==
"linear"
assert
mapping
[
'input'
].
data
.
is_meta
assert
mapping
[
'input'
].
data
.
shape
==
torch
.
Size
([
8
,
16
,
64
,
64
])
assert
mapping
[
'input'
].
type
==
OperationDataType
.
ARG
...
...
@@ -70,12 +148,16 @@ def test_view_handler():
assert
mapping
[
'output'
].
name
==
"view"
assert
mapping
[
'output'
].
data
.
is_meta
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
(
[
32
,
4
,
32
,
32
,
4
]
)
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
(
tgt_shape
)
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
# reshape handler is a following strategy handler, so the number of strategies is equal to the predecessor node.
assert
len
(
view_strategies_vector
)
==
len
(
conv
_strategies_vector
)
assert
len
(
view_strategies_vector
)
==
len
(
previous
_strategies_vector
)
strategy_name_list
=
[
strategy
.
name
for
strategy
in
view_strategies_vector
]
if
model_cls
.
__name__
==
'ConvViewModel'
:
if
tgt_shape
==
(
32
,
4
,
64
,
16
,
4
):
assert
'[S0, S1, R, R] -> FULLY REPLICATED_0'
in
strategy_name_list
assert
'[S1, S0, R, R] -> FULLY REPLICATED_1'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R, R]_2'
in
strategy_name_list
...
...
@@ -93,6 +175,91 @@ def test_view_handler():
assert
'[R, R, R, R] -> [R, R, R, R, R]_14'
in
strategy_name_list
assert
'[R, S01, R, R] -> FULLY REPLICATED_15'
in
strategy_name_list
if
tgt_shape
==
(
8
,
4
,
4
,
64
,
16
,
4
):
assert
'[S0, S1, R, R] -> [S0, S1, R, R, R, R]_0'
in
strategy_name_list
assert
'[S1, S0, R, R] -> [S1, S0, R, R, R, R]_1'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R, R, R]_2'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R, R, R]_3'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R, R, R]_4'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R, R, R]_5'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, S1, R, R, R, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, S0, R, R, R, R]_7'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_8'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_9'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, S0, R, R, R, R]_10'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, S1, R, R, R, R]_11'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_12'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R, R, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_14'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, S01, R, R, R, R]_15'
in
strategy_name_list
if
model_cls
.
__name__
==
'LinearViewModel'
:
if
tgt_shape
==
(
32
,
4
,
64
,
16
,
4
):
assert
'[S0, R, R, S1] -> [S0, R, R, S1, R]_0'
in
strategy_name_list
assert
'[R, S0, R, S1] -> FULLY REPLICATED_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, R, S0, S1, R]_2'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [S1, R, R, S0, R]_3'
in
strategy_name_list
assert
'[R, S1, R, S0] -> FULLY REPLICATED_4'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, R, S1, S0, R]_5'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> FULLY REPLICATED_7'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, R, S0, R, R]_8'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R, R]_9'
in
strategy_name_list
assert
'[R, S1, R, R] -> FULLY REPLICATED_10'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, R, S1, R, R]_11'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1, R]_12'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R]_14'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R]_15'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, S0, R]_16'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, S1, R]_17'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R, R]_18'
in
strategy_name_list
assert
'[R, S01, R, R] -> FULLY REPLICATED_19'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, R, S01, R, R]_20'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, S01, R]_22'
in
strategy_name_list
if
tgt_shape
==
(
8
,
4
,
4
,
64
,
16
,
4
):
assert
'[S0, R, R, S1] -> [S0, R, R, R, S1, R]_0'
in
strategy_name_list
assert
'[R, S0, R, S1] -> [R, S0, R, R, S1, R]_1'
in
strategy_name_list
assert
'[R, R, S0, S1] -> [R, R, R, S0, S1, R]_2'
in
strategy_name_list
assert
'[S1, R, R, S0] -> [S1, R, R, R, S0, R]_3'
in
strategy_name_list
assert
'[R, S1, R, S0] -> [R, S1, R, R, S0, R]_4'
in
strategy_name_list
assert
'[R, R, S1, S0] -> [R, R, R, S1, S0, R]_5'
in
strategy_name_list
assert
'[S0, R, R, R] -> [S0, R, R, R, R, R]_6'
in
strategy_name_list
assert
'[R, S0, R, R] -> [R, S0, R, R, R, R]_7'
in
strategy_name_list
assert
'[R, R, S0, R] -> [R, R, R, S0, R, R]_8'
in
strategy_name_list
assert
'[S1, R, R, R] -> [S1, R, R, R, R, R]_9'
in
strategy_name_list
assert
'[R, S1, R, R] -> [R, S1, R, R, R, R]_10'
in
strategy_name_list
assert
'[R, R, S1, R] -> [R, R, R, S1, R, R]_11'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, R, S1, R]_12'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, R, S0, R]_13'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_14'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_15'
in
strategy_name_list
assert
'[R, R, R, S0] -> [R, R, R, R, S0, R]_16'
in
strategy_name_list
assert
'[R, R, R, S1] -> [R, R, R, R, S1, R]_17'
in
strategy_name_list
assert
'[S01, R, R, R] -> [S01, R, R, R, R, R]_18'
in
strategy_name_list
assert
'[R, S01, R, R] -> [R, S01, R, R, R, R]_19'
in
strategy_name_list
assert
'[R, R, S01, R] -> [R, R, R, S01, R, R]_20'
in
strategy_name_list
assert
'[R, R, R, R] -> [R, R, R, R, R, R]_21'
in
strategy_name_list
assert
'[R, R, R, S01] -> [R, R, R, R, S01, R]_22'
in
strategy_name_list
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
parameterize
(
'tgt_shape'
,
[(
32
,
4
,
64
,
16
,
4
),
(
8
,
4
,
4
,
64
,
16
,
4
)])
@
parameterize
(
'model_cls'
,
[
ConvViewModel
,
LinearViewModel
])
def
test_view_handler
(
tgt_shape
,
model_cls
):
world_size
=
4
run_func
=
partial
(
check_view_handler
,
tgt_shape
=
tgt_shape
,
model_cls
=
model_cls
,
world_size
=
world_size
,
port
=
free_port
())
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
if
__name__
==
'__main__'
:
test_view_handler
()
tests/test_auto_parallel/test_tensor_shard/test_node_handler/utils.py
View file @
ea0f6b8d
...
...
@@ -87,6 +87,11 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
solution_len
=
len
(
strategies_constructor
.
leaf_strategies
)
solution
=
[
0
]
*
solution_len
solution
[
node_index
]
=
strategy_index
elif
node_type
==
'following'
:
solution_len
=
len
(
strategies_constructor
.
leaf_strategies
)
solution
=
[
0
]
*
solution_len
solution
[
node_index
]
=
strategy_index
solution
[
node_index
+
1
]
=
strategy_index
else
:
node_vector
=
strategies_constructor
.
leaf_strategies
[
node_index
]
strategy_to_keep
=
node_vector
[
strategy_index
]
...
...
@@ -121,7 +126,6 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
grad_to_shard
=
grad_to_shard_dict
[
key
]
grad_to_compare
=
grad_to_compare_dict
[
key
]
assert_close_helper
(
grad_to_shard
,
grad_to_compare
,
strategy_index
=
strategy_index
,
type
=
'input grad'
)
# extract the strategy used in this iter
strategy_in_use
=
target_node
.
strategies_vector
[
strategy_index
]
param_to_shard_dict
=
dict
(
gm
.
named_parameters
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment