Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
b0f7c8bd
Unverified
Commit
b0f7c8bd
authored
Oct 28, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 28, 2022
Browse files
[autoparallel] update CommSpec to CommActions (#1768)
* [autoparallel] update CommSpec to CommActions * polish code
parent
16b0abf9
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
267 additions
and
122 deletions
+267
-122
colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
...auto_parallel/tensor_shard/node_handler/linear_handler.py
+5
-4
colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
...ensor_shard/node_handler/strategy/batch_norm_generator.py
+18
-10
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
...l/tensor_shard/node_handler/strategy/getitem_generator.py
+11
-4
colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
...ensor_shard/node_handler/strategy/layer_norm_generator.py
+18
-9
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
..._shard/node_handler/strategy/matmul_strategy_generator.py
+210
-94
colossalai/tensor/comm_spec.py
colossalai/tensor/comm_spec.py
+3
-1
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
...est_tensor_shard/test_node_handler/test_linear_handler.py
+2
-0
No files found.
colossalai/auto_parallel/tensor_shard/node_handler/linear_handler.py
View file @
b0f7c8bd
...
@@ -202,16 +202,17 @@ class LinearFunctionHandler(NodeHandler):
...
@@ -202,16 +202,17 @@ class LinearFunctionHandler(NodeHandler):
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
mapping
=
{
"input"
:
physical_input_operand
,
"other"
:
physical_other_operand
,
"output"
:
physical_output
}
if
self
.
node
.
args
[
2
]
is
not
None
:
if
'bias'
in
self
.
node
.
kwargs
and
self
.
node
.
kwargs
[
'bias'
]
is
not
None
:
# check if the other operand is a parameter
# check if the other operand is a parameter
if
isinstance
(
self
.
node
.
args
[
2
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
if
isinstance
(
self
.
node
.
kw
args
[
"bias"
].
_meta_data
,
torch
.
nn
.
parameter
.
Parameter
):
data_type
=
OperationDataType
.
PARAM
data_type
=
OperationDataType
.
PARAM
else
:
else
:
data_type
=
OperationDataType
.
ARG
data_type
=
OperationDataType
.
ARG
physical_bias_operand
=
OperationData
(
name
=
str
(
self
.
node
.
args
[
2
]),
physical_bias_operand
=
OperationData
(
name
=
str
(
self
.
node
.
kw
args
[
"bias"
]),
type
=
data_type
,
type
=
data_type
,
data
=
self
.
node
.
args
[
2
].
_meta_data
)
data
=
self
.
node
.
kw
args
[
"bias"
].
_meta_data
)
mapping
[
'bias'
]
=
physical_bias_operand
mapping
[
'bias'
]
=
physical_bias_operand
return
mapping
return
mapping
def
post_process
(
self
,
strategy
:
ShardingStrategy
):
def
post_process
(
self
,
strategy
:
ShardingStrategy
):
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/batch_norm_generator.py
View file @
b0f7c8bd
...
@@ -3,7 +3,12 @@ import operator
...
@@ -3,7 +3,12 @@ import operator
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommType
,
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
...
@@ -204,12 +209,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -204,12 +209,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for weight and bias.
# For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
# to SyncBN operation instead of inserting a communication node.
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
"output"
],
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
AFTER
)
communication_action_mapping
=
{
"output"
:
output_comm_
spec
}
communication_action_mapping
=
{
"output"
:
output_comm_
action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -238,12 +244,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -238,12 +244,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
# to SyncBN operation instead of inserting a communication node.
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
"output"
],
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
AFTER
)
communication_action_mapping
=
{
"output"
:
output_comm_
spec
}
communication_action_mapping
=
{
"output"
:
output_comm_
action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -282,12 +289,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
...
@@ -282,12 +289,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node.
# to SyncBN operation instead of inserting a communication node.
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
"output"
],
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
])
logical_process_axis
=
[
mesh_dim_0
],
comm_type
=
CommType
.
AFTER
)
communication_action_mapping
=
{
"output"
:
output_comm_
spec
}
communication_action_mapping
=
{
"output"
:
output_comm_
action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/getitem_generator.py
View file @
b0f7c8bd
import
copy
import
copy
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommType
,
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
.strategy_generator
import
FollowingStrategyGenerator
from
.strategy_generator
import
FollowingStrategyGenerator
...
@@ -83,11 +88,13 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
...
@@ -83,11 +88,13 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
}
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict_mapping
)
if
gather_input
:
if
gather_input
:
input_communication_
spec
=
self
.
get_communication_
spec
(
input_communication_
action
=
self
.
get_communication_
action
(
sharding_spec_mapping
[
"input"
],
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
communication_pattern
=
CollectiveCommPattern
.
GATHER_FWD_SPLIT_BWD
,
logical_process_axis
=
logical_process_axis
)
logical_process_axis
=
logical_process_axis
,
communication_action_mapping
[
"input"
]
=
input_communication_spec
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
[
"input"
]
=
input_communication_action
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
'
name
=
f
'
{
sharding_spec_mapping
[
"output"
].
sharding_sequence
}
=
{
sharding_spec_mapping
[
"input"
].
sharding_sequence
}
'
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/layer_norm_generator.py
View file @
b0f7c8bd
...
@@ -3,9 +3,16 @@ import operator
...
@@ -3,9 +3,16 @@ import operator
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
)
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
from
colossalai.auto_parallel.tensor_shard.utils
import
(
enumerate_all_possible_1d_sharding
,
CommType
,
enumerate_all_possible_2d_sharding
)
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.auto_parallel.tensor_shard.utils
import
(
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
,
)
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
.strategy_generator
import
StrategyGenerator
from
.strategy_generator
import
StrategyGenerator
...
@@ -107,18 +114,20 @@ class LayerNormGenerator(StrategyGenerator):
...
@@ -107,18 +114,20 @@ class LayerNormGenerator(StrategyGenerator):
total_mesh_dim_list
=
total_mesh_dim_list
[
0
]
total_mesh_dim_list
=
total_mesh_dim_list
[
0
]
communication_action_mapping
=
{}
communication_action_mapping
=
{}
other_comm_
spec
=
self
.
get_communication_
spec
(
other_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
"other"
],
sharding_spec
=
sharding_spec_mapping
[
"other"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
total_mesh_dim_list
)
logical_process_axis
=
total_mesh_dim_list
,
communication_action_mapping
[
"other"
]
=
other_comm_spec
comm_type
=
CommType
.
HOOK
)
communication_action_mapping
[
"other"
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_
spec
=
self
.
get_communication_
spec
(
bias_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
"bias"
],
sharding_spec
=
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
total_mesh_dim_list
)
logical_process_axis
=
total_mesh_dim_list
,
communication_action_mapping
[
"bias"
]
=
bias_comm_spec
comm_type
=
CommType
.
HOOK
)
communication_action_mapping
[
"bias"
]
=
bias_comm_action
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
strategy
=
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
...
colossalai/auto_parallel/tensor_shard/node_handler/strategy/matmul_strategy_generator.py
View file @
b0f7c8bd
import
operator
import
operator
from
ast
import
arg
from
functools
import
reduce
from
functools
import
reduce
from
typing
import
List
from
typing
import
List
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
from
colossalai.auto_parallel.tensor_shard.sharding_strategy
import
(
CommType
,
MemoryCost
,
ShardingStrategy
,
TrainCycleItem
,
)
from
colossalai.auto_parallel.tensor_shard.utils
import
ignore_sharding_exception
from
colossalai.auto_parallel.tensor_shard.utils
import
ignore_sharding_exception
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
from
colossalai.tensor.shape_consistency
import
CollectiveCommPattern
...
@@ -77,11 +83,12 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
...
@@ -77,11 +83,12 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication action
# get communication action
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'output'
],
sharding_spec
=
sharding_spec_mapping
[
'output'
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim
)
logical_process_axis
=
mesh_dim
,
communication_action_mapping
=
{
"output"
:
output_comm_spec
}
comm_type
=
CommType
.
AFTER
)
communication_action_mapping
=
{
"output"
:
output_comm_action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
...
@@ -124,15 +131,35 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
...
@@ -124,15 +131,35 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication action
# get communication action
other_comm_spec
=
self
.
get_communication_spec
(
if
self
.
is_param
(
'other'
):
sharding_spec
=
sharding_spec_mapping
[
'other'
],
other_comm_action
=
self
.
get_communication_action
(
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
sharding_spec
=
sharding_spec_mapping
[
'other'
],
logical_process_axis
=
mesh_dim
)
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
bias_comm_spec
=
self
.
get_communication_spec
(
logical_process_axis
=
mesh_dim
,
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
comm_type
=
CommType
.
HOOK
)
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
else
:
logical_process_axis
=
mesh_dim
)
other_comm_action
=
self
.
get_communication_action
(
communication_action_mapping
=
{
'other'
:
other_comm_spec
,
'bias'
:
bias_comm_spec
}
sharding_spec
=
sharding_spec_mapping
[
'other'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
if
self
.
has_bias
:
if
self
.
is_param
(
'bias'
):
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
HOOK
)
else
:
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
2
)
communication_action_mapping
=
{
'other'
:
other_comm_action
,
'bias'
:
bias_comm_action
}
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
...
@@ -227,24 +254,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -227,24 +254,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# set communication action
# set communication action
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_comm_
spec
=
self
.
get_communication_
spec
(
input_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
"input"
],
sharding_spec
=
sharding_spec_mapping
[
"input"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
)
logical_process_axis
=
mesh_dim_1
,
other_comm_spec
=
self
.
get_communication_spec
(
comm_type
=
CommType
.
BEFORE
,
sharding_spec_mapping
[
"output"
],
arg_index
=
0
)
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
if
self
.
is_param
(
'other'
):
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
else
:
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
communication_action_mapping
[
'input'
]
=
input_comm_
spec
communication_action_mapping
[
'input'
]
=
input_comm_
action
communication_action_mapping
[
'other'
]
=
other_comm_
spec
communication_action_mapping
[
'other'
]
=
other_comm_
action
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_spec
=
self
.
get_communication_spec
(
if
self
.
is_param
(
'bias'
):
sharding_spec_mapping
[
"bias"
],
bias_comm_action
=
self
.
get_communication_action
(
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
sharding_spec_mapping
[
"bias"
],
logical_process_axis
=
mesh_dim_0
)
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
else
:
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
communication_action_mapping
[
'bias'
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -273,24 +321,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -273,24 +321,45 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action mapping
# get communication action mapping
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
"input"
],
output_comm_action
=
self
.
get_communication_action
(
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
output_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
"output"
],
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_1
)
logical_process_axis
=
mesh_dim_1
,
comm_type
=
CommType
.
AFTER
)
communication_action_mapping
[
'input'
]
=
input_comm_spec
if
self
.
is_param
(
'other'
):
communication_action_mapping
[
'output'
]
=
output_comm_spec
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
else
:
other_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
communication_action_mapping
[
'other'
]
=
other_comm_action
communication_action_mapping
[
'output'
]
=
output_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_spec
=
self
.
get_communication_spec
(
if
self
.
is_param
(
'bias'
):
sharding_spec
=
sharding_spec_mapping
[
"bias"
],
bias_comm_action
=
self
.
get_communication_action
(
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
sharding_spec_mapping
[
"bias"
],
logical_process_axis
=
mesh_dim_1
)
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
HOOK
)
else
:
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
,
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
communication_action_mapping
[
'bias'
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -320,16 +389,19 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -320,16 +389,19 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'output'
],
sharding_spec
=
sharding_spec_mapping
[
'output'
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
input_comm_spec
=
self
.
get_communication_spec
(
comm_type
=
CommType
.
AFTER
)
input_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'input'
],
sharding_spec
=
sharding_spec_mapping
[
'input'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
)
logical_process_axis
=
mesh_dim_1
,
communication_action_mapping
[
"input"
]
=
input_comm_spec
comm_type
=
CommType
.
BEFORE
,
communication_action_mapping
[
'output'
]
=
output_comm_spec
arg_index
=
0
)
communication_action_mapping
[
"input"
]
=
input_comm_action
communication_action_mapping
[
'output'
]
=
output_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
...
@@ -354,12 +426,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -354,12 +426,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
# get communication action
communication_action_mapping
=
{}
communication_action_mapping
=
{}
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'output'
],
sharding_spec
=
sharding_spec_mapping
[
'output'
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim
)
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
AFTER
)
communication_action_mapping
[
'output'
]
=
output_comm_
spec
communication_action_mapping
[
'output'
]
=
output_comm_
action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
...
@@ -386,12 +459,14 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -386,12 +459,14 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_comm_
spec
=
self
.
get_communication_
spec
(
input_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'input'
],
sharding_spec
=
sharding_spec_mapping
[
'input'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
)
logical_process_axis
=
mesh_dim
,
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
[
'input'
]
=
input_comm_
spec
communication_action_mapping
[
'input'
]
=
input_comm_
action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
...
@@ -414,18 +489,36 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -414,18 +489,36 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
# get communication action
communication_action_mapping
=
{}
communication_action_mapping
=
{}
other_comm_spec
=
self
.
get_communication_spec
(
if
self
.
is_param
(
'other'
):
sharding_spec
=
sharding_spec_mapping
[
'other'
],
other_comm_action
=
self
.
get_communication_action
(
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
sharding_spec
=
sharding_spec_mapping
[
'other'
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_action_mapping
[
'other'
]
=
other_comm_spec
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
HOOK
)
else
:
other_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'other'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
communication_action_mapping
[
'other'
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_spec
=
self
.
get_communication_spec
(
if
self
.
is_param
(
'bias'
):
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
bias_comm_action
=
self
.
get_communication_action
(
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
HOOK
)
else
:
bias_comm_action
=
self
.
get_communication_action
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
comm_type
=
CommType
.
BEFORE
,
key_for_kwarg
=
'bias'
)
communication_action_mapping
[
'bias'
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
...
@@ -449,11 +542,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -449,11 +542,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
# get communication action
communication_action_mapping
=
{}
communication_action_mapping
=
{}
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'output'
],
sharding_spec
=
sharding_spec_mapping
[
'output'
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
communication_action_mapping
[
'output'
]
=
output_comm_spec
comm_type
=
CommType
.
AFTER
)
communication_action_mapping
[
'output'
]
=
output_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -480,11 +574,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
...
@@ -480,11 +574,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# get communication action
# get communication action
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_comm_
spec
=
self
.
get_communication_
spec
(
input_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'input'
],
sharding_spec
=
sharding_spec_mapping
[
'input'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
communication_action_mapping
[
'input'
]
=
input_comm_spec
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
[
'input'
]
=
input_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -516,8 +612,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -516,8 +612,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
[b, i, k] x [b, k, j] -> [b, i, j]
[b, i, k] x [b, k, j] -> [b, i, j]
The bias term is considered to have a 2D logical shape.
The bias term is considered to have a 2D logical shape.
Note: This class will be used to generate strategies for torch.bmm
and torch.addbmm. However, the result of torch.addbmm is not correct,
some extra runtime apply actions are required to keep numerical correctness.
"""
"""
# TODO: torch.addbmm correctness issue need to be fixed.
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
self
.
squeeze_batch_dim
=
False
self
.
squeeze_batch_dim
=
False
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
...
@@ -566,16 +667,16 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -566,16 +667,16 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
self
.
_pop_batch_dim_sharding_for_output
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
print
(
sharding_spec_mapping
)
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_
spec
=
self
.
get_communication_
spec
(
bias_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
)
logical_process_axis
=
mesh_dim
,
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
[
'bias'
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
communication_action_mapping
=
communication_action_mapping
)
...
@@ -602,11 +703,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -602,11 +703,13 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_
spec
=
self
.
get_communication_
spec
(
bias_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
[
'bias'
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -637,18 +740,24 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -637,18 +740,24 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
other_comm_
spec
=
self
.
get_communication_
spec
(
other_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'other'
],
sharding_spec
=
sharding_spec_mapping
[
'other'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
)
logical_process_axis
=
mesh_dim_1
,
communication_action_mapping
[
'other'
]
=
other_comm_spec
comm_type
=
CommType
.
BEFORE
,
arg_index
=
1
)
communication_action_mapping
[
'other'
]
=
other_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_
spec
=
self
.
get_communication_
spec
(
bias_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
],
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
[
'bias'
]
=
bias_comm_action
# for addbmm case, other is the third argument instead of second.
communication_action_mapping
[
'other'
].
arg_index
+=
1
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -679,18 +788,23 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -679,18 +788,23 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
input_comm_
spec
=
self
.
get_communication_
spec
(
input_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'input'
],
sharding_spec
=
sharding_spec_mapping
[
'input'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
)
logical_process_axis
=
mesh_dim_1
,
communication_action_mapping
[
'input'
]
=
input_comm_spec
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
[
'input'
]
=
input_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_
spec
=
self
.
get_communication_
spec
(
bias_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
comm_type
=
CommType
.
BEFORE
)
communication_action_mapping
[
'bias'
]
=
bias_comm_action
# for addbmm case, other is the second argument instead of first.
communication_action_mapping
[
'input'
].
arg_index
+=
1
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -719,18 +833,21 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -719,18 +833,21 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# get communication actions
# get communication actions
communication_action_mapping
=
{}
communication_action_mapping
=
{}
output_comm_
spec
=
self
.
get_communication_
spec
(
output_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'output'
],
sharding_spec
=
sharding_spec_mapping
[
'output'
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_1
)
logical_process_axis
=
mesh_dim_1
,
communication_action_mapping
[
'output'
]
=
output_comm_spec
comm_type
=
CommType
.
AFTER
)
communication_action_mapping
[
'output'
]
=
output_comm_action
if
self
.
has_bias
:
if
self
.
has_bias
:
bias_comm_
spec
=
self
.
get_communication_
spec
(
bias_comm_
action
=
self
.
get_communication_
action
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
logical_process_axis
=
mesh_dim_0
,
communication_action_mapping
[
'bias'
]
=
bias_comm_spec
comm_type
=
CommType
.
BEFORE
,
arg_index
=
0
)
communication_action_mapping
[
'bias'
]
=
bias_comm_action
return
self
.
get_sharding_strategy
(
name
=
name
,
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
@@ -771,6 +888,5 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
...
@@ -771,6 +888,5 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# split two batch dim
# split two batch dim
strategy_list
.
append
(
self
.
split_two_batch_dim
(
0
,
1
))
strategy_list
.
append
(
self
.
split_two_batch_dim
(
0
,
1
))
strategy_list
.
append
(
self
.
split_two_batch_dim
(
1
,
0
))
return
strategy_list
return
strategy_list
colossalai/tensor/comm_spec.py
View file @
b0f7c8bd
...
@@ -41,7 +41,7 @@ def _split(tensor, comm_spec):
...
@@ -41,7 +41,7 @@ def _split(tensor, comm_spec):
dim
=
comm_spec
.
shard_dim
dim
=
comm_spec
.
shard_dim
length
=
tensor
.
shape
[
comm_spec
.
shard_dim
]
//
len
(
rank_list
)
length
=
tensor
.
shape
[
comm_spec
.
shard_dim
]
//
len
(
rank_list
)
start
=
length
*
rank_list
.
index
(
dist
.
get_rank
())
start
=
length
*
rank_list
.
index
(
dist
.
get_rank
())
output
=
torch
.
narrow
(
tensor
,
dim
,
start
,
length
)
output
=
torch
.
narrow
(
tensor
,
dim
,
start
,
length
)
.
contiguous
()
return
output
return
output
...
@@ -76,6 +76,8 @@ def _all_reduce(tensor, comm_spec):
...
@@ -76,6 +76,8 @@ def _all_reduce(tensor, comm_spec):
process_groups_list
=
comm_spec
.
device_mesh
.
process_groups_dict
[
comm_spec
.
logical_process_axis
]
process_groups_list
=
comm_spec
.
device_mesh
.
process_groups_dict
[
comm_spec
.
logical_process_axis
]
for
rank_list
,
process_group
in
process_groups_list
:
for
rank_list
,
process_group
in
process_groups_list
:
if
dist
.
get_rank
()
in
rank_list
:
if
dist
.
get_rank
()
in
rank_list
:
if
not
tensor
.
is_contiguous
():
tensor
=
tensor
.
contiguous
()
dist
.
all_reduce
(
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
process_group
)
dist
.
all_reduce
(
tensor
,
op
=
ReduceOp
.
SUM
,
group
=
process_group
)
return
tensor
return
tensor
...
...
tests/test_auto_parallel/test_tensor_shard/test_node_handler/test_linear_handler.py
View file @
b0f7c8bd
...
@@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
...
@@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
)
)
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.device.device_mesh
import
DeviceMesh
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.fx
import
ColoGraphModule
,
ColoTracer
from
colossalai.testing.pytest_wrapper
import
run_on_environment_flag
from
colossalai.testing.utils
import
parameterize
from
colossalai.testing.utils
import
parameterize
...
@@ -109,6 +110,7 @@ def test_linear_module_handler(bias):
...
@@ -109,6 +110,7 @@ def test_linear_module_handler(bias):
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
assert
bias_sharding_spec
.
sharding_sequence
[
-
1
]
==
output_sharding_spec
.
sharding_sequence
[
-
1
]
@
run_on_environment_flag
(
name
=
'AUTO_PARALLEL'
)
@
parameterize
(
'bias'
,
[
True
,
False
])
@
parameterize
(
'bias'
,
[
True
,
False
])
def
test_linear_function_handler
(
bias
):
def
test_linear_function_handler
(
bias
):
model
=
nn
.
Linear
(
16
,
32
,
bias
=
bias
).
to
(
'meta'
)
model
=
nn
.
Linear
(
16
,
32
,
bias
=
bias
).
to
(
'meta'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment