Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
d373e67b
Unverified
Commit
d373e67b
authored
Oct 19, 2022
by
YuliangLiu0306
Committed by
GitHub
Oct 19, 2022
Browse files
[hotfix] resharding cost issue (#1742)
parent
24e84eba
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
6 deletions
+13
-6
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py
...allel/tensor_shard/deprecated/op_handler/where_handler.py
+13
-6
No files found.
colossalai/auto_parallel/tensor_shard/deprecated/op_handler/where_handler.py
View file @
d373e67b
import
operator
import
operator
from
functools
import
reduce
import
warnings
import
warnings
from
copy
import
deepcopy
from
functools
import
reduce
from
typing
import
Dict
,
List
import
torch
import
torch
from
colossalai.auto_parallel.tensor_shard.deprecated._utils
import
(
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
,
exception_handler
,
)
from
colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy
import
ShardingStrategy
,
StrategiesVector
from
.operator_handler
import
OperatorHandler
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.shape_consistency
import
ShapeConsistencyManager
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
colossalai.tensor.sharding_spec
import
ShardingSpec
from
copy
import
deepcopy
from
typing
import
Dict
,
List
from
.operator_handler
import
OperatorHandler
from
colossalai.auto_parallel.tensor_shard.deprecated._utils
import
exception_handler
,
enumerate_all_possible_1d_sharding
,
enumerate_all_possible_2d_sharding
__all__
=
[
'WhereHandler'
]
__all__
=
[
'WhereHandler'
]
...
@@ -94,7 +101,7 @@ class WhereHandler(OperatorHandler):
...
@@ -94,7 +101,7 @@ class WhereHandler(OperatorHandler):
# compute the resharding cost
# compute the resharding cost
_
,
_
,
total_resharding_cost
=
shape_consistency_manager
.
shape_consistency
(
_
,
_
,
total_resharding_cost
=
shape_consistency_manager
.
shape_consistency
(
input_sharding_spec
,
input_spec
)
input_sharding_spec
,
input_spec
)
total_resharding_cost
=
total_resharding_cost
[
'total'
]
# we need multiply the size of elem dtype to get correct communication cost
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost
=
total_resharding_cost
*
size_per_elem_bytes
resharding_cost
=
total_resharding_cost
*
size_per_elem_bytes
resharding_costs
[
input_node
].
append
(
resharding_cost
)
resharding_costs
[
input_node
].
append
(
resharding_cost
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment