Unverified Commit 49216d7a authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] fix bugs caused by negative dim key (#1808)

* [autoparallel] fix bugs caused by negative dim key

* fix import error

* fix matmul test issue

* fix unit test issue
parent 4268ae01
...@@ -454,6 +454,9 @@ class MatMulHandler(NodeHandler): ...@@ -454,6 +454,9 @@ class MatMulHandler(NodeHandler):
if -1 in dim_partition_dict: if -1 in dim_partition_dict:
shard = dim_partition_dict.pop(-1) shard = dim_partition_dict.pop(-1)
dim_partition_dict[0] = shard dim_partition_dict[0] = shard
if 1 in dim_partition_dict:
shard = dim_partition_dict.pop(1)
dim_partition_dict[0] = shard
# re-init the sharding spec # re-init the sharding spec
input_sharding_spec.__init__(input_sharding_spec.device_mesh, input_sharding_spec.__init__(input_sharding_spec.device_mesh,
......
...@@ -9,6 +9,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -9,6 +9,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
ShardingStrategy, ShardingStrategy,
TrainCycleItem, TrainCycleItem,
) )
from colossalai.auto_parallel.tensor_shard.utils import ignore_sharding_exception
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
...@@ -103,6 +104,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -103,6 +104,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0): def split_input_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}' name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
...@@ -134,6 +136,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -134,6 +136,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1): def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}' name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
...@@ -165,6 +168,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -165,6 +168,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def non_split(self): def non_split(self):
name = f'RR = RR x R' name = f'RR = RR x R'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
...@@ -186,6 +190,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -186,6 +190,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0): def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN' name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
...@@ -221,6 +226,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -221,6 +226,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1): def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN' name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
...@@ -256,6 +262,7 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -256,6 +262,7 @@ class BatchNormStrategyGenerator(StrategyGenerator):
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1): def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN' name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_mapping = { dim_partition_dict_mapping = {
......
...@@ -3,9 +3,12 @@ import operator ...@@ -3,9 +3,12 @@ import operator
from functools import reduce from functools import reduce
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_2d_sharding) enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
...@@ -79,6 +82,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator): ...@@ -79,6 +82,7 @@ class NormalPoolStrategyGenerator(StrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition): def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition} dim_partition_dict_mapping = {"input": dim_partition, "output": dim_partition}
......
...@@ -17,6 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -17,6 +17,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.utils import convert_dim_partition_dict
class StrategyGenerator(ABC): class StrategyGenerator(ABC):
...@@ -74,11 +75,15 @@ class StrategyGenerator(ABC): ...@@ -74,11 +75,15 @@ class StrategyGenerator(ABC):
op_data = self.op_data[op_data_name] op_data = self.op_data[op_data_name]
if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor): if isinstance(op_data.data, tuple) and isinstance(op_data.data[0], torch.Tensor):
sharding_spec = [] sharding_spec = []
for output, dim_partition_dict_element in zip(op_data.data, dim_partition_dict): for logical_shape, dim_partition_dict_element in zip(op_data.logical_shape, dim_partition_dict):
dim_size = len(logical_shape)
dim_partition_dict_element = convert_dim_partition_dict(dim_size, dim_partition_dict_element)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh, sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=output.shape, entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict_element) dim_partition_dict=dim_partition_dict_element)
else: else:
dim_size = len(op_data.logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh, sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=op_data.logical_shape, entire_shape=op_data.logical_shape,
dim_partition_dict=dim_partition_dict) dim_partition_dict=dim_partition_dict)
......
from .process_group import ProcessGroup from . import distspec
from .tensor_spec import ColoTensorSpec
from .distspec import ShardSpec
from .distspec import ReplicaSpec
from .compute_spec import ComputeSpec, ComputePattern
from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor from .colo_tensor import ColoTensor
from .comm_spec import CollectiveCommPattern, CommSpec
from .compute_spec import ComputePattern, ComputeSpec
from .dist_spec_mgr import DistSpecManager from .dist_spec_mgr import DistSpecManager
from .distspec import ReplicaSpec, ShardSpec
from .param_op_hook import ParamOpHook, ParamOpHookManager from .param_op_hook import ParamOpHook, ParamOpHookManager
from .comm_spec import CollectiveCommPattern, CommSpec from .process_group import ProcessGroup
from . import distspec from .tensor_spec import ColoTensorSpec
from .utils import convert_dim_partition_dict, convert_parameter, merge_same_dim_mesh_list, named_params_with_colotensor
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter', 'ColoTensor', 'convert_parameter', 'ComputePattern', 'ComputeSpec', 'named_params_with_colotensor', 'ColoParameter',
'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec', 'distspec', 'DistSpecManager', 'ParamOpHook', 'ParamOpHookManager', 'ProcessGroup', 'ColoTensorSpec', 'ShardSpec',
'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern' 'ReplicaSpec', 'CommSpec', 'CollectiveCommPattern', 'convert_dim_partition_dict', 'merge_same_dim_mesh_list'
] ]
import torch
from typing import Optional from typing import Optional
import torch
from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor.const import TensorType from colossalai.tensor.const import TensorType
from colossalai.tensor import ColoTensorSpec
from colossalai.tensor.param_op_hook import ParamOpHookManager from colossalai.tensor.param_op_hook import ParamOpHookManager
from colossalai.tensor.tensor_spec import ColoTensorSpec
def filter_args(func, *args): def filter_args(func, *args):
......
...@@ -4,9 +4,10 @@ from typing import Callable, Optional, Set ...@@ -4,9 +4,10 @@ from typing import Callable, Optional, Set
import torch import torch
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ReplicaSpec
from colossalai.tensor.dist_spec_mgr import DistSpecManager from colossalai.tensor.dist_spec_mgr import DistSpecManager
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec from colossalai.tensor.distspec import DistPlacementPattern, ReplicaSpec, _DistSpec
from colossalai.tensor.process_group import ProcessGroup
from colossalai.tensor.tensor_spec import ColoTensorSpec
from .const import TensorType from .const import TensorType
from .op_wrapper import _COLOSSAL_OPS from .op_wrapper import _COLOSSAL_OPS
......
from colossalai.tensor.distspec import _DistSpec
# from colossalai.nn.layer.utils import divide
from numpy import prod
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
# from colossalai.nn.layer.utils import divide
from numpy import prod
from packaging import version from packaging import version
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.tensor import ProcessGroup from colossalai.tensor.distspec import _DistSpec
from colossalai.tensor.process_group import ProcessGroup
# TODO(jiaruifang) circle import, move the divide to colossalai.commons. # TODO(jiaruifang) circle import, move the divide to colossalai.commons.
......
import torch
from contextlib import contextmanager
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Tuple, Any from contextlib import contextmanager
from typing import Any, List, Tuple
import torch
from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
from colossalai.tensor import ColoTensorSpec from colossalai.tensor.tensor_spec import ColoTensorSpec
class ParamOpHook(ABC): class ParamOpHook(ABC):
......
...@@ -6,6 +6,8 @@ import torch ...@@ -6,6 +6,8 @@ import torch
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from .utils import merge_same_dim_mesh_list
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] __all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
ALLGATHER_COST = 20 ALLGATHER_COST = 20
...@@ -181,8 +183,12 @@ class ShardingSpec: ...@@ -181,8 +183,12 @@ class ShardingSpec:
self.dim_partition_dict = dim_partition_dict self.dim_partition_dict = dim_partition_dict
self.sharding_sequence = sharding_sequence self.sharding_sequence = sharding_sequence
if self.sharding_sequence is None: if self.sharding_sequence is None:
assert self.dim_partition_dict is not None, f'dim_partition_dict should not be None, if sharding_sequence is NoneType object.'
self.dim_partition_dict = merge_same_dim_mesh_list(dim_size=len(entire_shape),
dim_partition_dict=self.dim_partition_dict)
self.convert_dict_to_shard_sequence() self.convert_dict_to_shard_sequence()
elif self.dim_partition_dict is None: elif self.dim_partition_dict is None:
assert self.sharding_sequence is not None, f'sharding_sequence should not be None, if dim_partition_dict is NoneType object.'
self.convert_shard_sequence_to_dict() self.convert_shard_sequence_to_dict()
self._sanity_check() self._sanity_check()
......
from dataclasses import dataclass
from typing import Optional from typing import Optional
from colossalai.tensor.distspec import _DistSpec, DistPlacementPattern
from colossalai.tensor.distspec import DistPlacementPattern, _DistSpec
from colossalai.tensor.process_group import ProcessGroup
from .compute_spec import ComputeSpec from .compute_spec import ComputeSpec
from colossalai.tensor import ProcessGroup
from dataclasses import dataclass
@dataclass @dataclass
class ColoTensorSpec: class ColoTensorSpec:
""" ColoTensorSpec """ ColoTensorSpec
A data class for specifications of the `ColoTensor`. A data class for specifications of the `ColoTensor`.
It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`. It contains attributes of `ProcessGroup`, `_DistSpec`, `ComputeSpec`.
The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`. The latter two attributes are optional. If not set, they are default value is `Replicate()` and `None`.
......
import torch from typing import Dict, Iterator, List, Tuple, Union
from typing import Iterator, Tuple, Union import torch
import torch.nn as nn import torch.nn as nn
from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
...@@ -12,7 +13,7 @@ def all_gather_simulator(target_pair): ...@@ -12,7 +13,7 @@ def all_gather_simulator(target_pair):
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed. We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
Therefore, all gather operation just remove the last element in shard list, Therefore, all gather operation just remove the last element in shard list,
e.g.: e.g.:
all-gather(S01) -> S0 all-gather(S01) -> S0
Argument: Argument:
...@@ -31,18 +32,18 @@ def all_to_all_simulator(f_target_pair, b_target_pair): ...@@ -31,18 +32,18 @@ def all_to_all_simulator(f_target_pair, b_target_pair):
and simulate the influence of the DimSpec. and simulate the influence of the DimSpec.
We BANNED all representations which shard_list in decreasing order, We BANNED all representations which shard_list in decreasing order,
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed. such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list. Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
Argument: Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension. and the second element decribes which logical axis will be sharded in that dimension.
e.g.: e.g.:
all-to-all(S0, S1) -> [S01, R] all-to-all(S0, S1) -> [S01, R]
all-to-all(S0, R) -> [R, S0] all-to-all(S0, R) -> [R, S0]
Otherwise, we extend the front shard_list to behind. Otherwise, we extend the front shard_list to behind.
e.g.: e.g.:
all-to-all(R, S1) -> [S1, R] all-to-all(R, S1) -> [S1, R]
Argument: Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded, target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension. and the second element decribes which logical axis will be sharded in that dimension.
...@@ -65,7 +66,7 @@ def shard_simulator(target_pair, legal_sharding_dims): ...@@ -65,7 +66,7 @@ def shard_simulator(target_pair, legal_sharding_dims):
and simulate the influence of the DimSpec. and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed. We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
In addition, We BANNED all representations which shard_list in decreasing order, In addition, We BANNED all representations which shard_list in decreasing order,
such as S10, so shard(S0) -> S10 is NOT allowed. such as S10, so shard(S0) -> S10 is NOT allowed.
Therefore, for the R dimension, we could just append any legal sharding dim on it. Therefore, for the R dimension, we could just append any legal sharding dim on it.
e.g.: e.g.:
...@@ -164,3 +165,37 @@ def convert_parameter(module: torch.nn.Module, param_name: str): ...@@ -164,3 +165,37 @@ def convert_parameter(module: torch.nn.Module, param_name: str):
# Now we can set the attribute appropriately. # Now we can set the attribute appropriately.
setattr(module, param_name, st) setattr(module, param_name, st)
def convert_dim_partition_dict(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
'''
This method is used to convert the negative dim value to positive.
'''
dims_to_convert = []
for dim, mesh_list in dim_partition_dict.items():
if dim < 0:
dims_to_convert.append(dim)
for dim in dims_to_convert:
dim_partition_dict.pop(dim)
dim_partition_dict[dim_size + dim] = mesh_list
return dim_partition_dict
def merge_same_dim_mesh_list(dim_size: int, dim_partition_dict: Dict[int, List[int]]) -> Dict[int, List[int]]:
'''
This method is used to merge the different key value which points to same physical position.
For example:
dim_partition_dict: {1 :[0], -1: [1]} or {1: [0], 1: [1]} for a 2d tensor, the dim 1 and -1 point same physical position.
In this method, above dim_partition_dict will be converted to {1: [0, 1]}
'''
converted_dim_partition_dict = {}
for dim, mesh_list in dim_partition_dict.items():
if dim < 0:
dim = dim_size + dim
if dim not in converted_dim_partition_dict:
converted_dim_partition_dict[dim] = mesh_list
else:
converted_dim_partition_dict[dim].extend(mesh_list)
return converted_dim_partition_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment