Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
247a9dbc
Unverified
Commit
247a9dbc
authored
Sep 29, 2022
by
Frank Lee
Committed by
GitHub
Sep 29, 2022
Browse files
[autoparallel] added bias comm spec to matmul strategy (#1664)
parent
746f8f97
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
47 deletions
+68
-47
colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py
...uto_parallel/solver/strategy/matmul_strategy_generator.py
+68
-47
No files found.
colossalai/auto_parallel/solver/strategy/matmul_strategy_generator.py
View file @
247a9dbc
from
audioop
import
bias
import
operator
from
functools
import
reduce
from
..sharding_strategy
import
ShardingStrategy_V2
,
TrainCycleItem
,
MemoryCost
...
...
@@ -121,7 +122,7 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
name
=
f
'S
{
mesh_dim
}
R = S
{
mesh_dim
}
R x R'
# get sharding spec
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{},
"output"
:
{
0
:
[
mesh_dim
]}}
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{},
"output"
:
{
0
:
[
mesh_dim
]}
,
"bias"
:
{}
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication action
...
...
@@ -129,7 +130,11 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
'other'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
)
communication_action_mapping
=
{
'other'
:
other_comm_spec
}
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
)
communication_action_mapping
=
{
'other'
:
other_comm_spec
,
'bias'
:
bias_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
...
...
@@ -236,8 +241,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
communication_action_mapping
=
{
"input"
:
input_comm_spec
,
"other"
:
other_comm_spec
}
communication_action_mapping
=
{
"input"
:
input_comm_spec
,
"other"
:
other_comm_spec
,
"bias"
:
bias_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
...
@@ -272,8 +281,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
"output"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_1
)
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
"bias"
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_1
)
communication_action_mapping
=
{
"input"
:
input_comm_spec
,
'output'
:
output_comm_spec
}
communication_action_mapping
=
{
"input"
:
input_comm_spec
,
'output'
:
output_comm_spec
,
'bias'
:
bias_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
...
...
@@ -390,8 +403,12 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
'other'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
communcation_action_mapping
=
{
"other"
:
other_comm_spec
}
communcation_action_mapping
=
{
"other"
:
other_comm_spec
,
"bias"
:
bias_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communcation_action_mapping
)
...
...
@@ -486,40 +503,22 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def
update_compute_cost
(
self
,
strategy
:
ShardingStrategy_V2
)
->
ShardingStrategy_V2
:
return
self
.
op_data
[
'input'
].
data
.
shape
[
-
1
]
*
reduce
(
operator
.
mul
,
self
.
op_data
[
'output'
].
data
.
shape
)
def
split_one_batch_dim
(
self
):
device_mesh_is_1d
=
True
if
len
(
self
.
device_mesh
.
mesh_shape
)
==
1
:
mesh_dim
=
0
elif
1
in
self
.
device_mesh
.
mesh_shape
:
mesh_dim
=
self
.
device_mesh
.
mesh_shape
.
index
(
1
)
else
:
device_mesh_is_1d
=
False
def
split_one_batch_dim
(
self
,
mesh_dim
):
name
=
f
'Sb
{
mesh_dim
}
= Sb
{
mesh_dim
}
x Sb
{
mesh_dim
}
'
if
device_mesh_is_1d
:
name
=
f
'Sb
{
mesh_dim
}
= Sb
{
mesh_dim
}
x Sb
{
mesh_dim
}
'
# get sharding_spec
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]
},
"other"
:
{
0
:
[
mesh_dim
]
},
"bias"
:
{},
"output"
:
{
0
:
[
mesh_dim
]
}
}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get sharding_spec
dim_partition_dict
=
{
"input"
:
{
0
:
[
mesh_dim
]},
"other"
:
{
0
:
[
mesh_dim
]},
"bias"
:
{},
"output"
:
{
0
:
[
mesh_dim
]}}
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication actions
communication_action_mapping
=
{}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
else
:
return
None
# get communication actions
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim
)
communication_action_mapping
=
{
"bias"
:
bias_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
def
split_two_batch_dim
(
self
,
mesh_dim_0
,
mesh_dim_1
):
name
=
f
'Sb
{
mesh_dim_0
}{
mesh_dim_1
}
= Sb
{
mesh_dim_0
}{
mesh_dim_1
}
x Sb
{
mesh_dim_0
}{
mesh_dim_1
}
'
...
...
@@ -538,7 +537,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping
=
self
.
to_sharding_spec_mapping
(
dim_partition_dict
)
# get communication actions
communication_action_mapping
=
{}
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
communication_action_mapping
=
{
"bias"
:
bias_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
...
...
@@ -566,7 +569,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
'other'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
)
communication_action_mapping
=
{
'other'
:
other_comm_spec
}
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
[
mesh_dim_0
,
mesh_dim_1
])
communication_action_mapping
=
{
'other'
:
other_comm_spec
,
'bias'
:
bias_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
...
...
@@ -596,7 +603,11 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
'input'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_1
)
communication_action_mapping
=
{
'input'
:
input_comm_spec
}
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
communication_action_mapping
=
{
'input'
:
input_comm_spec
,
'bias'
:
bias_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
...
...
@@ -625,21 +636,31 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec
=
sharding_spec_mapping
[
'output'
],
communication_pattern
=
CollectiveCommPattern
.
ALLREDUCE_FWD_IDENTITY_BWD
,
logical_process_axis
=
mesh_dim_1
)
communication_action_mapping
=
{
'output'
:
output_comm_spec
}
bias_comm_spec
=
self
.
get_communication_spec
(
sharding_spec
=
sharding_spec_mapping
[
'bias'
],
communication_pattern
=
CollectiveCommPattern
.
IDENTITY_FWD_ALLREDUCE_BWD
,
logical_process_axis
=
mesh_dim_0
)
communication_action_mapping
=
{
'output'
:
output_comm_spec
,
'bias'
:
bias_comm_spec
}
return
self
.
get_sharding_strategy
(
name
=
name
,
sharding_spec_mapping
=
sharding_spec_mapping
,
communication_action_mapping
=
communication_action_mapping
)
def
generate
(
self
)
->
List
[
ShardingStrategy_V2
]:
strategy_list
=
[]
device_mesh_is_1d
=
True
if
len
(
self
.
device_mesh
.
mesh_shape
)
==
2
and
1
not
in
self
.
device_mesh
.
mesh_shape
:
device_mesh_is_1d
=
False
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
strategy
=
self
.
split_one_batch_dim
()
if
strategy
:
if
device_mesh_is_1d
:
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
# only for 1D device mesh
strategy_list
.
append
(
strategy
)
if
len
(
self
.
device_mesh
.
mesh_shape
)
==
1
:
mesh_dim
=
0
else
:
mesh_dim
=
self
.
device_mesh
.
mesh_shape
.
index
(
1
)
strategy_list
.
append
(
self
.
split_one_batch_dim
(
mesh_dim
))
else
:
# for 2D device mesh
# split batch dim of two inputs and the i dim of the first tensor
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment