Unverified Commit eee84908 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] handled illegal sharding strategy (#1728)

* [autoparallel] handled illegal sharding strategy

* polish code
parent cbe9a4cb
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
import torch
from torch.fx.node import Node
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Union, Dict, List, Optional
import warnings
from functools import reduce
import functools import functools
import operator import operator
import warnings
from functools import reduce
from typing import Dict, List, Optional, Union
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx.node import Node
from .constants import INFINITY_COST from .constants import INFINITY_COST
...@@ -87,7 +89,7 @@ def generate_resharding_costs(nodes: List[Node], ...@@ -87,7 +89,7 @@ def generate_resharding_costs(nodes: List[Node],
return resharding_costs return resharding_costs
def exception_handler(func): def ignore_sharding_exception(func):
""" """
A function wrapper which executes the function with a specified seed. A function wrapper which executes the function with a specified seed.
""" """
......
import operator import operator
from functools import reduce from functools import reduce
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
__all__ = ['BatchNormHandler'] __all__ = ['BatchNormHandler']
...@@ -110,7 +113,7 @@ class BatchNormHandler(OperatorHandler): ...@@ -110,7 +113,7 @@ class BatchNormHandler(OperatorHandler):
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
@exception_handler @ignore_sharding_exception
def split_input_channel(self, mesh_dim_0, mesh_dim_1): def split_input_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
...@@ -185,7 +188,7 @@ class BatchNormHandler(OperatorHandler): ...@@ -185,7 +188,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
...@@ -226,7 +229,7 @@ class BatchNormHandler(OperatorHandler): ...@@ -226,7 +229,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def non_split(self, mesh_dim_0, mesh_dim_1): def non_split(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RR x R' name = f'RR = RR x R'
...@@ -322,7 +325,7 @@ class BatchNormHandler(OperatorHandler): ...@@ -322,7 +325,7 @@ class BatchNormHandler(OperatorHandler):
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name) new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy) self.strategies_vector.append(new_sharding_strategy)
@exception_handler @ignore_sharding_exception
def split_input_batch(self, mesh_dim_0): def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
...@@ -363,7 +366,7 @@ class BatchNormHandler(OperatorHandler): ...@@ -363,7 +366,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
...@@ -404,7 +407,7 @@ class BatchNormHandler(OperatorHandler): ...@@ -404,7 +407,7 @@ class BatchNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
......
import operator import operator
from functools import reduce
import warnings import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
from .operator_handler import OperatorHandler enumerate_all_possible_2d_sharding,
ignore_sharding_exception)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding
__all__ = ['BcastOpHandler'] __all__ = ['BcastOpHandler']
...@@ -136,7 +140,7 @@ class BcastOpHandler(OperatorHandler): ...@@ -136,7 +140,7 @@ class BcastOpHandler(OperatorHandler):
return output_sharding_spec_list return output_sharding_spec_list
@exception_handler @ignore_sharding_exception
def _register_strategy(self, output_sharding_spec): def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input) sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
...@@ -171,7 +175,7 @@ class BcastOpHandler(OperatorHandler): ...@@ -171,7 +175,7 @@ class BcastOpHandler(OperatorHandler):
############################################## ##############################################
#used to generate strategies for torch.matmul# #used to generate strategies for torch.matmul#
############################################## ##############################################
@exception_handler @ignore_sharding_exception
def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim): def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim):
# this dim partition dict only describes the batch dimensions, but in this scenario, # this dim partition dict only describes the batch dimensions, but in this scenario,
# matrix dimensions are fully replicated, so it do not need extra process. # matrix dimensions are fully replicated, so it do not need extra process.
...@@ -210,7 +214,7 @@ class BcastOpHandler(OperatorHandler): ...@@ -210,7 +214,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
...@@ -268,7 +272,7 @@ class BcastOpHandler(OperatorHandler): ...@@ -268,7 +272,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
...@@ -332,7 +336,7 @@ class BcastOpHandler(OperatorHandler): ...@@ -332,7 +336,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix): def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j] # A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it. # this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
...@@ -398,7 +402,7 @@ class BcastOpHandler(OperatorHandler): ...@@ -398,7 +402,7 @@ class BcastOpHandler(OperatorHandler):
self._split_dim_k(dim_partition_dict, mesh_dim_list) self._split_dim_k(dim_partition_dict, mesh_dim_list)
self._split_dim_j(dim_partition_dict, mesh_dim_list) self._split_dim_j(dim_partition_dict, mesh_dim_list)
@exception_handler @ignore_sharding_exception
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]} dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
...@@ -435,7 +439,7 @@ class BcastOpHandler(OperatorHandler): ...@@ -435,7 +439,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-1: [mesh_dim_0]} dim_partition_dict_for_lhs = {-1: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
...@@ -474,7 +478,7 @@ class BcastOpHandler(OperatorHandler): ...@@ -474,7 +478,7 @@ class BcastOpHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0]} dim_partition_dict_for_lhs = {-2: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs) sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
......
import operator import operator
from functools import reduce
import warnings import warnings
from functools import reduce
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
__all__ = ['ConvHandler'] __all__ = ['ConvHandler']
...@@ -105,7 +108,7 @@ class ConvHandler(OperatorHandler): ...@@ -105,7 +108,7 @@ class ConvHandler(OperatorHandler):
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
@exception_handler @ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
...@@ -153,7 +156,7 @@ class ConvHandler(OperatorHandler): ...@@ -153,7 +156,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_input_batch(self, mesh_dim_0): def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
...@@ -199,7 +202,7 @@ class ConvHandler(OperatorHandler): ...@@ -199,7 +202,7 @@ class ConvHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
...@@ -245,7 +248,7 @@ class ConvHandler(OperatorHandler): ...@@ -245,7 +248,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
...@@ -288,7 +291,7 @@ class ConvHandler(OperatorHandler): ...@@ -288,7 +291,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0): def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
...@@ -331,7 +334,7 @@ class ConvHandler(OperatorHandler): ...@@ -331,7 +334,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0): def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
...@@ -374,7 +377,7 @@ class ConvHandler(OperatorHandler): ...@@ -374,7 +377,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def non_split(self): def non_split(self):
name = f'RR = RR x RR' name = f'RR = RR x RR'
...@@ -415,7 +418,7 @@ class ConvHandler(OperatorHandler): ...@@ -415,7 +418,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
...@@ -463,7 +466,7 @@ class ConvHandler(OperatorHandler): ...@@ -463,7 +466,7 @@ class ConvHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
......
import operator import operator
from enum import Enum
from functools import reduce
from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
from .operator_handler import OperatorHandler ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
from functools import reduce from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler from .strategy_generator import IntermediateStrategy, StrategyGenerator
from enum import Enum
from .strategy_generator import StrategyGenerator, IntermediateStrategy
from typing import List
__all__ = ['DotHandler'] __all__ = ['DotHandler']
...@@ -415,7 +418,7 @@ class DotHandler(OperatorHandler): ...@@ -415,7 +418,7 @@ class DotHandler(OperatorHandler):
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size
return compute_cost return compute_cost
@exception_handler @ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS # handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
...@@ -456,7 +459,7 @@ class DotHandler(OperatorHandler): ...@@ -456,7 +459,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR # handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
...@@ -496,7 +499,7 @@ class DotHandler(OperatorHandler): ...@@ -496,7 +499,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
...@@ -534,7 +537,7 @@ class DotHandler(OperatorHandler): ...@@ -534,7 +537,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim): def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R' name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
...@@ -569,7 +572,7 @@ class DotHandler(OperatorHandler): ...@@ -569,7 +572,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim): def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}' name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
...@@ -605,7 +608,7 @@ class DotHandler(OperatorHandler): ...@@ -605,7 +608,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
...@@ -641,7 +644,7 @@ class DotHandler(OperatorHandler): ...@@ -641,7 +644,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
...@@ -678,7 +681,7 @@ class DotHandler(OperatorHandler): ...@@ -678,7 +681,7 @@ class DotHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
......
import operator import operator
from functools import reduce
import warnings import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
from .operator_handler import OperatorHandler ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
__all__ = ['EmbeddingHandler'] __all__ = ['EmbeddingHandler']
...@@ -76,7 +79,7 @@ class EmbeddingHandler(OperatorHandler): ...@@ -76,7 +79,7 @@ class EmbeddingHandler(OperatorHandler):
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
@exception_handler @ignore_sharding_exception
def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1): def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}' name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}'
...@@ -117,7 +120,7 @@ class EmbeddingHandler(OperatorHandler): ...@@ -117,7 +120,7 @@ class EmbeddingHandler(OperatorHandler):
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR' name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR'
......
import operator import operator
from functools import reduce from functools import reduce
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size, ignore_sharding_exception)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from .operator_handler import OperatorHandler from .operator_handler import OperatorHandler
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler, enumerate_all_possible_2d_sharding, enumerate_all_possible_1d_sharding, generate_sharding_size
__all__ = ['LayerNormHandler'] __all__ = ['LayerNormHandler']
...@@ -149,21 +153,21 @@ class LayerNormHandler(OperatorHandler): ...@@ -149,21 +153,21 @@ class LayerNormHandler(OperatorHandler):
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
@exception_handler @ignore_sharding_exception
def split_input_batch_single_mesh_dim(self, mesh_dim_0): def split_input_batch_single_mesh_dim(self, mesh_dim_0):
batch_dimension_length = self.input_data.dim() - self.weight.dim() batch_dimension_length = self.input_data.dim() - self.weight.dim()
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length) dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
for dim_partition in dim_partition_list: for dim_partition in dim_partition_list:
self._generate_strategy_with_dim_partition(dim_partition) self._generate_strategy_with_dim_partition(dim_partition)
@exception_handler @ignore_sharding_exception
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1): def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1):
batch_dimension_length = self.input_data.dim() - self.weight.dim() batch_dimension_length = self.input_data.dim() - self.weight.dim()
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length) dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
for dim_partition in dim_partition_list: for dim_partition in dim_partition_list:
self._generate_strategy_with_dim_partition(dim_partition) self._generate_strategy_with_dim_partition(dim_partition)
@exception_handler @ignore_sharding_exception
def non_split(self): def non_split(self):
name = f'RR = RR x R' name = f'RR = RR x R'
......
import colorsys import colorsys
from .operator_handler import OperatorHandler
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from copy import deepcopy
import math import math
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
import warnings import warnings
from copy import deepcopy
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from ..constants import INFINITY_COST from ..constants import INFINITY_COST
from .operator_handler import OperatorHandler
class ReshapeHandler(OperatorHandler): class ReshapeHandler(OperatorHandler):
...@@ -24,7 +27,7 @@ class ReshapeHandler(OperatorHandler): ...@@ -24,7 +27,7 @@ class ReshapeHandler(OperatorHandler):
def _generate_compute_cost(self, *args, **kwargs): def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs) return super()._generate_compute_cost(*args, **kwargs)
@exception_handler @ignore_sharding_exception
def register_strategy(self): def register_strategy(self):
# TODO: add strategies with more output sharding specs other than only fully replicated. # TODO: add strategies with more output sharding specs other than only fully replicated.
input_node = self.strategies_vector.predecessor_nodes[0] input_node = self.strategies_vector.predecessor_nodes[0]
......
import math
import operator import operator
from functools import reduce
import warnings import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated.constants import INFINITY_COST from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector ignore_sharding_exception
from .operator_handler import OperatorHandler from colossalai.auto_parallel.tensor_shard.deprecated.constants import \
INFINITY_COST
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
from typing import Dict, List from .operator_handler import OperatorHandler
import math
from colossalai.auto_parallel.tensor_shard.deprecated._utils import exception_handler
__all__ = ['UnaryElementwiseHandler'] __all__ = ['UnaryElementwiseHandler']
...@@ -40,7 +44,7 @@ class UnaryElementwiseHandler(OperatorHandler): ...@@ -40,7 +44,7 @@ class UnaryElementwiseHandler(OperatorHandler):
def _generate_compute_cost(self, *args, **kwargs): def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs) return super()._generate_compute_cost(*args, **kwargs)
@exception_handler @ignore_sharding_exception
def register_strategy(self): def register_strategy(self):
# TODO: integrate element-wise func and module together # TODO: integrate element-wise func and module together
# create sharding strategy for element-wise function # create sharding strategy for element-wise function
......
...@@ -6,12 +6,10 @@ from typing import Dict, List ...@@ -6,12 +6,10 @@ from typing import Dict, List
import torch import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ( from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding,
enumerate_all_possible_2d_sharding, ignore_sharding_exception)
exception_handler, from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
...@@ -146,7 +144,7 @@ class WhereHandler(OperatorHandler): ...@@ -146,7 +144,7 @@ class WhereHandler(OperatorHandler):
return output_sharding_spec_list return output_sharding_spec_list
@exception_handler @ignore_sharding_exception
def _register_strategy(self, output_sharding_spec): def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input) sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input)
......
...@@ -5,7 +5,8 @@ import torch ...@@ -5,7 +5,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim) from colossalai.auto_parallel.tensor_shard.utils import (switch_partition_dim, update_partition_dim)
from colossalai.tensor.sharding_spec import ShardingException from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingNotDivisibleError
from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy) from ..sharding_strategy import (OperationData, OperationDataType, ShardingStrategy)
from .node_handler import ModuleHandler, NodeHandler from .node_handler import ModuleHandler, NodeHandler
...@@ -15,6 +16,100 @@ from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyG ...@@ -15,6 +16,100 @@ from .strategy import (BatchedMatMulStrategyGenerator, LinearProjectionStrategyG
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler'] __all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
def _update_sharding_spec_for_transposed_weight_for_linear(strategy: ShardingStrategy,
weight_name: str) -> ShardingStrategy:
"""
This function is a helper function used by both module node handler and function node handler. This function will
convert the sharding spec for the transposed weight to the correct partititon spec.
Args:
strategy (ShardingStrategy): the strategy generated by the strategy generator.
weight_name (str): the name of the OperationData object for the weight.
"""
# switch the dimensions of the transposed weight
sharding_spec = strategy.get_sharding_spec_by_name(weight_name)
op_data = strategy.get_op_data_by_name(weight_name)
assert op_data.logical_shape != op_data.data.shape, \
"Expected the logical and physical shape of the linear operator's weight to be different, but found them to be the same"
switch_partition_dim(sharding_spec, 0, -1)
return strategy
def _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy: ShardingStrategy, input_name: str,
output_name: str) -> List[ShardingStrategy]:
"""
This function converts the logical sharding spec to the physical sharding spec for both the input and output of the linear operation. The input and output
should have the same sharding spec.
Args:
strategy (ShardingStrategy): the logical strategy generated by the strategy generator.
input_name (str): the name of the OperationData object for the input.
output_name (str): the name of the OperationData object for the output.
"""
# the result will be a list of strategies
sharding_strategies = []
# get operation data
input_op_data = strategy.get_op_data_by_name(input_name)
output_op_data = strategy.get_op_data_by_name(output_name)
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
# get logger for debug message
logger = get_dist_logger()
# for the input of the linear operation, it can be multi-dimensional. The sharding spec generated is only
# 2D, where the first dimension is non-matrix dimension and the last dimension is the matrix dimension.
# the logical non-matrix dimension can belong to the 0th to (N-1)th dimension of the physical input shape.
# Thus, we enumerate to get all possible cases.
if 0 in input_sharding_spec.dim_partition_dict:
# if 0 is in the dim_partition_dict, it means that the
# the generated sharding strategy does shard the non-matrix dimension,
# in this case, we need to do enumeration
num_input_dims = input_op_data.data.dim()
for i in range(num_input_dims - 1):
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
try:
# replace the 0th dimension in the logical sharding with ith dimension in the physical sharding
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(strategy_copy)
except ShardingNotDivisibleError as e:
logger.debug(
f'Errored occurred when converting the logical sharding spec to the physical one. Error details: {e}'
)
else:
# the generated sharding strategy does not shard the non-matrix dimension,
# in this case, we don't need to do enumeration
# but instead, we still need to convert the logical shape to physical shape
strategy_copy = strategy.clone()
input_sharding_spec = strategy_copy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = strategy_copy.get_sharding_spec_by_name(output_op_data.name)
# after updating, the logical shape will be replaced by the physical shape
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={},
physical_shape=output_op_data.data.shape,
inplace=True)
print(input_op_data.data.shape)
print(output_op_data.data.shape)
sharding_strategies.append(strategy_copy)
return sharding_strategies
@operator_registry.register(torch.nn.Linear) @operator_registry.register(torch.nn.Linear)
class LinearModuleHandler(ModuleHandler): class LinearModuleHandler(ModuleHandler):
""" """
...@@ -58,44 +153,20 @@ class LinearModuleHandler(ModuleHandler): ...@@ -58,44 +153,20 @@ class LinearModuleHandler(ModuleHandler):
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]: def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
""" """
Convert the sharding spec from the logical shape to the physical shape. Convert the sharding spec from the logical shape to the physical shape. In this function, two tasks are completed:
1. the sharding spec is updated for the transposed weight
2. the input and output sharding specs are updated to physical shape.
""" """
# switch the dimensions of the transposed weight # switch the dimensions of the transposed weight
for op_data, sharding_spec in strategy.input_sharding_specs.items(): strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy, weight_name='weight')
if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape
switch_partition_dim(sharding_spec, 0, -1)
# create multiple sharding strategies for the inputs # create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D, # as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input # we need to map the partition at dim 0 to one of the first few dimensions of the input
sharding_strategies = [] strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0])) input_name=str(self.node.args[0]),
output_op_data = strategy.get_op_data_by_name(str(self.node)) output_name=str(self.node))
num_input_dims = input_op_data.data.dim() return strategies
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
if 0 in input_sharding_spec.dim_partition_dict:
for i in range(num_input_dims - 1):
new_strategy = strategy.clone()
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
try:
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(new_strategy)
except ShardingException:
pass
else:
sharding_strategies.append(strategy)
return sharding_strategies
@operator_registry.register(F.linear) @operator_registry.register(F.linear)
...@@ -113,9 +184,12 @@ class LinearFunctionHandler(NodeHandler): ...@@ -113,9 +184,12 @@ class LinearFunctionHandler(NodeHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies # use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG, type=OperationDataType.ARG,
data=self.node.args[0]._meta_data) data=self.node.args[0]._meta_data,
logical_shape=input_logical_shape)
# check if the other operand is a parameter # check if the other operand is a parameter
if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter): if isinstance(self.node.args[1]._meta_data, torch.nn.parameter.Parameter):
...@@ -144,44 +218,17 @@ class LinearFunctionHandler(NodeHandler): ...@@ -144,44 +218,17 @@ class LinearFunctionHandler(NodeHandler):
return mapping return mapping
def post_process(self, strategy: ShardingStrategy): def post_process(self, strategy: ShardingStrategy):
""" # switch the dimensions of the transposed weight
Convert the sharding spec of the weight parameter back to its original shape. strategy = _update_sharding_spec_for_transposed_weight_for_linear(strategy=strategy,
""" weight_name=str(self.node.args[1]))
for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape
switch_partition_dim(sharding_spec, 0, -1)
# create multiple sharding strategies for the inputs # create multiple sharding strategies for the inputs
# as input can be multi-dimensinal and the partition dim is only 2D, # as input can be multi-dimensinal and the partition dim is only 2D,
# we need to map the partition at dim 0 to one of the first few dimensions of the input # we need to map the partition at dim 0 to one of the first few dimensions of the input
sharding_strategies = [] strategies = _convert_logical_sharding_to_physical_sharding_spec_for_linear(strategy=strategy,
input_op_data = strategy.get_op_data_by_name(str(self.node.args[0])) input_name=str(self.node.args[0]),
output_op_data = strategy.get_op_data_by_name(str(self.node)) output_name=str(self.node))
num_input_dims = input_op_data.data.dim() return strategies
input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
if 0 in input_sharding_spec.dim_partition_dict:
for i in range(num_input_dims - 1):
new_strategy = strategy.clone()
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
try:
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(new_strategy)
except ShardingException:
pass
else:
sharding_strategies.append(strategy)
return strategy
@operator_registry.register(torch.bmm) @operator_registry.register(torch.bmm)
......
import copy import copy
import operator import operator
from functools import reduce from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
...@@ -292,7 +293,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -292,7 +293,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self): def collate_strategies(self) -> List[ShardingStrategy]:
''' '''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector. Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
''' '''
...@@ -325,9 +326,4 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -325,9 +326,4 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# S01R = S01R x R WITH SYNC_BN # S01R = S01R x R WITH SYNC_BN
# strategy_list.append(self.split_input_batch_1d(0, 1)) # strategy_list.append(self.split_input_batch_1d(0, 1))
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list return strategy_list
...@@ -5,7 +5,8 @@ from functools import reduce ...@@ -5,7 +5,8 @@ from functools import reduce
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import exception_handler from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
...@@ -25,8 +26,8 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -25,8 +26,8 @@ class ConvStrategyGenerator(StrategyGenerator):
For Conv3d, the dim of input data should be 5([N, C, H, W, D]). For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
''' '''
input_op_data = self.op_data['input'] input_op_data = self.op_data['input']
assert input_op_data.dim() in (3, 4, assert input_op_data.data.dim() in (
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].' 3, 4, 5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy): def update_compute_cost(self, strategy: ShardingStrategy):
''' '''
...@@ -99,7 +100,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -99,7 +100,7 @@ class ConvStrategyGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
@exception_handler @ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1): def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
...@@ -146,7 +147,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -146,7 +147,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@exception_handler @ignore_sharding_exception
def split_input_batch(self, mesh_dim_0): def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR' name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
...@@ -183,7 +184,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -183,7 +184,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@exception_handler @ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1): def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
...@@ -230,7 +231,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -230,7 +231,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@exception_handler @ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1): def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
...@@ -270,7 +271,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -270,7 +271,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@exception_handler @ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0): def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R' name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
...@@ -301,7 +302,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -301,7 +302,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@exception_handler @ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0): def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}' name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
...@@ -334,7 +335,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -334,7 +335,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@exception_handler @ignore_sharding_exception
def non_split(self): def non_split(self):
name = f'RR = RR x RR' name = f'RR = RR x RR'
...@@ -353,7 +354,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -353,7 +354,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping={}) communication_action_mapping={})
@exception_handler @ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1): def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
...@@ -391,7 +392,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -391,7 +392,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@exception_handler @ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1): def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
...@@ -421,7 +422,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -421,7 +422,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@exception_handler @ignore_sharding_exception
def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1): def split_1d_parallel_on_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
...@@ -453,7 +454,7 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -453,7 +454,7 @@ class ConvStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
strategies = [] strategies = []
# SS = SR x RS # SS = SR x RS
strategies.append(self.split_input_batch_weight_out_channel(0, 1)) strategies.append(self.split_input_batch_weight_out_channel(0, 1))
...@@ -491,20 +492,4 @@ class ConvStrategyGenerator(StrategyGenerator): ...@@ -491,20 +492,4 @@ class ConvStrategyGenerator(StrategyGenerator):
# RS01 = RR x RS01 # RS01 = RR x RS01
strategies.append(self.split_1d_parallel_on_out_channel(0, 1)) strategies.append(self.split_1d_parallel_on_out_channel(0, 1))
rm_list = [strategy for strategy in strategies if strategy is None]
for rm_element in rm_list:
strategies.remove(rm_element)
illegal_strategy_list = []
# update mete info on cost
for strategy in strategies:
try:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
except AssertionError as e:
illegal_strategy_list.append(strategy)
warnings.warn(f'{e}')
for strategy in illegal_strategy_list:
strategies.remove(strategy)
return strategies return strategies
import copy import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
...@@ -61,7 +62,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator): ...@@ -61,7 +62,7 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
Deal with case 1 and 2. Deal with case 1 and 2.
''' '''
def generate(self): def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
for strategy in self.predecessor_node.strategies_vector: for strategy in self.predecessor_node.strategies_vector:
dim_partition_dict_mapping = {} dim_partition_dict_mapping = {}
...@@ -109,7 +110,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator): ...@@ -109,7 +110,7 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
Deal with case 3. Deal with case 3.
''' '''
def generate(self): def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
index = self.op_data["index"].data index = self.op_data["index"].data
...@@ -133,9 +134,4 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator): ...@@ -133,9 +134,4 @@ class TensorTupleStrategyGenerator(GetItemStrategyGenerator):
strategy_list.append(strategy) strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list return strategy_list
import copy import copy
import operator import operator
from functools import reduce from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
...@@ -159,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator): ...@@ -159,7 +160,7 @@ class LayerNormGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self): def collate_strategies(self) -> List[ShardingStrategy]:
''' '''
Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector. Generate every possible strategies for a LayerNorm node, and record all strategies into the strategies_vector.
''' '''
...@@ -178,11 +179,5 @@ class LayerNormGenerator(StrategyGenerator): ...@@ -178,11 +179,5 @@ class LayerNormGenerator(StrategyGenerator):
# RR = RR x R # RR = RR x R
strategy_list.append(self.non_split()) strategy_list.append(self.non_split())
# update mete info on cost
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list return strategy_list
...@@ -3,6 +3,8 @@ from functools import reduce ...@@ -3,6 +3,8 @@ from functools import reduce
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import \
ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
...@@ -169,7 +171,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): ...@@ -169,7 +171,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
total=fwd_compute_cost + bwd_compute_cost) total=fwd_compute_cost + bwd_compute_cost)
strategy.compute_cost = compute_cost strategy.compute_cost = compute_cost
def generate(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
strategies = [] strategies = []
# SS = SR x RS # SS = SR x RS
...@@ -201,14 +203,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): ...@@ -201,14 +203,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
# RS01 = RR x RS01 # RS01 = RR x RS01
strategies.append(self.split_rhs_2nd_dim_1d(0, 1)) strategies.append(self.split_rhs_2nd_dim_1d(0, 1))
# update mete info on cost
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategies return strategies
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1): def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS # handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}' name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
...@@ -249,6 +246,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): ...@@ -249,6 +246,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR # handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R' name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
...@@ -289,6 +287,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): ...@@ -289,6 +287,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1): def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}' name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
...@@ -324,6 +323,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): ...@@ -324,6 +323,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim): def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R' name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
...@@ -351,6 +351,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): ...@@ -351,6 +351,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim): def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}' name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
...@@ -380,6 +381,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): ...@@ -380,6 +381,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1): def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR' name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
# get sharding spec # get sharding spec
...@@ -410,6 +412,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): ...@@ -410,6 +412,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communcation_action_mapping) communication_action_mapping=communcation_action_mapping)
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R' name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
...@@ -437,6 +440,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator): ...@@ -437,6 +440,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1): def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}' name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
...@@ -542,7 +546,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): ...@@ -542,7 +546,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec=sharding_spec_mapping['bias'], sharding_spec=sharding_spec_mapping['bias'],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1]) logical_process_axis=[mesh_dim_0, mesh_dim_1])
communication_action_mappingp['bias'] = bias_comm_spec communication_action_mapping['bias'] = bias_comm_spec
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
...@@ -662,7 +666,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): ...@@ -662,7 +666,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
def generate(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
device_mesh_is_1d = True device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape:
......
...@@ -25,8 +25,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator): ...@@ -25,8 +25,8 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
For Pool3d, the dim of input data should be 5([N, C, H, W, D]). For Pool3d, the dim of input data should be 5([N, C, H, W, D]).
''' '''
input_op_data = self.op_data['input'] input_op_data = self.op_data['input']
assert input_op_data.dim() in (3, 4, assert input_op_data.data.dim() in (
5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].' 3, 4, 5), f'We suppose the dim of input fed into Pool op should in range of [3, 5].'
def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem: def update_compute_cost(self, strategy: ShardingStrategy) -> TrainCycleItem:
''' '''
...@@ -103,7 +103,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator): ...@@ -103,7 +103,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
return dim_partition_list return dim_partition_list
def generate(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1) dim_partition_list = self.enumerate_all_possible_batch_dimensions_dim_partition(0, 1)
...@@ -111,9 +111,4 @@ class NormalPoolStrategyGenerator(StrategyGenerator): ...@@ -111,9 +111,4 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
strategy = self._generate_strategy_with_dim_partition(dim_partition) strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy) strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list return strategy_list
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import OutputStrategyGenerator from .strategy_generator import OutputStrategyGenerator
...@@ -30,7 +32,7 @@ class OutputGenerator(OutputStrategyGenerator): ...@@ -30,7 +32,7 @@ class OutputGenerator(OutputStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
def generate(self): def collate_strategies(self) -> List[ShardingStrategy]:
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
"output": {}, "output": {},
} }
...@@ -47,8 +49,4 @@ class OutputGenerator(OutputStrategyGenerator): ...@@ -47,8 +49,4 @@ class OutputGenerator(OutputStrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return [strategy] return [strategy]
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
...@@ -35,7 +37,7 @@ class PlaceholderGenerator(StrategyGenerator): ...@@ -35,7 +37,7 @@ class PlaceholderGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
def generate(self): def collate_strategies(self) -> List[ShardingStrategy]:
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
"output": {}, "output": {},
} }
...@@ -48,8 +50,4 @@ class PlaceholderGenerator(StrategyGenerator): ...@@ -48,8 +50,4 @@ class PlaceholderGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return [strategy] return [strategy]
import copy import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
...@@ -49,7 +50,7 @@ class ReshapeGenerator(FollowingStrategyGenerator): ...@@ -49,7 +50,7 @@ class ReshapeGenerator(FollowingStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
def generate(self): def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
# For reshape function, to keep the computing correctness we keep the sharding # For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in # spec of input is fully replicated. In addition, we will keep the output in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment