Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
27de2523
Unverified
Commit
27de2523
authored
Nov 01, 2022
by
YuliangLiu0306
Committed by
GitHub
Nov 01, 2022
Browse files
[autoparallel] fix conv handler numerical test (#1771)
parent
1e88811c
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
87 additions
and
24 deletions
+87
-24
colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
...or_shard/node_handler/strategy/conv_strategy_generator.py
+87
-22
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
.../test_tensor_shard/test_node_handler/test_conv_handler.py
+0
-2
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
View file @
27de2523
...
@@ -141,14 +141,31 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -141,14 +141,31 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
)
else
:
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
communication_action_mapping
[
"other"
]
=
other_comm_action
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
if
self
.
has_bias
:
if
self
.
is_param
(
'bias'
):
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
)
else
:
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
...
@@ -180,14 +197,31 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -180,14 +197,31 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
)
else
:
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
communication_action_mapping
[
"other"
]
=
other_comm_action
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
if
self
.
has_bias
:
if
self
.
is_param
(
'bias'
):
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
)
else
:
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
...
@@ -230,14 +264,29 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -230,14 +264,29 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
)
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
else
:
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
is_param
(
"bias"
):
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
)
else
:
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
...
@@ -277,7 +326,7 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -277,7 +326,7 @@ class ConvStrategyGenerator(StrategyGenerator):
input_comm_action
=
self
.
get_communication_action
(
input_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"input"
],
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_
0
,
logical_process_axis
=
mesh_dim_
1
,
comm_type
=
CommType
.
BEFORE
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
arg_index
=
0
)
...
@@ -399,14 +448,30 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -399,14 +448,30 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
)
else
:
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
communication_action_mapping
[
"other"
]
=
other_comm_action
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
if
self
.
has_bias
:
if
self
.
is_param
(
"bias"
):
bias_comm_action
=
self
.
get_communication_action
(
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
HOOK
)
comm_type
=
CommType
.
HOOK
)
else
:
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_conv_handler.py
View file @
27de2523
...
@@ -290,7 +290,6 @@ def check_conv_function_handler(rank, bias, world_size, port):
...
@@ -290,7 +290,6 @@ def check_conv_function_handler(rank, bias, world_size, port):
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
1
]
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
1
]
@
pytest
.
mark
.
skip
(
"some cases need to be fixed"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
# We temporarily ban the bias option before doing bias add
# We temporarily ban the bias option before doing bias add
...
@@ -303,7 +302,6 @@ def test_conv_module_handler(bias=False):
...
@@ -303,7 +302,6 @@ def test_conv_module_handler(bias=False):
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
mp
.
spawn
(
run_func
,
nprocs
=
world_size
)
@
pytest
.
mark
.
skip
(
"some cases need to be fixed"
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
# We temporarily ban the bias option before doing bias add
# We temporarily ban the bias option before doing bias add
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment