Unverified Commit d373e67b authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[hotfix] resharding cost issue (#1742)

parent 24e84eba
import operator import operator
from functools import reduce
import warnings import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
exception_handler,
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
__all__ = ['WhereHandler'] __all__ = ['WhereHandler']
...@@ -94,7 +101,7 @@ class WhereHandler(OperatorHandler): ...@@ -94,7 +101,7 @@ class WhereHandler(OperatorHandler):
# compute the resharding cost # compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency( _, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec) input_sharding_spec, input_spec)
total_resharding_cost = total_resharding_cost['total']
# we need multiply the size of elem dtype to get correct communication cost # we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost) resharding_costs[input_node].append(resharding_cost)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment