Unverified Commit 8283e95d authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] collated all deprecated files (#1700)

* [autoparallel] collated all deprecated files

* polish code
parent e2355d01
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2 from .strategy_generator import StrategyGenerator
from typing import List from typing import List
from .._utils import exception_handler from .._utils import exception_handler
import copy import copy
...@@ -10,7 +10,7 @@ import copy ...@@ -10,7 +10,7 @@ import copy
__all__ = ['BatchNormStrategyGenerator'] __all__ = ['BatchNormStrategyGenerator']
class BatchNormStrategyGenerator(StrategyGenerator_V2): class BatchNormStrategyGenerator(StrategyGenerator):
""" """
A StrategyGenerator which deals with the sharding strategies of batch normalization. A StrategyGenerator which deals with the sharding strategies of batch normalization.
...@@ -37,7 +37,7 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2): ...@@ -37,7 +37,7 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2):
assert input_op_data.dim() in (3, 4, assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy_V2): def update_compute_cost(self, strategy: ShardingStrategy):
''' '''
Compute the computation cost per device with this specific strategy. Compute the computation cost per device with this specific strategy.
...@@ -64,7 +64,7 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2): ...@@ -64,7 +64,7 @@ class BatchNormStrategyGenerator(StrategyGenerator_V2):
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2): def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = { forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"), 'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"), 'other': self._compute_size_in_bytes(strategy, "other"),
......
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2 from .strategy_generator import StrategyGenerator
from typing import List from typing import List
from .._utils import exception_handler from .._utils import exception_handler
import warnings import warnings
import copy import copy
class ConvStrategyGenerator(StrategyGenerator_V2): class ConvStrategyGenerator(StrategyGenerator):
""" """
ConvStrategyGenerator is a generic class to generate strategies. ConvStrategyGenerator is a generic class to generate strategies.
The operation data is defined as `output = input x other + bias`. The operation data is defined as `output = input x other + bias`.
...@@ -30,7 +30,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): ...@@ -30,7 +30,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
assert input_op_data.dim() in (3, 4, assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy_V2): def update_compute_cost(self, strategy: ShardingStrategy):
''' '''
Compute the computation cost per device with this specific strategy. Compute the computation cost per device with this specific strategy.
...@@ -70,7 +70,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): ...@@ -70,7 +70,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2): def update_memory_cost(self, strategy: ShardingStrategy):
forward_size_mapping = { forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"), 'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"), 'other': self._compute_size_in_bytes(strategy, "other"),
...@@ -455,7 +455,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2): ...@@ -455,7 +455,7 @@ class ConvStrategyGenerator(StrategyGenerator_V2):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]: def generate(self) -> List[ShardingStrategy]:
strategies = [] strategies = []
# SS = SR x RS # SS = SR x RS
strategies.append(self.split_input_batch_weight_out_channel(0, 1)) strategies.append(self.split_input_batch_weight_out_channel(0, 1))
......
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator from .strategy_generator import FollowingStrategyGenerator
from typing import List from typing import List
...@@ -28,11 +28,11 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator): ...@@ -28,11 +28,11 @@ class GetItemStrategyGenerator(FollowingStrategyGenerator):
def validate(self) -> bool: def validate(self) -> bool:
return super().validate() return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2): def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2): def update_memory_cost(self, strategy: ShardingStrategy):
''' '''
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' '''
......
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2 from .strategy_generator import StrategyGenerator
from typing import List from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy import copy
...@@ -10,7 +10,7 @@ import copy ...@@ -10,7 +10,7 @@ import copy
__all__ = ['LayerNormGenerator'] __all__ = ['LayerNormGenerator']
class LayerNormGenerator(StrategyGenerator_V2): class LayerNormGenerator(StrategyGenerator):
""" """
LayerNormGenerator is a generic class to generate strategies for LayerNorm operation. LayerNormGenerator is a generic class to generate strategies for LayerNorm operation.
The operation data is defined as `output = input x other + bias`. The operation data is defined as `output = input x other + bias`.
...@@ -23,7 +23,7 @@ class LayerNormGenerator(StrategyGenerator_V2): ...@@ -23,7 +23,7 @@ class LayerNormGenerator(StrategyGenerator_V2):
def validate(self) -> bool: def validate(self) -> bool:
return super().validate() return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2): def update_compute_cost(self, strategy: ShardingStrategy):
''' '''
Compute the computation cost per device with this specific strategy. Compute the computation cost per device with this specific strategy.
...@@ -54,7 +54,7 @@ class LayerNormGenerator(StrategyGenerator_V2): ...@@ -54,7 +54,7 @@ class LayerNormGenerator(StrategyGenerator_V2):
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2): def update_memory_cost(self, strategy: ShardingStrategy):
''' '''
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' '''
......
from audioop import bias from audioop import bias
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2 from .strategy_generator import StrategyGenerator
from typing import List from typing import List
class MatMulStrategyGenerator(StrategyGenerator_V2): class MatMulStrategyGenerator(StrategyGenerator):
""" """
MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases. MatMulStrategyGenerator is a generic class to cover all matrix multiplication cases.
The operation data is defined as `output = input x other + bias`. The operation data is defined as `output = input x other + bias`.
...@@ -17,7 +17,7 @@ class MatMulStrategyGenerator(StrategyGenerator_V2): ...@@ -17,7 +17,7 @@ class MatMulStrategyGenerator(StrategyGenerator_V2):
def has_bias(self): def has_bias(self):
return 'bias' in self.op_data return 'bias' in self.op_data
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
size_mapping = { size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"), 'input': self._compute_size_in_bytes(strategy, "input"),
'other': self._compute_size_in_bytes(strategy, "other"), 'other': self._compute_size_in_bytes(strategy, "other"),
...@@ -53,7 +53,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): ...@@ -53,7 +53,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
other_op_data = self.op_data['other'] other_op_data = self.op_data['other']
assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1 assert input_op_data.data.dim() == 1 and other_op_data.data.dim() == 1
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device() sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
fwd_compute_cost = sharded_input_shape[0] fwd_compute_cost = sharded_input_shape[0]
bwd_compute_cost = sharded_input_shape * 2 bwd_compute_cost = sharded_input_shape * 2
...@@ -88,7 +88,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator): ...@@ -88,7 +88,7 @@ class DotProductStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]: def generate(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
# do not split dimensions for dot product # do not split dimensions for dot product
...@@ -139,7 +139,7 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): ...@@ -139,7 +139,7 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]: def generate(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
# no split # no split
...@@ -154,7 +154,7 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator): ...@@ -154,7 +154,7 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# C = AB # C = AB
# C: [M, N], A: [M, P], B: [P, N] # C: [M, N], A: [M, P], B: [P, N]
# fwd cost = MNP (only count mul) # fwd cost = MNP (only count mul)
...@@ -172,7 +172,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): ...@@ -172,7 +172,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
total=fwd_compute_cost + bwd_compute_cost) total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def generate(self) -> List[ShardingStrategy_V2]: def generate(self) -> List[ShardingStrategy]:
strategies = [] strategies = []
# SS = SR x RS # SS = SR x RS
...@@ -500,7 +500,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): ...@@ -500,7 +500,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
other_op_data = self.op_data['other'] other_op_data = self.op_data['other']
assert input_op_data.data.dim() > 2 or other_op_data.data.dim() > 2 assert input_op_data.data.dim() > 2 or other_op_data.data.dim() > 2
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape) return self.op_data['input'].data.shape[-1] * reduce(operator.mul, self.op_data['output'].data.shape)
def split_one_batch_dim(self, mesh_dim): def split_one_batch_dim(self, mesh_dim):
...@@ -645,7 +645,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): ...@@ -645,7 +645,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy_V2]: def generate(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
device_mesh_is_1d = True device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
......
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2 from .strategy_generator import StrategyGenerator
from typing import List from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy import copy
class NormalPoolStrategyGenerator(StrategyGenerator_V2): class NormalPoolStrategyGenerator(StrategyGenerator):
""" """
NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd. NormalPoolStrategyGenerator is a generic class to generate strategies for pool operation like MaxPoolxd.
The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image, The reason we call this normal pool is AvgPoolxd and MaxPoolxd are taking the kernel size element from image,
...@@ -26,7 +26,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator_V2): ...@@ -26,7 +26,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator_V2):
assert input_op_data.dim() in (3, 4, assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].' 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> TrainCycleItem: def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
''' '''
Compute the computation cost per device with this specific strategy. Compute the computation cost per device with this specific strategy.
...@@ -54,7 +54,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator_V2): ...@@ -54,7 +54,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator_V2):
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost) compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
return compute_cost return compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
forward_size_mapping = { forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"), 'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output") 'output': self._compute_size_in_bytes(strategy, "output")
...@@ -101,7 +101,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator_V2): ...@@ -101,7 +101,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator_V2):
return dim_partition_list return dim_partition_list
def generate(self) -> List[ShardingStrategy_V2]: def generate(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1) dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1)
......
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import OutputStrategyGenerator from .strategy_generator import OutputStrategyGenerator
from typing import List from typing import List
...@@ -18,11 +18,11 @@ class OutputGenerator(OutputStrategyGenerator): ...@@ -18,11 +18,11 @@ class OutputGenerator(OutputStrategyGenerator):
def validate(self) -> bool: def validate(self) -> bool:
return super().validate() return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2): def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2): def update_memory_cost(self, strategy: ShardingStrategy):
''' '''
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' '''
......
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2 from .strategy_generator import StrategyGenerator
from typing import List from typing import List
from .._utils import exception_handler from .._utils import exception_handler
import copy import copy
...@@ -10,7 +10,7 @@ import copy ...@@ -10,7 +10,7 @@ import copy
__all__ = ['PlaceholderGenerator'] __all__ = ['PlaceholderGenerator']
class PlaceholderGenerator(StrategyGenerator_V2): class PlaceholderGenerator(StrategyGenerator):
""" """
PlaceholderGenerator is a generic class to generate strategies for placeholder node. PlaceholderGenerator is a generic class to generate strategies for placeholder node.
""" """
...@@ -18,11 +18,11 @@ class PlaceholderGenerator(StrategyGenerator_V2): ...@@ -18,11 +18,11 @@ class PlaceholderGenerator(StrategyGenerator_V2):
def validate(self) -> bool: def validate(self) -> bool:
return super().validate() return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2): def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2): def update_memory_cost(self, strategy: ShardingStrategy):
''' '''
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' '''
......
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator from .strategy_generator import FollowingStrategyGenerator
from typing import List from typing import List
...@@ -17,11 +17,11 @@ class ReshapeGenerator(FollowingStrategyGenerator): ...@@ -17,11 +17,11 @@ class ReshapeGenerator(FollowingStrategyGenerator):
def validate(self) -> bool: def validate(self) -> bool:
return super().validate() return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2): def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2): def update_memory_cost(self, strategy: ShardingStrategy):
''' '''
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' '''
......
...@@ -7,12 +7,12 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec ...@@ -7,12 +7,12 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List, Union, Any from typing import Dict, List, Union, Any
from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem, OperationDataType from ..sharding_strategy import OperationData, ShardingStrategy, TrainCycleItem, OperationDataType
from torch.fx import Node from torch.fx import Node
import copy import copy
class StrategyGenerator_V2(ABC): class StrategyGenerator(ABC):
""" """
StrategyGenerator is used to generate the same group of sharding strategies. StrategyGenerator is used to generate the same group of sharding strategies.
...@@ -38,9 +38,7 @@ class StrategyGenerator_V2(ABC): ...@@ -38,9 +38,7 @@ class StrategyGenerator_V2(ABC):
""" """
sharding_specs = self.replace_op_name_with_op_data(sharding_spec_mapping) sharding_specs = self.replace_op_name_with_op_data(sharding_spec_mapping)
communication_actions = self.replace_op_name_with_op_data(communication_action_mapping) communication_actions = self.replace_op_name_with_op_data(communication_action_mapping)
return ShardingStrategy_V2(name=name, return ShardingStrategy(name=name, sharding_specs=sharding_specs, communication_actions=communication_actions)
sharding_specs=sharding_specs,
communication_actions=communication_actions)
def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]): def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]):
""" """
...@@ -85,7 +83,7 @@ class StrategyGenerator_V2(ABC): ...@@ -85,7 +83,7 @@ class StrategyGenerator_V2(ABC):
sharding_spec=sharding_spec, sharding_spec=sharding_spec,
logical_process_axis=logical_process_axis) logical_process_axis=logical_process_axis)
def update_communication_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
""" """
Compute the communication cost involved in the forward and backward iteration. Compute the communication cost involved in the forward and backward iteration.
""" """
...@@ -113,20 +111,20 @@ class StrategyGenerator_V2(ABC): ...@@ -113,20 +111,20 @@ class StrategyGenerator_V2(ABC):
return strategy return strategy
@abstractmethod @abstractmethod
def update_compute_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
""" """
Customize this method to compute the computation flops. Customize this method to compute the computation flops.
""" """
pass pass
@abstractmethod @abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy_V2) -> ShardingStrategy_V2: def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
""" """
Customize this method to compute the memory cost in bytes. Customize this method to compute the memory cost in bytes.
""" """
pass pass
def _compute_size_in_bytes(self, strategy: ShardingStrategy_V2, key: str): def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str):
""" """
Compute the size of a tensor in bytes. Compute the size of a tensor in bytes.
...@@ -142,7 +140,7 @@ class StrategyGenerator_V2(ABC): ...@@ -142,7 +140,7 @@ class StrategyGenerator_V2(ABC):
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
@abstractmethod @abstractmethod
def generate(self) -> List[ShardingStrategy_V2]: def generate(self) -> List[ShardingStrategy]:
""" """
Generate all possible sharding strategies for this operation. Generate all possible sharding strategies for this operation.
""" """
...@@ -157,7 +155,7 @@ class StrategyGenerator_V2(ABC): ...@@ -157,7 +155,7 @@ class StrategyGenerator_V2(ABC):
pass pass
class FollowingStrategyGenerator(StrategyGenerator_V2): class FollowingStrategyGenerator(StrategyGenerator):
""" """
FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node. FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node.
...@@ -171,7 +169,7 @@ class FollowingStrategyGenerator(StrategyGenerator_V2): ...@@ -171,7 +169,7 @@ class FollowingStrategyGenerator(StrategyGenerator_V2):
self.predecessor_node = predecessor_node self.predecessor_node = predecessor_node
class OutputStrategyGenerator(StrategyGenerator_V2): class OutputStrategyGenerator(StrategyGenerator):
""" """
OutputStrategyGenerator is used to generate the sharding strategies for Output Node. OutputStrategyGenerator is used to generate the sharding strategies for Output Node.
""" """
......
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator from .strategy_generator import FollowingStrategyGenerator
from typing import List from typing import List
...@@ -18,11 +18,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): ...@@ -18,11 +18,11 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
def validate(self) -> bool: def validate(self) -> bool:
return super().validate() return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2): def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2): def update_memory_cost(self, strategy: ShardingStrategy):
''' '''
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' '''
......
import operator import operator
from functools import reduce from functools import reduce
from ..sharding_strategy import ShardingStrategy_V2, TrainCycleItem, MemoryCost from ..sharding_strategy import ShardingStrategy, TrainCycleItem, MemoryCost
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator_V2, FollowingStrategyGenerator from .strategy_generator import StrategyGenerator, FollowingStrategyGenerator
from typing import List from typing import List
from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding from .._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
import copy import copy
...@@ -10,7 +10,7 @@ import copy ...@@ -10,7 +10,7 @@ import copy
__all__ = ['WhereGenerator'] __all__ = ['WhereGenerator']
class WhereGenerator(StrategyGenerator_V2): class WhereGenerator(StrategyGenerator):
""" """
WhereGenerator is a generic class to generate strategies for Where operation. WhereGenerator is a generic class to generate strategies for Where operation.
""" """
...@@ -18,11 +18,11 @@ class WhereGenerator(StrategyGenerator_V2): ...@@ -18,11 +18,11 @@ class WhereGenerator(StrategyGenerator_V2):
def validate(self) -> bool: def validate(self) -> bool:
return super().validate() return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy_V2): def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20) compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy_V2): def update_memory_cost(self, strategy: ShardingStrategy):
''' '''
Compute the memory cost per device with this specific strategy. Compute the memory cost per device with this specific strategy.
''' '''
......
from .options import SolverOptions
from .strategies_constructor import StrategiesConstructor
from .sharding_strategy import ShardingStrategy, StrategiesVector
from .cost_graph import CostGraph
from .solver import Solver
from .graph_analysis import GraphAnalyser
\ No newline at end of file
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
import torch
from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List, Optional
import warnings
from functools import reduce
import functools
import operator
from .constants import INFINITY_COST
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
"""
Generate the sharding spec of the tensor based on the given dim_partition_dict.
Args:
input_ (Union[Node, torch.Tensor]): the input can be a Node object or a PyTorch tensor. If a node is used, it will look for its meta data associated with this node.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
dim_partition_dict (Dict[int, List[int]]): a dictionary to specify the sharding specs, the key is the tensor dimension and the value is the mesh dimension for sharding.
"""
if isinstance(input_, Node):
assert hasattr(input_, '_meta_data'), f'The given node has no attribte _meta_data'
meta_tensor = input_._meta_data
assert meta_tensor is not None, "The given node's _meta_data attribute is None"
shape = meta_tensor.shape
elif isinstance(input_, torch.Tensor):
shape = input_.shape
else:
raise TypeError(
f'We cannot generate sharding spec for {type(input_)} type, only torch.fx.Node or torch.Tensor is expected.'
)
for dim_index, sharding_index_list in dim_partition_dict.items():
sharding_list = [device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
assert shape[
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
sharding_spec = ShardingSpec(device_mesh=device_mesh, entire_shape=shape, dim_partition_dict=dim_partition_dict)
return sharding_spec
def generate_resharding_costs(nodes: List[Node],
sharding_specs: List[ShardingSpec],
count_backward: Optional[bool] = True,
dtype: Optional[torch.dtype] = None,
index=None):
'''
Compute the resharding costs with this specific strategy.
Argument:
nodes (List[Node]): a list of nodes
sharding_spec_for_input(ShardingSpec): a list of ShardingSpec for the nodes.
count_backward (Optional[bool]): whether to include the cost of resharding in the backward pass, default is True. False can be used for inference.
dtype (Optional[torch.dtype]): the data type for cost calculation, default is None.
'''
# The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
for input_node, input_spec in zip(nodes, sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
if not isinstance(input_sharding_spec, ShardingSpec):
assert isinstance(input_sharding_spec, list), 'only ShardingSpec or List[ShardingSpec] is expected.'
input_sharding_spec = input_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
try:
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
except AssertionError as e:
warnings.warn(f'{e}')
resharding_cost = INFINITY_COST
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def exception_handler(func):
"""
A function wrapper which executes the function with a specified seed.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
rst = func(*args, **kwargs)
return rst
except AssertionError as e:
warnings.warn(f'{e}')
return wrapper
def enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size):
dim_partition_list = []
# enumerate all the 2D sharding cases
for i in range(dim_size):
for j in range(i + 1, dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0], j: [mesh_dim_1]}
dim_partition_dict_1 = {i: [mesh_dim_1], j: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
dim_partition_list.append(dim_partition_dict_1)
for i in range(dim_size):
dim_partition_dict_flatten = {i: [mesh_dim_0, mesh_dim_1]}
dim_partition_list.append(dim_partition_dict_flatten)
return dim_partition_list
def enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size):
dim_partition_list = []
# enumerate all the 1D sharding cases
for i in range(dim_size):
dim_partition_dict_0 = {i: [mesh_dim_0]}
dim_partition_list.append(dim_partition_dict_0)
return dim_partition_list
def generate_sharding_size(dim_partition_dict, device_mesh):
total_sharding_size = 1
for mesh_dim_list in dim_partition_dict.values():
mesh_dim_sharding_size = [device_mesh.shape[mesh_dim] for mesh_dim in mesh_dim_list]
sharding_size = reduce(operator.mul, mesh_dim_sharding_size)
total_sharding_size *= sharding_size
return total_sharding_size
import torch
import operator
__all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST'
]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
ELEMENTWISE_FUNC_OP = [
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
# softmax should not be here
torch.nn.functional.softmax
]
ELEMENTWISE_METHOD_OP = [
torch.Tensor.to,
torch.Tensor.type,
# TODO: contiguous maybe need some extra processes.
torch.Tensor.contiguous
]
RESHAPE_FUNC_OP = [torch.flatten, torch.reshape]
RESHAPE_METHOD_OP = [
torch.Tensor.view,
torch.Tensor.unsqueeze,
torch.Tensor.split,
torch.Tensor.permute,
torch.Tensor.transpose,
]
BCAST_FUNC_OP = [
torch.add, torch.sub, torch.mul, torch.div, torch.floor_divide, torch.true_divide, operator.add, operator.sub,
operator.mul, operator.floordiv, operator.truediv, torch.matmul, torch.where, operator.pow, torch.pow, torch.tanh
]
CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d
]
CONV_FUNC_OP = [
torch.conv1d, torch.conv2d, torch.conv3d, torch.conv_transpose1d, torch.conv_transpose2d, torch.conv_transpose3d
]
EMBEDDING_MODULE_OP = [torch.nn.modules.sparse.Embedding]
LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
LAYERNORM_MODULE_OP = [torch.nn.LayerNorm]
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
NON_PARAM_FUNC_OP = [
torch.flatten,
torch.reshape,
torch.abs,
torch.cos,
torch.exp,
operator.neg,
torch.multiply,
torch.nn.functional.relu,
torch.nn.functional.dropout,
torch.flatten,
torch.where,
operator.pow,
torch.pow,
torch.tanh,
torch.add,
torch.sub,
torch.mul,
torch.div,
torch.floor_divide,
torch.true_divide,
operator.add,
operator.sub,
operator.mul,
operator.floordiv,
operator.truediv,
# softmax should not be here
torch.nn.functional.softmax
]
INFINITY_COST = 1e13
from typing import List
import math
from torch.fx.node import Node
from .constants import INFINITY_COST
class CostGraph:
'''
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
be given by the StrategiesVector depending on the type of target node and following nodes.
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
'''
def __init__(self, leaf_strategies, simplify=True):
self.leaf_strategies = leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
# stores number of strategies in each node
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
# extra_node_costs will store the extra costs introduced by merging nodes
self.extra_node_costs = {}
self.following_dict = {}
self.simplify = simplify
self._build_cost_graph()
def _remove_invalid_node(self, node, attr_name):
remove_list = []
target_node_list = getattr(node, attr_name, [])
for target_node in target_node_list:
if target_node not in self.nodes:
remove_list.append(target_node)
for element in remove_list:
target_node_list.remove(element)
def _build_cost_graph(self):
'''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
'''
self.edge_costs = {}
if self.simplify:
self.merge_pair = []
for strategies_vector in self.leaf_strategies:
# build edge_cost
dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes:
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
# src_index = strategies_vector.predecessor_nodes.index(src_node)
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(src_node.strategies_vector)):
edge_cost[(j, i)] = strategies_vector[i].resharding_costs[src_node][j]
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
setattr(dst_node, 'parents', strategies_vector.predecessor_nodes)
setattr(dst_node, 'children', strategies_vector.successor_nodes)
self._remove_invalid_node(dst_node, 'parents')
self._remove_invalid_node(dst_node, 'children')
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
self.merge_pair.append((followed_node, dst_node))
def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
'''
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
x represents the picking strategy of node 1 merged into node 2 strategy 0.
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
another is the origin extra costs in src_node strategy.
3. Build connections between new node pairs, and remove the src_node after all consumer nodes
detached from it.
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
'''
src_node_index = dst_node.parents.index(src_node)
# build merge_map
merge_map = {}
for src_index, strategy in enumerate(src_node.strategies_vector):
min_cost = INFINITY_COST
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
resharding_cost = dst_strategy.resharding_costs[src_node][src_index]
if resharding_cost <= min_cost:
min_cost = resharding_cost
lowest_cost_index = dst_index
merge_map[src_index] = lowest_cost_index
# extra_node_cost for src node
self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node]
for src_index, strategy in enumerate(src_node.strategies_vector):
target_strate_index = merge_map[src_index]
target_strategy = dst_node.strategies_vector[target_strate_index]
self.extra_node_costs[src_node][src_index] += target_strategy.resharding_costs[src_node][src_index]
if dst_node in self.extra_node_costs:
self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index]
# add new node pair to cost graph
for child_node in dst_node.children:
new_node_pair = (src_node, child_node)
old_node_pair = (dst_node, child_node)
if new_node_pair in self.edge_costs:
continue
edge_cost = {}
for i in range(self.node_lens[src_node]):
for j in range(self.node_lens[child_node]):
dst_strate_index = merge_map[i]
# dst_strategy = dst_node.strategies_vector[dst_strate_index]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
if new_node_pair not in self.edge_costs:
self.edge_costs[new_node_pair] = edge_cost
else:
# we should accumulate the resharding costs if args of child node contain
# both src node and dst node.
for index_pair, resharding_cost in self.edge_costs[new_node_pair]:
self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair]
# connect src node and children of dst node
dst_node.parents.remove(src_node)
src_node.children.remove(dst_node)
self.edge_costs.pop((src_node, dst_node))
for child_node in dst_node.children:
if child_node not in src_node.children:
src_node.children.append(child_node)
if src_node not in child_node.parents:
child_node.parents.append(src_node)
# remove dst node from cost graph when dst node has no producer.
if len(dst_node.parents) == 0:
child_node.parents.remove(dst_node)
node_pair = (dst_node, child_node)
self.edge_costs.pop(node_pair)
if len(dst_node.parents) == 0:
self.following_dict[dst_node] = src_node
dst_node.children = []
def _reindexing_src(self, src):
if src not in self.following_dict:
return src
return self._reindexing_src(self.following_dict[src])
def simplify_graph(self):
if not self.simplify:
return
self.merge_pair.reverse()
for (src_node, dst_node) in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}
for dst, src in self.following_dict.items():
reindexing_following_dict[dst] = self._reindexing_src(src)
self.following_dict = reindexing_following_dict
from dataclasses import dataclass
from torch.fx.node import Node
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from collections import OrderedDict as ODict
from typing import List, OrderedDict, Union, Any
from colossalai.fx.passes.utils import get_node_module
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
@dataclass
class LiveVariable:
"""
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
name: str
node: Node
is_inplace: bool
class LiveVariableVector(list):
"""
LiveVariableVector is a data structure to store the list of LiveVariable objects.
"""
def exists(self, name) -> bool:
"""
Check if a variable has already existed in the current list by name.
"""
for var in self:
if name == var.name:
return True
return False
def get(self, name) -> LiveVariable:
for var in self:
if name == var.name:
return var
raise KeyError(f"Variable {name} is not found")
def copy(self) -> "LiveVariableVector":
"""
Create a copy of this vector
"""
vector = LiveVariableVector()
for var in self:
vector.append(var)
return vector
@dataclass
class LiveStage:
"""
LiveStage is a data structure to record the living variables at this current node.
"""
name: str
node: Node
all_live_vars: LiveVariableVector
unique_live_vars: LiveVariableVector
class GraphAnalyser:
def __init__(self, gm: GraphModule):
self._gm = gm
self._graph = gm.graph
@property
def gm(self) -> GraphModule:
"""
Return the GraphModule object associated with this analyser.
"""
return self._gm
@property
def graph(self) -> Graph:
"""
Return the Graph object associated with this analyser.
"""
return self._graph
def liveness_analysis(self) -> List[LiveStage]:
"""
Analyse the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
"""
compute_nodes = self.graph.nodes
liveness_list = []
# checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage.
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
# unique: record the unique live variables only exist until the current stage.
# this is different from `all list` as some variables are duplicated.
checked_variables = LiveVariableVector()
all_live_variables = LiveVariableVector()
unique_live_vars = LiveVariableVector()
for idx, node in enumerate(compute_nodes):
#############################
# find new living variables #
#############################
# detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplciate var
is_inplace = False
if node.op == 'call_function':
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
if node.kwargs.get('inplace', False):
is_inplace = True
elif node.op == 'call_module':
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
module = get_node_module(node)
if getattr(module, 'inplace', False):
is_inplace = True
# add the output var
meta = getattr(node, '_meta_data', None)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
checked_variables.append(live_var)
all_live_variables.append(live_var)
# check if any input is not checked yet
for arg in node.args:
if not isinstance(arg, Node):
continue
arg_name = arg.name
if not checked_variables.exists(arg_name):
live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
all_live_variables.append(live_var_from_arg)
checked_variables.append(live_var_from_arg)
unique_live_vars.append(live_var_from_arg)
# TODO: add the logic to remove live variables
# this should be completed if we are able to trace the backward compute graph
# add this stage to liveness dict
stage = LiveStage(name=node.name,
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy())
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):
all_covered = True
for ele in prev_stage.unique_live_vars:
if ele not in stage.unique_live_vars:
all_covered = False
break
if all_covered:
replace = True
break
if replace:
liveness_list[index] = stage
else:
liveness_list.append(stage)
return liveness_list
def get_alias_set(self):
pass
...@@ -6,19 +6,9 @@ from .reshape_handler import ReshapeHandler ...@@ -6,19 +6,9 @@ from .reshape_handler import ReshapeHandler
from .bcast_op_handler import BcastOpHandler from .bcast_op_handler import BcastOpHandler
from .embedding_handler import EmbeddingHandler from .embedding_handler import EmbeddingHandler
from .unary_elementwise_handler import UnaryElementwiseHandler from .unary_elementwise_handler import UnaryElementwiseHandler
from .dot_handler_v2 import LinearFunctionHandler, LinearModuleHandler from .where_handler import WhereHandler
from .layer_norm_handler_v2 import LayerNormModuleHandler
from .batch_norm_handler_v2 import BatchNormModuleHandler
from .conv_handler_v2 import ConvModuleHandler, ConvFunctionHandler
from .where_handler_v2 import WhereHandler
from .unary_elementwise_handler_v2 import UnaryElementwiseHandler_V2
from .reshape_handler_v2 import ReshapeHandler_V2
from .placeholder_handler import PlacehodlerHandler
from .output_handler import OuputHandler
__all__ = [ __all__ = [
'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler', 'OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler', 'ReshapeHandler', 'BcastOpHandler',
'UnaryElementwiseHandler', 'EmbeddingHandler', 'LinearFunctionHandler', 'LinearModuleHandler', 'UnaryElementwiseHandler', 'EmbeddingHandler', 'WhereHandler'
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler_V2', 'ReshapeHandler_V2', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler'
] ]
import operator import operator
from functools import reduce from functools import reduce
import torch import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
from colossalai.auto_parallel.solver._utils import exception_handler from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
__all__ = ['BatchNormHandler'] __all__ = ['BatchNormHandler']
......
...@@ -2,13 +2,13 @@ import operator ...@@ -2,13 +2,13 @@ import operator
from functools import reduce from functools import reduce
import warnings import warnings
import torch import torch
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy from copy import deepcopy
from typing import Dict, List from typing import Dict, List
from colossalai.auto_parallel.solver._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
__all__ = ['BcastOpHandler'] __all__ = ['BcastOpHandler']
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment