Unverified Commit eee84908 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] handled illegal sharding strategy (#1728)

* [autoparallel] handled illegal sharding strategy

* polish code
parent cbe9a4cb
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
import torch
from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List, Optional
import warnings
from functools import reduce
import functools
import operator
import warnings
from functools import reduce
from typing import Dict, List, Optional, Union
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx.node import Node
from .constants import INFINITY_COST
......@@ -87,7 +89,7 @@ def generate_resharding_costs(nodes: List[Node],
return resharding_costs
def exception_handler(func):
def ignore_sharding_exception(func):
"""
A function wrapper which executes the function with a specified seed.
"""
......
import operator
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
__all__ = ['BatchNormHandler']
......@@ -110,7 +113,7 @@ class BatchNormHandler(OperatorHandler):
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
@exception_handler
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
......@@ -185,7 +188,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
......@@ -226,7 +229,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def non_split(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RR x R'
......@@ -322,7 +325,7 @@ class BatchNormHandler(OperatorHandler):
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
@exception_handler
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
......@@ -363,7 +366,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
......@@ -404,7 +407,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
......
import operator
from functools import reduce
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
from .operator_handler import OperatorHandler
__all__ = ['BcastOpHandler']
......@@ -136,7 +140,7 @@ class BcastOpHandler(OperatorHandler):
return output_sharding_spec_list
@exception_handler
@ignore_sharding_exception
def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
......@@ -171,7 +175,7 @@ class BcastOpHandler(OperatorHandler):
##############################################
#used to generate strategies for torch.matmul#
##############################################
@exception_handler
@ignore_sharding_exception
def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim):
# this dim partition dict only describes the batch dimensions, but in this scenario,
# matrix dimensions are fully replicated, so it do not need extra process.
......@@ -210,7 +214,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
......@@ -268,7 +272,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
......@@ -332,7 +336,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
......@@ -398,7 +402,7 @@ class BcastOpHandler(OperatorHandler):
self._split_dim_k(dim_partition_dict, mesh_dim_list)
self._split_dim_j(dim_partition_dict, mesh_dim_list)
@exception_handler
@ignore_sharding_exception
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
......@@ -435,7 +439,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-1: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
......@@ -474,7 +478,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
......
import operator
from functools import reduce
import warnings
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
__all__ = ['ConvHandler']
......@@ -105,7 +108,7 @@ class ConvHandler(OperatorHandler):
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
@exception_handler
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
......@@ -153,7 +156,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
......@@ -199,7 +202,7 @@ class ConvHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
......@@ -245,7 +248,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
......@@ -288,7 +291,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
......@@ -331,7 +334,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
......@@ -374,7 +377,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
......@@ -415,7 +418,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
......@@ -463,7 +466,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
......
import operator
from enum import Enum
from functools import reduce
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
from functools import reduce
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
from enum import Enum
from .strategy_generator import StrategyGenerator, IntermediateStrategy
from typing import List
from .operator_handler import OperatorHandler
from .strategy_generator import IntermediateStrategy, StrategyGenerator
__all__ = ['DotHandler']
......@@ -415,7 +418,7 @@ class DotHandler(OperatorHandler):
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size
return compute_cost
@exception_handler
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
......@@ -456,7 +459,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
......@@ -496,7 +499,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
......@@ -534,7 +537,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
......@@ -569,7 +572,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
......@@ -605,7 +608,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
......@@ -641,7 +644,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
......@@ -678,7 +681,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
......
import operator
from functools import reduce
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
from .operator_handler import OperatorHandler
__all__ = ['EmbeddingHandler']
......@@ -76,7 +79,7 @@ class EmbeddingHandler(OperatorHandler):
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
@exception_handler
@ignore_sharding_exception
def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}'
......@@ -117,7 +120,7 @@ class EmbeddingHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR'
......
import operator
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size, ignore_sharding_exception)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_2d_sharding, enumerate_all_possible_1d_sharding, generate_sharding_size
__all__ = ['LayerNormHandler']
......@@ -149,21 +153,21 @@ class LayerNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies)
@exception_handler
@ignore_sharding_exception
def split_input_batch_single_mesh_dim(self, mesh_dim_0):
batch_dimension_length = self.input_data.dim() - self.weight.dim()
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
for dim_partition in dim_partition_list:
self._generate_strategy_with_dim_partition(dim_partition)
@exception_handler
@ignore_sharding_exception
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1):
batch_dimension_length = self.input_data.dim() - self.weight.dim()
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
for dim_partition in dim_partition_list:
self._generate_strategy_with_dim_partition(dim_partition)
@exception_handler
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
......
import colorsys
from .operator_handler import OperatorHandler
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from copy import deepcopy
import math
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
import warnings
from copy import deepcopy
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from ..constants import INFINITY_COST
from .operator_handler import OperatorHandler
class ReshapeHandler(OperatorHandler):
......@@ -24,7 +27,7 @@ class ReshapeHandler(OperatorHandler):
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@exception_handler
@ignore_sharding_exception
def register_strategy(self):
# TODO: add strategies with more output sharding specs other than only fully replicated.
input_node = self.strategies_vector.predecessor_nodes[0]
......
import math
import operator
from functools import reduce
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated.constants import INFINITY_COST
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.constants import \
INFINITY_COST
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List
import math
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
from .operator_handler import OperatorHandler
__all__ = ['UnaryElementwiseHandler']
......@@ -40,7 +44,7 @@ class UnaryElementwiseHandler(OperatorHandler):
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@exception_handler
@ignore_sharding_exception
def register_strategy(self):
# TODO: integrate element-wise func and module together
# create sharding strategy for element-wise function
......
......@@ -6,12 +6,10 @@ from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
exception_handler,
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
......@@ -146,7 +144,7 @@ class WhereHandler(OperatorHandler):
return output_sharding_spec_list
@exception_handler
@ignore_sharding_exception
def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input)
......
......@@ -5,7 +5,8 @@ import torch
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim)
from colossalai.tensor.sharding_spec import ShardingException
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from .node_handler import ModuleHandler, NodeHandler
......@@ -15,6 +16,100 @@ from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyG
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy,
weight_name: str) -> ShardingStrategy:
"""
This function is a helper function used by both module node handler and function node handler. This function will
convert the sharding spec for the transposed weight to the correct partititon spec.
Args:
strategy (ShardingStrategy): the strategy generated by the strategy generator.
weight_name (str): the name of the OperationData object for the weight.
"""
# switch the dimensions of the transposed weight
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape != op_data.data.shape, \
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
switch_partition_dim(sharding_spec, 0, -1)
return strategy
def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str,
output_name: str) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output
should have the same sharding spec.
Args:
strategy (ShardingStrategy): the logical strategy generated by the strategy generator.
input_name (str): the name of the OperationData object for the input.
output_name (str): the name of the OperationData object for the output.
"""
# the result will be a list of strategies
sharding_strategies = []
# get operation data
input_op_data = strategy.get_op_data_by_name(input_name)
output_op_data = strategy.get_op_data_by_name(output_name)
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
# get logger for debug message
logger = get_dist_logger()
# for the input of the linear operation, it can be multi-dimensional. The sharding spec generated is only
# 2D, where the first dimension is non-matrix dimension and the last dimension is the matrix dimension.
# the logical non-matrix dimension can belong to the 0th to (N-1)th dimension of the physical input shape.
# Thus, we enumerate to get all possible cases.
if 0 in input_sharding_spec.dim_partition_dict:
# if 0 is in the dim_partition_dict, it means that the
# the generated sharding strategy does shard the non-matrix dimension,
# in this case, we need to do enumeration
num_input_dims = input_op_data.data.dim()
for i in range(num_input_dims - 1):
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
# in this case, we don't need to do enumeration
# but instead, we still need to convert the logical shape to physical shape
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={},
physical_shape=output_op_data.data.shape,
inplace=True)
print(input_op_data.data.shape)
print(output_op_data.data.shape)
sharding_strategies.append(strategy_copy)
return sharding_strategies
@operator_registry.register(torch.nn.Linear)
class LinearModuleHandler(ModuleHandler):
"""
......@@ -58,44 +153,20 @@ class LinearModuleHandler(ModuleHandler):
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
"""
Convert the sharding spec from the logical shape to the physical shape.
Convert the sharding spec from the logical shape to the physical shape. In this function, two tasks are completed:
1. the sharding spec is updated for the transposed weight
2. the input and output sharding specs are updated to physical shape.
"""
# switch the dimensions of the transposed weight
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape
switch_partition_dim(sharding_spec, 0, -1)
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight')
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
sharding_strategies = []
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
output_op_data = strategy.get_op_data_by_name(str(self.node))
num_input_dims = input_op_data.data.dim()
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
if 0 in input_sharding_spec.dim_partition_dict:
for i in range(num_input_dims - 1):
new_strategy = strategy.clone()
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
try:
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(new_strategy)
except ShardingException:
pass
else:
sharding_strategies.append(strategy)
return sharding_strategies
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
input_name=str(self.node.args[0]),
output_name=str(self.node))
return strategies
@operator_registry.register(F.linear)
......@@ -113,9 +184,12 @@ class LinearFunctionHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape)
# check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
......@@ -144,44 +218,17 @@ class LinearFunctionHandler(NodeHandler):
return mapping
def post_process(self, strategy: ShardingStrategy):
"""
Convert the sharding spec of the weight parameter back to its original shape.
"""
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape
switch_partition_dim(sharding_spec, 0, -1)
# switch the dimensions of the transposed weight
strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
weight_name=str(self.node.args[1]))
# create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input
sharding_strategies = []
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
output_op_data = strategy.get_op_data_by_name(str(self.node))
num_input_dims = input_op_data.data.dim()
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
if 0 in input_sharding_spec.dim_partition_dict:
for i in range(num_input_dims - 1):
new_strategy = strategy.clone()
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
try:
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(new_strategy)
except ShardingException:
pass
else:
sharding_strategies.append(strategy)
return strategy
strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
input_name=str(self.node.args[0]),
output_name=str(self.node))
return strategies
@operator_registry.register(torch.bmm)
......
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
......@@ -292,7 +293,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
'''
......@@ -325,9 +326,4 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# S01R = S01R x R WITH SYNC_BN
# strategy_list.append(self.split_input_batch_1d(0, 1))
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list
......@@ -5,7 +5,8 @@ from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import exception_handler
from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
......@@ -25,8 +26,8 @@ class ConvStrategyGenerator(StrategyGenerator):
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy):
'''
......@@ -99,7 +100,7 @@ class ConvStrategyGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@exception_handler
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
......@@ -146,7 +147,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
......@@ -183,7 +184,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
......@@ -230,7 +231,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
......@@ -270,7 +271,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
......@@ -301,7 +302,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
......@@ -334,7 +335,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
......@@ -353,7 +354,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={})
@exception_handler
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
......@@ -391,7 +392,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_mapping = {
......@@ -421,7 +422,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@exception_handler
@ignore_sharding_exception
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = {
......@@ -453,7 +454,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
# SS = SR x RS
strategies.append(self.split_input_batch_weight_out_channel(0, 1))
......@@ -491,20 +492,4 @@ class ConvStrategyGenerator(StrategyGenerator):
# RS01 = RR x RS01
strategies.append(self.split_1d_parallel_on_out_channel(0, 1))
rm_list = [strategy for strategy in strategies if strategy is None]
for rm_element in rm_list:
strategies.remove(rm_element)
illegal_strategy_list = []
# update mete info on cost
for strategy in strategies:
try:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
except AssertionError as e:
illegal_strategy_list.append(strategy)
warnings.warn(f'{e}')
for strategy in illegal_strategy_list:
strategies.remove(strategy)
return strategies
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
......@@ -61,7 +62,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
Deal with case 1 and 2.
'''
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for strategy in self.predecessor_node.strategies_vector:
dim_partition_dict_mapping = {}
......@@ -109,7 +110,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
Deal with case 3.
'''
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
index = self.op_data["index"].data
......@@ -133,9 +134,4 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
......@@ -159,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
'''
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
'''
......@@ -178,11 +179,5 @@ class LayerNormGenerator(StrategyGenerator):
# RR = RR x R
strategy_list.append(self.non_split())
# update mete info on cost
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list
......@@ -3,6 +3,8 @@ from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator
......@@ -169,7 +171,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategies = []
# SS = SR x RS
......@@ -201,14 +203,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# update mete info on cost
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategies
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
......@@ -249,6 +246,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
......@@ -289,6 +287,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
......@@ -324,6 +323,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
......@@ -351,6 +351,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
......@@ -380,6 +381,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
# get sharding spec
......@@ -410,6 +412,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communcation_action_mapping)
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
......@@ -437,6 +440,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
......@@ -542,7 +546,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mappingp['bias'] = bias_comm_spec
communication_action_mapping['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
......@@ -662,7 +666,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
......
......@@ -25,8 +25,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
'''
input_op_data = self.op_data['input']
assert input_op_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
assert input_op_data.data.dim() in (
3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
'''
......@@ -103,7 +103,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
return dim_partition_list
def generate(self) -> List[ShardingStrategy]:
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1)
......@@ -111,9 +111,4 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import OutputStrategyGenerator
......@@ -30,7 +32,7 @@ class OutputGenerator(OutputStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
dim_partition_dict_mapping = {
"output": {},
}
......@@ -47,8 +49,4 @@ class OutputGenerator(OutputStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return [strategy]
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import StrategyGenerator
......@@ -35,7 +37,7 @@ class PlaceholderGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
dim_partition_dict_mapping = {
"output": {},
}
......@@ -48,8 +50,4 @@ class PlaceholderGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return [strategy]
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
......@@ -49,7 +50,7 @@ class ReshapeGenerator(FollowingStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment