Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
a4ce180e
Unverified
Commit
a4ce180e
authored
Oct 20, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 20, 2022
Browse files
[autoparallel] add sequential order to communication actions (#1735)
parent
b893342f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
292 additions
and
89 deletions
+292
-89
colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
...or_shard/node_handler/strategy/conv_strategy_generator.py
+79
-49
colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
...l/tensor_shard/node_handler/strategy/reshape_generator.py
+25
-7
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
.../tensor_shard/node_handler/strategy/strategy_generator.py
+47
-9
colossalai/auto_parallel/tensor_shard/sharding_strategy.py
colossalai/auto_parallel/tensor_shard/sharding_strategy.py
+36
-3
colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py
...x/passes/experimental/adding_shape_consistency_pass_v2.py
+63
-5
colossalai/tensor/comm_spec.py
colossalai/tensor/comm_spec.py
+7
-5
tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py
...parallel/test_tensor_shard/test_shape_consistency_pass.py
+35
-11
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/strategy/conv_strategy_generator.py
View file @
a4ce180e
...
@@ -4,9 +4,18 @@ import warnings
...
@@ -4,9 +4,18 @@ import warnings
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.auto_parallel.tensor_shard.utils
import
\
from
colossalai.auto_parallel.tensor_shard.utils
import
\
ignore_sharding_exception
ignore_sharding_exception
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
...
@@ -122,26 +131,28 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -122,26 +131,28 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# set communication action
# set communication action
input_comm_
spec
=
self
.
get_communication_
spec
(
input_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
"input"
],
sharding_spec
=
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
)
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
BEFORE
)
communication_action_mapping
=
{
"input"
:
input_comm_
spec
}
communication_action_mapping
=
{
"input"
:
input_comm_
action
}
if
self
.
is_param
(
"other"
):
if
self
.
is_param
(
"other"
):
other_comm_
spec
=
self
.
get_communication_
spec
(
other_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
communication_action_mapping
[
"other"
]
=
other_comm_spec
comm_type
=
CommType
.
HOOK
)
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
bias_comm_
spec
=
self
.
get_communication_
spec
(
bias_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
communication_action_mapping
[
"bias"
]
=
bias_comm_spec
comm_type
=
CommType
.
HOOK
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -167,18 +178,20 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -167,18 +178,20 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
if
self
.
is_param
(
"other"
):
if
self
.
is_param
(
"other"
):
other_comm_
spec
=
self
.
get_communication_
spec
(
other_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
communication_action_mapping
[
"other"
]
=
other_comm_spec
comm_type
=
CommType
.
HOOK
)
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
bias_comm_
spec
=
self
.
get_communication_
spec
(
bias_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
communication_action_mapping
[
"bias"
]
=
bias_comm_spec
comm_type
=
CommType
.
HOOK
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -206,26 +219,30 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -206,26 +219,30 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# set communication action
# set communication action
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"output"
],
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_1
)
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
AFTER
,
arg_index
=
0
)
communication_action_mapping
=
{
"output"
:
output_comm_
spec
}
communication_action_mapping
=
{
"output"
:
output_comm_
action
}
if
self
.
is_param
(
"other"
):
if
self
.
is_param
(
"other"
):
other_comm_
spec
=
self
.
get_communication_
spec
(
other_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
communication_action_mapping
[
"other"
]
=
other_comm_spec
comm_type
=
CommType
.
HOOK
)
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
bias_comm_
spec
=
self
.
get_communication_
spec
(
bias_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
communication_action_mapping
[
"bias"
]
=
bias_comm_spec
comm_type
=
CommType
.
HOOK
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -256,16 +273,20 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -256,16 +273,20 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# set communication action
# set communication action
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"output"
],
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
input_comm_spec
=
self
.
get_communication_spec
(
comm_type
=
CommType
.
AFTER
,
arg_index
=
0
)
input_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"input"
],
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
=
{
"output"
:
output_comm_
spec
,
"input"
:
input_comm_
spec
}
communication_action_mapping
=
{
"output"
:
output_comm_
action
,
"input"
:
input_comm_
action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -291,12 +312,14 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -291,12 +312,14 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# set communication action
# set communication action
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"output"
],
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
AFTER
,
arg_index
=
0
)
communication_action_mapping
=
{
"output"
:
output_comm_
spec
}
communication_action_mapping
=
{
"output"
:
output_comm_
action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -324,12 +347,13 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -324,12 +347,13 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# set communication action
# set communication action
input_comm_
spec
=
self
.
get_communication_
spec
(
input_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"input"
],
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
)
communication_action_mapping
=
{
"input"
:
input_comm_
spec
}
communication_action_mapping
=
{
"input"
:
input_comm_
action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -375,18 +399,20 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -375,18 +399,20 @@ class ConvStrategyGenerator(StrategyGenerator):
communication_action_mapping
=
{}
communication_action_mapping
=
{}
if
self
.
is_param
(
"other"
):
if
self
.
is_param
(
"other"
):
other_comm_
spec
=
self
.
get_communication_
spec
(
other_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"other"
],
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
communication_action_mapping
[
"other"
]
=
other_comm_spec
comm_type
=
CommType
.
HOOK
)
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
if
self
.
has_bias
and
self
.
is_param
(
"bias"
):
bias_comm_
spec
=
self
.
get_communication_
spec
(
bias_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"bias"
],
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
communication_action_mapping
[
"bias"
]
=
bias_comm_spec
comm_type
=
CommType
.
HOOK
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -411,12 +437,14 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -411,12 +437,14 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# set communication action
# set communication action
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"output"
],
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
AFTER
,
arg_index
=
0
)
communication_action_mapping
=
{
"output"
:
output_comm_
spec
}
communication_action_mapping
=
{
"output"
:
output_comm_
action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -443,12 +471,14 @@ class ConvStrategyGenerator(StrategyGenerator):
...
@@ -443,12 +471,14 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
# set communication action
# set communication action
input_comm_
spec
=
self
.
get_communication_
spec
(
input_comm_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"input"
],
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
=
{
"input"
:
input_comm_
spec
}
communication_action_mapping
=
{
"input"
:
input_comm_
action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/reshape_generator.py
View file @
a4ce180e
import
copy
import
copy
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
.strategy_generator
import
FollowingStrategyGenerator
from
.strategy_generator
import
FollowingStrategyGenerator
...
@@ -81,12 +88,23 @@ class ReshapeGenerator(FollowingStrategyGenerator):
...
@@ -81,12 +88,23 @@ class ReshapeGenerator(FollowingStrategyGenerator):
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if
len
(
total_mesh_dim_list
)
==
1
:
if
len
(
total_mesh_dim_list
)
==
1
:
total_mesh_dim_list
=
total_mesh_dim_list
[
0
]
total_mesh_dim_list
=
total_mesh_dim_list
[
0
]
input_comm_action
=
self
.
get_communication_action
(
input_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
"input"
],
sharding_spec
=
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
logical_process_axis
=
total_mesh_dim_list
,
logical_process_axis
=
total_mesh_dim_list
)
comm_type
=
CommType
.
BEFORE
,
communication_action_mapping
[
"input"
]
=
input_comm_spec
arg_index
=
0
)
input_comm_action
.
comm_spec
.
gather_dim
=
total_mesh_dim_list
else
:
source_spec
=
sharding_spec_mapping
[
"input"
]
target_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
source_spec
.
entire_shape
,
dim_partition_dict
=
{})
comm_spec
=
{
'src_spec'
:
source_spec
,
'tgt_spec'
:
target_spec
}
input_comm_action
=
CommAction
(
comm_spec
=
comm_spec
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
[
"input"
]
=
input_comm_action
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/strategy_generator.py
View file @
a4ce180e
...
@@ -4,17 +4,27 @@ from functools import reduce
...
@@ -4,17 +4,27 @@ from functools import reduce
from
typing
import
Any
,
Dict
,
List
,
Union
from
typing
import
Any
,
Dict
,
List
,
Union
import
torch
import
torch
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
OperationData
,
OperationDataType
,
ShardingStrategy
,
TrainCycleItem
)
from
torch.fx
import
Node
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommAction
,
CommType
,
OperationData
,
OperationDataType
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
,
CommSpec
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
,
CommSpec
,
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
torch.fx
import
Node
from
torch.fx
import
Node
class
StrategyGenerator
(
ABC
):
class
StrategyGenerator
(
ABC
):
"""
"""
StrategyGenerator is used to generate the same group of sharding strategies.
StrategyGenerator is used to generate the same group of sharding strategies.
TODO: remove the original strategy_generator.py after refactoring
TODO: remove the original strategy_generator.py after refactoring
"""
"""
...
@@ -97,6 +107,21 @@ class StrategyGenerator(ABC):
...
@@ -97,6 +107,21 @@ class StrategyGenerator(ABC):
sharding_spec
=
sharding_spec
,
sharding_spec
=
sharding_spec
,
logical_process_axis
=
logical_process_axis
)
logical_process_axis
=
logical_process_axis
)
def
get_communication_action
(
self
,
sharding_spec
:
ShardingSpec
,
communication_pattern
:
CollectiveCommPattern
,
logical_process_axis
:
Union
[
int
,
List
[
int
]],
comm_type
:
CommType
,
arg_index
:
int
=
-
1
)
->
CommAction
:
"""
A factory method to produce a CommAction object.
"""
return
CommAction
(
comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec
,
communication_pattern
=
communication_pattern
,
logical_process_axis
=
logical_process_axis
),
comm_type
=
comm_type
,
arg_index
=
arg_index
)
def
update_communication_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
def
update_communication_cost
(
self
,
strategy
:
ShardingStrategy
)
->
ShardingStrategy
:
"""
"""
Compute the communication cost involved in the forward and backward iteration.
Compute the communication cost involved in the forward and backward iteration.
...
@@ -117,8 +142,21 @@ class StrategyGenerator(ABC):
...
@@ -117,8 +142,21 @@ class StrategyGenerator(ABC):
# check if communication action exists
# check if communication action exists
# if so, loop over each action and compute the cost of each action
# if so, loop over each action and compute the cost of each action
if
strategy
.
communication_actions
is
not
None
:
if
strategy
.
communication_actions
is
not
None
:
for
operand
,
comm_spec
in
strategy
.
communication_actions
.
items
():
for
operand
,
comm_action
in
strategy
.
communication_actions
.
items
():
_compute_and_add
(
operand
,
comm_spec
)
if
isinstance
(
comm_action
,
CommAction
):
comm_spec
=
comm_action
.
comm_spec
else
:
# this condition branch will be removed after all the handler updated.
comm_spec
=
comm_action
if
isinstance
(
comm_spec
,
dict
):
src_spec
=
comm_spec
[
'src_spec'
]
tgt_spec
=
comm_spec
[
'tgt_spec'
]
shape_consistency_manager
=
ShapeConsistencyManager
()
_
,
comm_action_sequence
,
_
=
shape_consistency_manager
.
shape_consistency
(
src_spec
,
tgt_spec
)
for
comm_spec_
in
comm_action_sequence
:
_compute_and_add
(
operand
,
comm_spec_
)
else
:
_compute_and_add
(
operand
,
comm_spec
)
# update the communication cost attribute in-place
# update the communication cost attribute in-place
strategy
.
communication_cost
=
comm_cost
strategy
.
communication_cost
=
comm_cost
...
@@ -141,7 +179,7 @@ class StrategyGenerator(ABC):
...
@@ -141,7 +179,7 @@ class StrategyGenerator(ABC):
def
_compute_size_in_bytes
(
self
,
strategy
:
ShardingStrategy
,
key
:
str
):
def
_compute_size_in_bytes
(
self
,
strategy
:
ShardingStrategy
,
key
:
str
):
"""
"""
Compute the size of a tensor in bytes.
Compute the size of a tensor in bytes.
Args:
Args:
strategy (ShardingStrategy): the ShardingStrategy generated.
strategy (ShardingStrategy): the ShardingStrategy generated.
key (str): the name of the operation data defined by the generator.
key (str): the name of the operation data defined by the generator.
...
@@ -182,7 +220,7 @@ class StrategyGenerator(ABC):
...
@@ -182,7 +220,7 @@ class StrategyGenerator(ABC):
@
abstractmethod
@
abstractmethod
def
validate
(
self
)
->
bool
:
def
validate
(
self
)
->
bool
:
"""
"""
Validate if the operands are of desired shape.
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
If True, means this generator can be used for the current operation.
"""
"""
pass
pass
...
@@ -190,7 +228,7 @@ class StrategyGenerator(ABC):
...
@@ -190,7 +228,7 @@ class StrategyGenerator(ABC):
class
FollowingStrategyGenerator
(
StrategyGenerator
):
class
FollowingStrategyGenerator
(
StrategyGenerator
):
"""
"""
FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node.
FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node.
TODO: remove the original strategy_generator.py after refactoring
TODO: remove the original strategy_generator.py after refactoring
"""
"""
...
...
colossalai/auto_parallel/tensor_shard/sharding_strategy.py
View file @
a4ce180e
...
@@ -4,11 +4,12 @@ from enum import Enum
...
@@ -4,11 +4,12 @@ from enum import Enum
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Tuple
,
Union
import
torch
import
torch
from
torch.fx.node
import
Node
from
colossalai.tensor.shape_consistency
import
CommSpec
from
colossalai.tensor.shape_consistency
import
CommSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
torch.fx.node
import
Node
from
.constants
import
(
BCAST_FUNC_OP
,
ELEMENTWISE_FUNC_OP
,
ELEMENTWISE_MODULE_OP
,
RESHAPE_FUNC_OP
)
from
.constants
import
BCAST_FUNC_OP
,
ELEMENTWISE_FUNC_OP
,
ELEMENTWISE_MODULE_OP
,
RESHAPE_FUNC_OP
__all__
=
[
'OperationDataType'
,
'OperationData'
,
'TrainCycleItem'
,
'MemoryCost'
,
'ShardingStrategy'
,
'StrategiesVector'
]
__all__
=
[
'OperationDataType'
,
'OperationData'
,
'TrainCycleItem'
,
'MemoryCost'
,
'ShardingStrategy'
,
'StrategiesVector'
]
...
@@ -84,6 +85,38 @@ class MemoryCost:
...
@@ -84,6 +85,38 @@ class MemoryCost:
buffer
:
int
=
0
buffer
:
int
=
0
class
CommType
(
Enum
):
"""
CommType describes the sequential order of a communication action and a computation action.
Meaning:
BEFORE: the communication action happens just before the computation operation.
AFTER: the communication action happens after the computation operation.
HOOK: the communication action is used to do the grad all reduce.
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
"""
BEFORE
=
0
AFTER
=
1
HOOK
=
2
IMPLICIT
=
3
@
dataclass
class
CommAction
:
"""
CommAction is used to record the communication action.
Args:
comm_spec: express the communication pattern and the process groups to execute the communication action.
comm_type: describes the sequential order of a communication action and a computation action.
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
because the args of node may be changed by graph transform passes.
"""
comm_spec
:
CommSpec
=
None
comm_type
:
CommType
=
None
arg_index
:
int
=
-
1
@
dataclass
@
dataclass
class
ShardingStrategy
:
class
ShardingStrategy
:
"""
"""
...
@@ -102,7 +135,7 @@ class ShardingStrategy:
...
@@ -102,7 +135,7 @@ class ShardingStrategy:
compute_cost
:
TrainCycleItem
=
None
compute_cost
:
TrainCycleItem
=
None
communication_cost
:
TrainCycleItem
=
None
communication_cost
:
TrainCycleItem
=
None
memory_cost
:
TrainCycleItem
=
None
memory_cost
:
TrainCycleItem
=
None
communication_actions
:
Dict
[
OperationData
,
Comm
Spec
]
=
None
communication_actions
:
Dict
[
OperationData
,
Comm
Action
]
=
None
resharding_costs
:
Dict
[
Node
,
List
[
TrainCycleItem
]]
=
None
resharding_costs
:
Dict
[
Node
,
List
[
TrainCycleItem
]]
=
None
@
property
@
property
...
...
colossalai/fx/passes/experimental/adding_shape_consistency_pass_v2.py
View file @
a4ce180e
...
@@ -8,8 +8,10 @@ import torch
...
@@ -8,8 +8,10 @@ import torch
from
torch.fx
import
symbolic_trace
from
torch.fx
import
symbolic_trace
from
torch.fx.node
import
Node
from
torch.fx.node
import
Node
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
CommAction
,
CommType
,
OperationDataType
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.passes.split_module
import
split_module
from
colossalai.fx.passes.split_module
import
split_module
from
colossalai.tensor.comm_spec
import
CommSpec
,
_all_reduce
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
_DimSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
,
_DimSpec
...
@@ -19,9 +21,9 @@ shape_consistency_manager = ShapeConsistencyManager()
...
@@ -19,9 +21,9 @@ shape_consistency_manager = ShapeConsistencyManager()
class
ConsistencyApply
(
torch
.
autograd
.
Function
):
class
ConsistencyApply
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
node
,
origin_
dict
,
input_dict
,
node_index
,
user_node_index
):
def
forward
(
ctx
,
node
,
origin_
sharding_spec
,
target_sharding_spec
):
ctx
.
origin_sharding_spec
=
origin_
dict
[
node_index
]
ctx
.
origin_sharding_spec
=
origin_
sharding_spec
ctx
.
target_sharding_spec
=
input_dict
[
node_index
][
user_node_index
]
ctx
.
target_sharding_spec
=
target_sharding_spec
return
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
,
ctx
.
origin_sharding_spec
,
return
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
,
ctx
.
origin_sharding_spec
,
ctx
.
target_sharding_spec
)
ctx
.
target_sharding_spec
)
...
@@ -32,7 +34,9 @@ class ConsistencyApply(torch.autograd.Function):
...
@@ -32,7 +34,9 @@ class ConsistencyApply(torch.autograd.Function):
def
runtime_apply_for_leaf_node
(
node
,
origin_dict
,
input_dict
,
node_index
,
user_node_index
):
def
runtime_apply_for_leaf_node
(
node
,
origin_dict
,
input_dict
,
node_index
,
user_node_index
):
return
ConsistencyApply
.
apply
(
node
,
origin_dict
,
input_dict
,
node_index
,
user_node_index
)
origin_sharding_spec
=
origin_dict
[
node_index
]
target_sharding_spec
=
input_dict
[
node_index
][
user_node_index
]
return
ConsistencyApply
.
apply
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
def
runtime_apply
(
node
,
origin_dict
,
input_dict
,
node_index
,
user_node_index
):
def
runtime_apply
(
node
,
origin_dict
,
input_dict
,
node_index
,
user_node_index
):
...
@@ -41,6 +45,18 @@ def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index):
...
@@ -41,6 +45,18 @@ def runtime_apply(node, origin_dict, input_dict, node_index, user_node_index):
return
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
return
shape_consistency_manager
.
apply_for_autoparallel_runtime
(
node
,
origin_sharding_spec
,
target_sharding_spec
)
def
runtime_comm_spec_apply
(
tensor
,
comm_actions_dict
,
node_index
,
op_data
):
comm_action
=
comm_actions_dict
[
node_index
][
op_data
]
if
isinstance
(
comm_action
.
comm_spec
,
CommSpec
):
rst
=
comm_action
.
comm_spec
.
covert_spec_to_action
(
tensor
)
else
:
origin_sharding_spec
=
comm_action
.
comm_spec
[
'src_spec'
]
tgt_sharding_spec
=
comm_action
.
comm_spec
[
'tgt_spec'
]
rst
=
ConsistencyApply
.
apply
(
tensor
,
origin_sharding_spec
,
tgt_sharding_spec
)
return
rst
def
solution_annotatation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
):
def
solution_annotatation_pass
(
gm
:
torch
.
fx
.
GraphModule
,
solution
:
List
[
int
],
device_mesh
):
mod_graph
=
gm
.
graph
mod_graph
=
gm
.
graph
nodes
=
tuple
(
mod_graph
.
nodes
)
nodes
=
tuple
(
mod_graph
.
nodes
)
...
@@ -63,6 +79,16 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
...
@@ -63,6 +79,16 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
setattr
(
param
,
'sharding_spec'
,
origin_sharding_spec
)
setattr
(
param
,
'sharding_spec'
,
origin_sharding_spec
)
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
target_sharding_spec
=
node
.
best_strategy
.
get_sharding_spec_by_name
(
name
)
shape_consistency_manager
.
apply
(
param
,
target_sharding_spec
)
shape_consistency_manager
.
apply
(
param
,
target_sharding_spec
)
comm_actions
=
node
.
best_strategy
.
communication_actions
for
operation_data
,
comm_action
in
comm_actions
.
items
():
comm_spec_to_use
=
comm_action
.
comm_spec
if
operation_data
.
type
==
OperationDataType
.
PARAM
and
operation_data
.
name
==
name
and
comm_action
.
comm_type
==
CommType
.
HOOK
:
def
hook_fn
(
grad
):
_all_reduce
(
grad
,
comm_spec_to_use
)
param
.
register_hook
(
hook_fn
)
for
name
,
buffer
in
target_module
.
named_buffers
():
for
name
,
buffer
in
target_module
.
named_buffers
():
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
buffer
.
shape
,
{})
origin_sharding_spec
=
ShardingSpec
(
device_mesh
,
buffer
.
shape
,
{})
...
@@ -79,15 +105,24 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
...
@@ -79,15 +105,24 @@ def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], de
target_sharding_specs
.
append
(
target_sharding_spec
)
target_sharding_specs
.
append
(
target_sharding_spec
)
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
sharding_spec_convert_dict
[
index
]
=
target_sharding_specs
# the dict to record comm actions of nodes
comm_actions_dict
=
{}
for
index
,
node
in
enumerate
(
nodes
):
comm_action_dict
=
{}
for
op_data
,
comm_action
in
node
.
best_strategy
.
communication_actions
.
items
():
comm_action_dict
[
op_data
.
name
]
=
comm_action
comm_actions_dict
[
index
]
=
comm_action_dict
# add above dicts into graph
# add above dicts into graph
for
node
in
nodes
:
for
node
in
nodes
:
if
node
.
op
!=
'placeholder'
:
if
node
.
op
!=
'placeholder'
:
with
mod_graph
.
inserting_before
(
node
):
with
mod_graph
.
inserting_before
(
node
):
input_specs_node
=
mod_graph
.
create_node
(
'placeholder'
,
target
=
'sharding_spec_convert_dict'
)
input_specs_node
=
mod_graph
.
create_node
(
'placeholder'
,
target
=
'sharding_spec_convert_dict'
)
origin_specs_node
=
mod_graph
.
create_node
(
'placeholder'
,
target
=
'origin_node_sharding_spec_dict'
)
origin_specs_node
=
mod_graph
.
create_node
(
'placeholder'
,
target
=
'origin_node_sharding_spec_dict'
)
comm_actions_dict_node
=
mod_graph
.
create_node
(
'placeholder'
,
target
=
'comm_actions_dict'
)
break
break
return
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
return
sharding_spec_convert_dict
,
origin_node_sharding_spec_dict
,
comm_actions_dict
def
shape_consistency_pass
(
gm
:
torch
.
fx
.
GraphModule
):
def
shape_consistency_pass
(
gm
:
torch
.
fx
.
GraphModule
):
...
@@ -106,6 +141,9 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
...
@@ -106,6 +141,9 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
if
node
.
target
==
'origin_node_sharding_spec_dict'
:
if
node
.
target
==
'origin_node_sharding_spec_dict'
:
origin_dict_node
=
node
origin_dict_node
=
node
continue
continue
if
node
.
target
==
'comm_actions_dict'
:
comm_actions_dict_node
=
node
continue
if
not
hasattr
(
node
,
'best_strategy'
):
if
not
hasattr
(
node
,
'best_strategy'
):
continue
continue
node_to_index_dict
[
node
]
=
index
node_to_index_dict
[
node
]
=
index
...
@@ -138,4 +176,24 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
...
@@ -138,4 +176,24 @@ def shape_consistency_pass(gm: torch.fx.GraphModule):
new_args
[
origin_index_args
]
=
shape_consistency_node
new_args
[
origin_index_args
]
=
shape_consistency_node
user_node
.
args
=
new_args
user_node
.
args
=
new_args
comm_actions
=
node
.
best_strategy
.
communication_actions
for
op_data
,
comm_action
in
comm_actions
.
items
():
comm_object
=
node
.
args
[
comm_action
.
arg_index
]
if
op_data
.
type
==
OperationDataType
.
ARG
:
if
comm_action
.
comm_type
==
CommType
.
BEFORE
:
with
mod_graph
.
inserting_before
(
node
):
comm_spec_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
runtime_comm_spec_apply
,
args
=
(
comm_object
,
comm_actions_dict_node
,
node_to_index_dict
[
node
],
op_data
.
name
))
elif
comm_action
.
comm_type
==
CommType
.
AFTER
:
with
mod_graph
.
inserting_after
(
node
):
comm_spec_apply_node
=
mod_graph
.
create_node
(
'call_function'
,
runtime_comm_spec_apply
,
args
=
(
comm_object
,
comm_actions_dict_node
,
node_to_index_dict
[
node
],
op_data
.
name
))
# TODO: consider other OperationDataType, such as OperationDataType.OUTPUT
new_args
=
list
(
node
.
args
)
new_args
[
comm_action
.
arg_index
]
=
comm_spec_apply_node
node
.
args
=
new_args
return
gm
return
gm
colossalai/tensor/comm_spec.py
View file @
a4ce180e
import
tor
ch
import
opera
tor
from
enum
import
Enum
from
enum
import
Enum
import
torch.distributed
as
dist
from
functools
import
reduce
from
functools
import
reduce
import
operator
import
torch
import
torch.distributed
as
dist
from
torch.distributed
import
ReduceOp
from
torch.distributed
import
ReduceOp
__all__
=
[
__all__
=
[
...
@@ -238,7 +239,7 @@ class CommSpec:
...
@@ -238,7 +239,7 @@ class CommSpec:
1. Compute the communication cost which will be used in auto parallel solver.
1. Compute the communication cost which will be used in auto parallel solver.
2. Convert the communication spec to real action which will be used in runtime.
2. Convert the communication spec to real action which will be used in runtime.
It contains comm_pattern to determine the
It contains comm_pattern to determine the
communication method, sharding_spec to determine the communication size, gather_dim and shard_dim
communication method, sharding_spec to determine the communication size, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis
to determine the buffer shape, and logical_process_axis
Argument:
Argument:
...
@@ -296,7 +297,7 @@ class CommSpec:
...
@@ -296,7 +297,7 @@ class CommSpec:
'''
'''
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost.
compute the communication cost.
For shard operation, it is an on-chip operation, so the communication cost is zero.
For shard operation, it is an on-chip operation, so the communication cost is zero.
'''
'''
comm_size
=
reduce
(
operator
.
mul
,
self
.
sharding_spec
.
get_sharded_shape_per_device
(),
1
)
comm_size
=
reduce
(
operator
.
mul
,
self
.
sharding_spec
.
get_sharded_shape_per_device
(),
1
)
cost_dict
=
{}
cost_dict
=
{}
...
@@ -347,6 +348,7 @@ class CommSpec:
...
@@ -347,6 +348,7 @@ class CommSpec:
tensor
.
data
=
pattern_to_func_dict
[
self
.
comm_pattern
](
tensor
,
self
)
tensor
.
data
=
pattern_to_func_dict
[
self
.
comm_pattern
](
tensor
,
self
)
else
:
else
:
tensor
.
data
=
tensor
tensor
.
data
=
tensor
return
tensor
pattern_to_func_dict
=
{
pattern_to_func_dict
=
{
...
...
tests/test_auto_parallel/test_tensor_shard/test_shape_consistency_pass.py
View file @
a4ce180e
import
copy
from
functools
import
partial
from
functools
import
partial
import
pytest
import
pytest
...
@@ -6,15 +7,22 @@ import torch.multiprocessing as mp
...
@@ -6,15 +7,22 @@ import torch.multiprocessing as mp
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.fx
import
GraphModule
from
torch.fx
import
GraphModule
from
colossalai.auto_parallel.tensor_shard.solver
import
(
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
from
colossalai.auto_parallel.tensor_shard.solver
import
(
StrategiesConstructor
)
CostGraph
,
GraphAnalyser
,
Solver
,
SolverOptions
,
StrategiesConstructor
,
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2
import
(
shape_consistency_pass
,
from
colossalai.fx.passes.experimental.adding_shape_consistency_pass_v2
import
(
solution_annotatation_pass
)
shape_consistency_pass
,
solution_annotatation_pass
,
)
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.fx.tracer.tracer
import
ColoTracer
from
colossalai.initialize
import
launch
from
colossalai.initialize
import
launch
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
assert_close
,
rerun_if_address_is_in_use
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
...
@@ -27,6 +35,7 @@ class ConvModel(nn.Module):
...
@@ -27,6 +35,7 @@ class ConvModel(nn.Module):
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
x
=
self
.
conv
(
x
)
x
=
self
.
conv
(
x
)
x
=
torch
.
flatten
(
x
)
return
x
return
x
...
@@ -38,12 +47,13 @@ def check_apply(rank, world_size, port):
...
@@ -38,12 +47,13 @@ def check_apply(rank, world_size, port):
mesh_shape
=
(
2
,
2
)
mesh_shape
=
(
2
,
2
)
# [[0, 1]
# [[0, 1]
# [2, 3]]
# [2, 3]]
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
False
)
device_mesh
=
DeviceMesh
(
physical_mesh_id
,
mesh_shape
,
init_process_group
=
True
)
entire_shape
=
torch
.
Size
((
4
,
4
,
8
,
8
))
tracer
=
ColoTracer
()
tracer
=
ColoTracer
()
model
=
ConvModel
(
4
,
4
).
cuda
()
model
=
ConvModel
(
4
,
4
).
cuda
()
origin_output
=
model
(
input
)
test_model
=
copy
.
deepcopy
(
model
)
test_input
=
copy
.
deepcopy
(
input
)
input_sample
=
{
'x'
:
torch
.
rand
(
4
,
4
,
4
,
4
).
to
(
'meta'
)}
input_sample
=
{
'x'
:
torch
.
rand
(
4
,
4
,
4
,
4
).
to
(
'meta'
)}
# graph():
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %x : torch.Tensor [#users=1] = placeholder[target=x]
...
@@ -62,16 +72,30 @@ def check_apply(rank, world_size, port):
...
@@ -62,16 +72,30 @@ def check_apply(rank, world_size, port):
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
solver
=
Solver
(
gm
.
graph
,
strategies_constructor
,
cost_graph
,
graph_analyser
)
ret
=
solver
.
call_solver_serialized_args
()
ret
=
solver
.
call_solver_serialized_args
()
solution
=
list
(
ret
[
0
])
solution
=
list
(
ret
[
0
])
device_mesh
.
process_groups_dict
=
device_mesh
.
create_process_groups_for_logical_mesh
()
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
=
solution_annotatation_pass
(
gm
,
solution
,
device_mesh
)
sharding_spec_dict
,
origin_spec_dict
=
solution_annotatation_pass
(
gm
,
solution
,
device_mesh
)
shape_consistency_pass
(
gm
)
shape_consistency_pass
(
gm
)
gm
.
recompile
()
gm
.
recompile
()
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
nodes
=
[
node
for
node
in
gm
.
graph
.
nodes
]
# TODO: wrap the gm to avoid the influence of the user training code
# TODO: wrap the gm to avoid the influence of the user training code
output
=
gm
(
input
,
sharding_spec_dict
,
origin_spec_dict
)
output
=
gm
(
input
,
sharding_spec_dict
,
origin_spec_dict
,
comm_actions_dict
)
origin_output
=
test_model
(
test_input
)
assert
output
.
equal
(
origin_output
)
assert
output
.
equal
(
origin_output
)
origin_loss
=
origin_output
.
sum
()
loss
=
output
.
sum
()
origin_loss
.
backward
()
loss
.
backward
()
grad_0
=
test_model
.
conv
.
weight
.
grad
.
narrow
(
0
,
0
,
2
)
grad_1
=
test_model
.
conv
.
weight
.
grad
.
narrow
(
0
,
2
,
2
)
if
rank
in
(
0
,
1
):
assert_close
(
gm
.
conv
.
weight
.
grad
.
data
,
grad_0
.
data
)
elif
rank
in
(
2
,
3
):
assert_close
(
gm
.
conv
.
weight
.
grad
.
data
,
grad_1
.
data
)
# skip this test due to pulp not installed in CI environment
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
pytest
.
mark
.
dist
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
@
rerun_if_address_is_in_use
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment