Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
69448f64
Unverified
Commit
69448f64
authored
Sep 23, 2022
by
YuliangLiu0306
Committed by
GitHub
Sep 23, 2022
Browse files
[autoparallel] protect bcast handler from invalid strategies (#1631)
parent
0c703189
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
8 deletions
+19
-8
colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py
...salai/auto_parallel/solver/op_handler/bcast_op_handler.py
+19
-8
No files found.
colossalai/auto_parallel/solver/op_handler/bcast_op_handler.py
View file @
69448f64
...
...
@@ -40,6 +40,11 @@ class BcastOpHandler(OperatorHandler):
for
dim_index
,
_
in
dim_partition_dict
.
items
():
if
shape
[
dim_index
]
==
1
:
processed_dim_partition_dict
.
pop
(
dim_index
)
for
dim_index
,
sharding_index_list
in
processed_dim_partition_dict
.
items
():
sharding_list
=
[
self
.
device_mesh
.
mesh_shape
[
sharding_index
]
for
sharding_index
in
sharding_index_list
]
sharding_size
=
reduce
(
operator
.
mul
,
sharding_list
,
1
)
assert
shape
[
dim_index
]
%
sharding_size
==
0
,
f
'we cannot shard the
{
dim_index
}
dimension of tensor into
{
sharding_size
}
partitions.'
sharding_spec
=
ShardingSpec
(
device_mesh
=
self
.
device_mesh
,
entire_shape
=
shape
,
dim_partition_dict
=
processed_dim_partition_dict
)
...
...
@@ -83,14 +88,10 @@ class BcastOpHandler(OperatorHandler):
entire_shape
=
new_entire_shape
,
dim_partition_dict
=
new_dim_partition_dict
)
# compute the resharding cost
during forward phase
_
,
_
,
resharding_cost
_forward
=
shape_consistency_manager
.
shape_consistency
(
# compute the resharding cost
_
,
_
,
total_
resharding_cost
=
shape_consistency_manager
.
shape_consistency
(
input_sharding_spec
,
input_spec
)
_
,
_
,
resharding_cost_backward
=
shape_consistency_manager
.
shape_consistency
(
input_spec
,
input_sharding_spec
)
total_resharding_cost
=
resharding_cost_forward
+
resharding_cost_backward
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost
=
total_resharding_cost
*
size_per_elem_bytes
resharding_costs
[
input_node
].
append
(
resharding_cost
)
...
...
@@ -102,7 +103,11 @@ class BcastOpHandler(OperatorHandler):
sharding_spec_list
=
[]
check_duplicated_list
=
[]
for
output_dim_partition_dict
in
dim_partition_list
:
output_sharding_spec
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
output_dim_partition_dict
)
try
:
output_sharding_spec
=
self
.
_generate_sharding_spec
(
self
.
output_data
,
output_dim_partition_dict
)
except
AssertionError
as
e
:
warnings
.
warn
(
f
'
{
e
}
'
)
break
sharding_seq
=
output_sharding_spec
.
sharding_sequence
if
sharding_seq
not
in
check_duplicated_list
:
check_duplicated_list
.
append
(
sharding_seq
)
...
...
@@ -166,7 +171,7 @@ class BcastOpHandler(OperatorHandler):
##############################################
#used to generate strategies for torch.matmul#
##############################################
#
@exception_handler
@
exception_handler
def
_registry_no_split_strategies_for_matmul
(
self
,
dim_partition_dict_for_batch_dim
):
# this dim partition dict only describes the batch dimensions, but in this scenario,
# matrix dimensions are fully replicated, so it do not need extra process.
...
...
@@ -205,6 +210,7 @@ class BcastOpHandler(OperatorHandler):
self
.
strategies_vector
.
append
(
sharding_strategies
)
@
exception_handler
def
_split_dim_i
(
self
,
dim_partition_dict_for_batch_dim
,
mesh_dim_on_matrix
):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
...
...
@@ -262,6 +268,7 @@ class BcastOpHandler(OperatorHandler):
self
.
strategies_vector
.
append
(
sharding_strategies
)
@
exception_handler
def
_split_dim_k
(
self
,
dim_partition_dict_for_batch_dim
,
mesh_dim_on_matrix
):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
...
...
@@ -325,6 +332,7 @@ class BcastOpHandler(OperatorHandler):
self
.
strategies_vector
.
append
(
sharding_strategies
)
@
exception_handler
def
_split_dim_j
(
self
,
dim_partition_dict_for_batch_dim
,
mesh_dim_on_matrix
):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
...
...
@@ -390,6 +398,7 @@ class BcastOpHandler(OperatorHandler):
self
.
_split_dim_k
(
dim_partition_dict
,
mesh_dim_list
)
self
.
_split_dim_j
(
dim_partition_dict
,
mesh_dim_list
)
@
exception_handler
def
_split_lhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
dim_partition_dict_for_lhs
=
{
-
2
:
[
mesh_dim_0
],
-
1
:
[
mesh_dim_1
]}
sharding_spec_for_lhs
=
self
.
_generate_sharding_spec
(
self
.
lhs_data
,
dim_partition_dict_for_lhs
)
...
...
@@ -426,6 +435,7 @@ class BcastOpHandler(OperatorHandler):
self
.
strategies_vector
.
append
(
sharding_strategies
)
@
exception_handler
def
_split_rhs_space_both_contract
(
self
,
mesh_dim_0
,
mesh_dim_1
):
dim_partition_dict_for_lhs
=
{
-
1
:
[
mesh_dim_0
]}
sharding_spec_for_lhs
=
self
.
_generate_sharding_spec
(
self
.
lhs_data
,
dim_partition_dict_for_lhs
)
...
...
@@ -464,6 +474,7 @@ class BcastOpHandler(OperatorHandler):
self
.
strategies_vector
.
append
(
sharding_strategies
)
@
exception_handler
def
_split_lhs_space_rhs_space
(
self
,
mesh_dim_0
,
mesh_dim_1
):
dim_partition_dict_for_lhs
=
{
-
2
:
[
mesh_dim_0
]}
sharding_spec_for_lhs
=
self
.
_generate_sharding_spec
(
self
.
lhs_data
,
dim_partition_dict_for_lhs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment