Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
22a11540
Unverified
Commit
22a11540
authored
Oct 14, 2022
by
Frank Lee
Committed by
GitHub
Oct 14, 2022
Browse files
[autoparallel] fixed broken node handler tests (#1708)
parent
1468e4bc
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
53 additions
and
49 deletions
+53
-49
colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
...ensor_shard/node_handler/strategy/batch_norm_generator.py
+0
-4
colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
...or_shard/node_handler/strategy/conv_strategy_generator.py
+0
-4
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
...l/tensor_shard/node_handler/strategy/getitem_generator.py
+0
-4
colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
...ensor_shard/node_handler/strategy/layer_norm_generator.py
+0
-4
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
..._shard/node_handler/strategy/matmul_strategy_generator.py
+45
-29
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
.../tensor_shard/node_handler/strategy/strategy_generator.py
+7
-0
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
...l/test_tensor_shard/test_node_handler/test_bmm_handler.py
+0
-2
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py
...nsor_shard/test_node_handler/test_norm_pooling_handler.py
+1
-2
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
View file @
22a11540
...
...
@@ -22,10 +22,6 @@ class BatchNormStrategyGenerator(StrategyGenerator):
In this generator, both methods will be considered.
"""
@
property
def
has_bias
(
self
):
return
'bias'
in
self
.
op_data
def
validate
(
self
)
->
bool
:
'''
In sanity check, we need make sure the input data having correct dimension size.
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
View file @
22a11540
...
...
@@ -17,10 +17,6 @@ class ConvStrategyGenerator(StrategyGenerator):
The operation data is defined as `output = input x other + bias`.
"""
@
property
def
has_bias
(
self
):
return
'bias'
in
self
.
op_data
def
validate
(
self
)
->
bool
:
'''
In sanity check, we need make sure the input data having correct dimension size.
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
View file @
22a11540
...
...
@@ -19,10 +19,6 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
3. args_0._meta_data: Tuple[torch.Tensor], args_1._meta_data: int
"""
@
property
def
has_bias
(
self
):
return
'bias'
in
self
.
op_data
def
validate
(
self
)
->
bool
:
return
super
().
validate
()
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
View file @
22a11540
...
...
@@ -18,10 +18,6 @@ class LayerNormGenerator(StrategyGenerator):
The operation data is defined as `output = input x other + bias`.
"""
@
property
def
has_bias
(
self
):
return
'bias'
in
self
.
op_data
def
validate
(
self
)
->
bool
:
return
super
().
validate
()
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
View file @
22a11540
...
...
@@ -14,10 +14,6 @@ class MatMulStrategyGenerator(StrategyGenerator):
The operation data is defined as `output = input x other + bias`.
"""
@
property
def
has_bias
(
self
):
return
'bias'
in
self
.
op_data
def
update_memory_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
size_mapping
=
{
'input'
:
self
.
_compute_size_in_bytes
(
strategy
,
"input"
),
...
...
@@ -512,11 +508,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication actions
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
)
communication_action_mapping
=
{
"bias"
:
bias_comm_spec
}
communication_action_mapping
=
{}
if
self
.
has_bias
:
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
)
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
...
...
@@ -538,11 +536,14 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication actions
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
communication_action_mapping
=
{
"bias"
:
bias_comm_spec
}
communication_action_mapping
=
{}
if
self
.
has_bias
:
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
communication_action_mappingp
[
'bias'
]
=
bias_comm_spec
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
...
...
@@ -566,15 +567,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication actions
communication_action_mapping
=
{}
other_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'other'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
)
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
communication_action_mapping
=
{
'other'
:
other_comm_spec
,
'bias'
:
bias_comm_spec
}
communication_action_mapping
[
'other'
]
=
other_comm_spec
if
self
.
has_bias
:
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
...
...
@@ -600,15 +606,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication actions
communication_action_mapping
=
{}
input_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'input'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
)
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
communication_action_mapping
=
{
'input'
:
input_comm_spec
,
'bias'
:
bias_comm_spec
}
communication_action_mapping
[
'input'
]
=
input_comm_spec
if
self
.
has_bias
:
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
...
...
@@ -633,15 +644,20 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication actions
communication_action_mapping
=
{}
output_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'output'
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_1
)
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
communication_action_mapping
=
{
'output'
:
output_comm_spec
,
'bias'
:
bias_comm_spec
}
communication_action_mapping
[
'output'
]
=
output_comm_spec
if
self
.
has_bias
:
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
View file @
22a11540
...
...
@@ -24,6 +24,13 @@ class StrategyGenerator(ABC):
self
.
op_data
=
operation_data_mapping
self
.
device_mesh
=
device_mesh
@
property
def
has_bias
(
self
):
"""
A utility method to check for the existence of bias operand for convenience.
"""
return
'bias'
in
self
.
op_data
def
is_param
(
self
,
op_data_name
):
other_data
=
self
.
op_data
[
op_data_name
]
return
other_data
.
type
==
OperationDataType
.
PARAM
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_bmm_handler.py
View file @
22a11540
...
...
@@ -22,7 +22,6 @@ class BMMTorchFunctionModule(nn.Module):
return
torch
.
bmm
(
x1
,
x2
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_2d_device_mesh
(
module
):
...
...
@@ -93,7 +92,6 @@ def test_2d_device_mesh(module):
assert
'Sb1R = Sb1Sk0 x Sb1Sk0'
in
strategy_name_list
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
parametrize
(
'module'
,
[
BMMTensorMethodModule
,
BMMTorchFunctionModule
])
def
test_1d_device_mesh
(
module
):
model
=
module
()
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_norm_pooling_handler.py
View file @
22a11540
...
...
@@ -11,7 +11,6 @@ from colossalai.fx.tracer.meta_patch.patched_module import linear
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
def
test_norm_pool_handler
():
model
=
nn
.
Sequential
(
nn
.
MaxPool2d
(
4
,
padding
=
1
).
to
(
'meta'
))
tracer
=
ColoTracer
()
...
...
@@ -50,7 +49,7 @@ def test_norm_pool_handler():
assert
mapping
[
'output'
].
data
.
shape
==
torch
.
Size
([
4
,
4
,
16
,
16
])
assert
mapping
[
'output'
].
type
==
OperationDataType
.
OUTPUT
strategies_vector
=
handler
.
register_strategy
()
strategies_vector
=
handler
.
register_strategy
(
compute_resharding_cost
=
False
)
strategy_name_list
=
[
val
.
name
for
val
in
strategies_vector
]
assert
len
(
strategy_name_list
)
==
9
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment