Unverified Commit eee84908 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] handled illegal sharding strategy (#1728)

* [autoparallel] handled illegal sharding strategy

* polish code
parent cbe9a4cb
......@@ -4,13 +4,12 @@ from functools import reduce
from typing import Any, Dict, List, Union
import torch
from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
TrainCycleItem)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx import Node
class StrategyGenerator(ABC):
......@@ -24,6 +23,9 @@ class StrategyGenerator(ABC):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
# validate the whether operation data is of desired value
self.validate()
@property
def has_bias(self):
"""
......@@ -102,9 +104,9 @@ class StrategyGenerator(ABC):
comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
def _compute_and_add(data: OperationData, comm_spec: CommSpec):
def _compute_and_add(op_data: OperationData, comm_spec: CommSpec):
num_ele_in_comm = comm_spec.get_comm_cost()
dtype = operand.data.dtype
dtype = op_data.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
......@@ -151,11 +153,30 @@ class StrategyGenerator(ABC):
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
@abstractmethod
def generate(self) -> List[ShardingStrategy]:
"""
Generate all possible sharding strategies for this operation.
"""
strategies = self.collate_strategies()
# some strategies may be None as ignore_sharding_exception may return None
# when ShardingSpecException occurs.
# thus, remove those None values
strategies = [strategy for strategy in strategies if strategy]
# update the costs
# update mete info on cost
# these update methods are all in-place, the default method will do nothing
# the cost info will only be added if the child class overrides these methods
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategies
@abstractmethod
def collate_strategies(self) -> List[ShardingStrategy]:
pass
@abstractmethod
......
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
......@@ -48,7 +49,7 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
......@@ -73,9 +74,4 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
......@@ -78,7 +79,7 @@ class WhereGenerator(StrategyGenerator):
return dim_partition_list
def generate(self):
def collate_strategies(self) -> List[ShardingStrategy]:
'''
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
'''
......@@ -90,9 +91,4 @@ class WhereGenerator(StrategyGenerator):
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list
from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape)
from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import exception_handler
from .misc import ignore_sharding_exception
from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size,
switch_partition_dim, update_partition_dim)
__all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'exception_handler', 'switch_partition_dim',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'switch_partition_dim',
'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding',
'generate_sharding_size'
]
import functools
import warnings
__all__ = ['exception_handler']
from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpecException
__all__ = ['ignore_sharding_exception']
def exception_handler(func):
def ignore_sharding_exception(func):
"""
A function wrapper to handle the AssertionError in the function.
A function wrapper to handle the ShardingSpecException in the function.
If ShardingSpecException occurs, this function will return None.
Usage:
# mute the assertion error in the function
@exception_handler
@ignore_sharding_exception
def do_something():
...
"""
......@@ -18,9 +21,11 @@ def exception_handler(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
logger = get_dist_logger()
rst = func(*args, **kwargs)
return rst
except AssertionError as e:
warnings.warn(f'{e}')
except ShardingSpecException as e:
logger.debug(e)
return None
return wrapper
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
import operator
from copy import deepcopy
from enum import Enum
from functools import reduce
import operator
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
......@@ -138,7 +140,19 @@ class _DimSpec:
return difference
class ShardingException(Exception):
class ShardingSpecException(Exception):
pass
class ShardingOutOfIndexError(ShardingSpecException):
pass
class DuplicatedShardingDimensionError(ShardingSpecException):
pass
class ShardingNotDivisibleError(ShardingSpecException):
pass
......@@ -156,7 +170,11 @@ class ShardingSpec:
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None):
def __init__(self,
device_mesh: DeviceMesh,
entire_shape: torch.Size,
dim_partition_dict=None,
sharding_sequence=None):
self.device_mesh = device_mesh
self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict
......@@ -174,19 +192,36 @@ class ShardingSpec:
return ' '.join(res_list)
def _sanity_check(self):
'''
In sanity check, we need make sure all axes in logical device mesh only be used
once.
'''
dim_check_list = [i for i in range(self.device_mesh.logical_mesh_id.dim())]
# make sure all axes in logical device mesh only be used once
dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
for dim, shard_list in self.dim_partition_dict.items():
for element in shard_list:
if element in dim_check_list:
dim_check_list.remove(element)
else:
raise ValueError(
raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
# make sure that the dimension is not out of index
for dim in self.dim_partition_dict.keys():
if dim >= len(self.entire_shape):
raise ShardingOutOfIndexError(
f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
)
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in self.dim_partition_dict.items():
tensor_dim_size = self.entire_shape[dim]
num_devices = 1
for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element]
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(
f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
)
def convert_dict_to_shard_sequence(self):
'''
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
......
from cProfile import run
import pytest
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module):
......@@ -27,6 +30,7 @@ class ConvModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
......
import pytest
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class MatmulModel(nn.Module):
......@@ -20,6 +21,7 @@ class MatmulModel(nn.Module):
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
......
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
def is_sharding_spec_valid(sharding_spec: ShardingSpec, tensor: torch.Tensor):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
#
"""
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
if str(dim_spec).startswith('S'):
devices_str = str(dim_spec).lstrip('S')
num_devices = 1
if '0' in devices_str:
num_devices *= num_devices_in_col
if '1' in devices_str:
num_devices *= num_devices_in_row
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
......@@ -6,12 +6,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa
StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear
from colossalai.tensor.sharding_spec import ShardingSpec
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.common import \
is_sharding_spec_valid
def test_linear_module_handler():
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
......@@ -91,6 +92,12 @@ def test_linear_module_handler():
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
# make sure the sharding spec is valid
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('0.weight'))
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('0.bias'))
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
......@@ -101,7 +108,7 @@ def test_linear_module_handler():
def test_linear_function_handler():
model = nn.Linear(16, 32).to('meta')
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
......@@ -117,11 +124,13 @@ def test_linear_function_handler():
# # check operation data mapping
mapping = handler.get_operation_data_mapping()
print(mapping['input'].logical_shape)
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['input'].logical_shape == torch.Size([16, 16])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
......@@ -137,7 +146,7 @@ def test_linear_function_handler():
assert mapping['output'].name == "linear"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 32])
assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy(compute_resharding_cost=False)
......@@ -167,11 +176,18 @@ def test_linear_function_handler():
for strategy in strategies_vector:
strategy: ShardingStrategy
print(strategy)
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
# make sure the sharding spec is valid
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('weight'))
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('bias'))
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
......
import torch
import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \
......
from functools import partial
from lib2to3 import pgen2
import colossalai
import torch
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from functools import partial
import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor import ColoTensor, ColoParameter, ProcessGroup
from colossalai.nn._ops._utils import gather_forward_split_backward
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
def run_dist(rank, world_size, port):
......@@ -18,7 +20,7 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# create mlp vars
x = ColoTensor.from_torch_tensor(torch.rand(2, 4, 8, requires_grad=True)).cuda()
x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda()
w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda()
b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda()
......
import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
def test_sharding_spec():
......@@ -11,7 +12,7 @@ def test_sharding_spec():
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8, 6))
entire_shape = torch.Size((16, 8, 6))
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment