Unverified Commit eee84908 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] handled illegal sharding strategy (#1728)

* [autoparallel] handled illegal sharding strategy

* polish code
parent cbe9a4cb
...@@ -4,13 +4,12 @@ from functools import reduce ...@@ -4,13 +4,12 @@ from functools import reduce
from typing import Any, Dict, List, Union from typing import Any, Dict, List, Union
import torch import torch
from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy, from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationData, OperationDataType, ShardingStrategy,
TrainCycleItem) TrainCycleItem)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from torch.fx import Node
class StrategyGenerator(ABC): class StrategyGenerator(ABC):
...@@ -24,6 +23,9 @@ class StrategyGenerator(ABC): ...@@ -24,6 +23,9 @@ class StrategyGenerator(ABC):
self.op_data = operation_data_mapping self.op_data = operation_data_mapping
self.device_mesh = device_mesh self.device_mesh = device_mesh
# validate the whether operation data is of desired value
self.validate()
@property @property
def has_bias(self): def has_bias(self):
""" """
...@@ -102,9 +104,9 @@ class StrategyGenerator(ABC): ...@@ -102,9 +104,9 @@ class StrategyGenerator(ABC):
comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0) comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
def _compute_and_add(data: OperationData, comm_spec: CommSpec): def _compute_and_add(op_data: OperationData, comm_spec: CommSpec):
num_ele_in_comm = comm_spec.get_comm_cost() num_ele_in_comm = comm_spec.get_comm_cost()
dtype = operand.data.dtype dtype = op_data.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for phase, cost in num_ele_in_comm.items(): for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
...@@ -151,11 +153,30 @@ class StrategyGenerator(ABC): ...@@ -151,11 +153,30 @@ class StrategyGenerator(ABC):
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size() size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return reduce(operator.mul, sharded_shape) * size_per_elem_bytes return reduce(operator.mul, sharded_shape) * size_per_elem_bytes
@abstractmethod
def generate(self) -> List[ShardingStrategy]: def generate(self) -> List[ShardingStrategy]:
""" """
Generate all possible sharding strategies for this operation. Generate all possible sharding strategies for this operation.
""" """
strategies = self.collate_strategies()
# some strategies may be None as ignore_sharding_exception may return None
# when ShardingSpecException occurs.
# thus, remove those None values
strategies = [strategy for strategy in strategies if strategy]
# update the costs
# update mete info on cost
# these update methods are all in-place, the default method will do nothing
# the cost info will only be added if the child class overrides these methods
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategies
@abstractmethod
def collate_strategies(self) -> List[ShardingStrategy]:
pass pass
@abstractmethod @abstractmethod
......
import copy import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
...@@ -48,7 +49,7 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): ...@@ -48,7 +49,7 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost) memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost strategy.memory_cost = memory_cost
def generate(self): def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
# For element-wise function, we keep the sharding spec of output node same as # For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same # the input. Therefore, the different strategies of input node with same
...@@ -73,9 +74,4 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator): ...@@ -73,9 +74,4 @@ class UnaryElementwiseGenerator(FollowingStrategyGenerator):
communication_action_mapping=communication_action_mapping) communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy) strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list return strategy_list
import copy import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding,
...@@ -78,7 +79,7 @@ class WhereGenerator(StrategyGenerator): ...@@ -78,7 +79,7 @@ class WhereGenerator(StrategyGenerator):
return dim_partition_list return dim_partition_list
def generate(self): def collate_strategies(self) -> List[ShardingStrategy]:
''' '''
Generate every possible strategies for a where node, and record all strategies into the strategies_vector. Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
''' '''
...@@ -90,9 +91,4 @@ class WhereGenerator(StrategyGenerator): ...@@ -90,9 +91,4 @@ class WhereGenerator(StrategyGenerator):
strategy = self._generate_strategy_with_dim_partition(dim_partition) strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy) strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list return strategy_list
from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape) from .broadcast import (BroadcastType, get_broadcast_shape, is_broadcastable, recover_sharding_spec_for_broadcast_shape)
from .factory import generate_resharding_costs, generate_sharding_spec from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import exception_handler from .misc import ignore_sharding_exception
from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size, from .sharding import (enumerate_all_possible_1d_sharding, enumerate_all_possible_2d_sharding, generate_sharding_size,
switch_partition_dim, update_partition_dim) switch_partition_dim, update_partition_dim)
__all__ = [ __all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape', 'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'exception_handler', 'switch_partition_dim', 'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'switch_partition_dim',
'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding', 'update_partition_dim', 'enumerate_all_possible_1d_sharding', 'enumerate_all_possible_2d_sharding',
'generate_sharding_size' 'generate_sharding_size'
] ]
import functools import functools
import warnings
__all__ = ['exception_handler'] from colossalai.logging import get_dist_logger
from colossalai.tensor.sharding_spec import ShardingSpecException
__all__ = ['ignore_sharding_exception']
def exception_handler(func):
def ignore_sharding_exception(func):
""" """
A function wrapper to handle the AssertionError in the function. A function wrapper to handle the ShardingSpecException in the function.
If ShardingSpecException occurs, this function will return None.
Usage: Usage:
# mute the assertion error in the function # mute the assertion error in the function
@exception_handler @ignore_sharding_exception
def do_something(): def do_something():
... ...
""" """
...@@ -18,9 +21,11 @@ def exception_handler(func): ...@@ -18,9 +21,11 @@ def exception_handler(func):
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
try: try:
logger = get_dist_logger()
rst = func(*args, **kwargs) rst = func(*args, **kwargs)
return rst return rst
except AssertionError as e: except ShardingSpecException as e:
warnings.warn(f'{e}') logger.debug(e)
return None
return wrapper return wrapper
import torch import operator
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
import operator
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import (all_gather_simulator, all_to_all_simulator, shard_simulator)
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec'] __all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
...@@ -138,7 +140,19 @@ class _DimSpec: ...@@ -138,7 +140,19 @@ class _DimSpec:
return difference return difference
class ShardingException(Exception): class ShardingSpecException(Exception):
pass
class ShardingOutOfIndexError(ShardingSpecException):
pass
class DuplicatedShardingDimensionError(ShardingSpecException):
pass
class ShardingNotDivisibleError(ShardingSpecException):
pass pass
...@@ -156,7 +170,11 @@ class ShardingSpec: ...@@ -156,7 +170,11 @@ class ShardingSpec:
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1]. sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
''' '''
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None): def __init__(self,
device_mesh: DeviceMesh,
entire_shape: torch.Size,
dim_partition_dict=None,
sharding_sequence=None):
self.device_mesh = device_mesh self.device_mesh = device_mesh
self.entire_shape = entire_shape self.entire_shape = entire_shape
self.dim_partition_dict = dim_partition_dict self.dim_partition_dict = dim_partition_dict
...@@ -174,19 +192,36 @@ class ShardingSpec: ...@@ -174,19 +192,36 @@ class ShardingSpec:
return ' '.join(res_list) return ' '.join(res_list)
def _sanity_check(self): def _sanity_check(self):
''' # make sure all axes in logical device mesh only be used once
In sanity check, we need make sure all axes in logical device mesh only be used dim_check_list = list(range(self.device_mesh.logical_mesh_id.dim()))
once.
'''
dim_check_list = [i for i in range(self.device_mesh.logical_mesh_id.dim())]
for dim, shard_list in self.dim_partition_dict.items(): for dim, shard_list in self.dim_partition_dict.items():
for element in shard_list: for element in shard_list:
if element in dim_check_list: if element in dim_check_list:
dim_check_list.remove(element) dim_check_list.remove(element)
else: else:
raise ValueError( raise DuplicatedShardingDimensionError(
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
# make sure that the dimension is not out of index
for dim in self.dim_partition_dict.keys():
if dim >= len(self.entire_shape):
raise ShardingOutOfIndexError(
f"The dim_partition_dict specifies to shard dimension {dim} but the entire_shape only has {len(self.entire_shape)} dimensions"
)
# make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in self.dim_partition_dict.items():
tensor_dim_size = self.entire_shape[dim]
num_devices = 1
for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element]
if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError(
f'The size of dimension at index {dim} is {tensor_dim_size}, it cannot be sharded over {num_devices} devices.'
)
def convert_dict_to_shard_sequence(self): def convert_dict_to_shard_sequence(self):
''' '''
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence. Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
......
from cProfile import run
import pytest
import torch import torch
from torch.fx import GraphModule
import torch.nn as nn import torch.nn as nn
import pytest from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module): class ConvModel(nn.Module):
...@@ -27,6 +30,7 @@ class ConvModel(nn.Module): ...@@ -27,6 +30,7 @@ class ConvModel(nn.Module):
return x return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_handler(): def test_conv_handler():
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
......
import pytest
import torch import torch
from torch.fx import GraphModule
import torch.nn as nn import torch.nn as nn
import pytest from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class MatmulModel(nn.Module): class MatmulModel(nn.Module):
...@@ -20,6 +21,7 @@ class MatmulModel(nn.Module): ...@@ -20,6 +21,7 @@ class MatmulModel(nn.Module):
return x return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_handler(): def test_conv_handler():
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
......
import torch
from colossalai.tensor.sharding_spec import ShardingSpec
def is_sharding_spec_valid(sharding_spec: ShardingSpec, tensor: torch.Tensor):
"""
This function checks whether the ShardingSpec is valid for the physical tensor.
This check includes 2 items:
1. the sharding spec covers all dimensions of the physical tensor
2. the sharding spec for each dimension is divisible by the number of devices.
#
"""
# make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1]
assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
# make sure the sharding is valid for each dim
for i in range(tensor_num_dim):
dim_size = tensor.shape[i]
dim_spec = sharding_spec.sharding_sequence[i]
if str(dim_spec).startswith('S'):
devices_str = str(dim_spec).lstrip('S')
num_devices = 1
if '0' in devices_str:
num_devices *= num_devices_in_col
if '1' in devices_str:
num_devices *= num_devices_in_row
assert dim_size >= num_devices and dim_size % num_devices == 0, \
f'The dimension at index {i} has value {dim_size}, but it is sharded over {num_devices} devices.'
...@@ -6,12 +6,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa ...@@ -6,12 +6,13 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (OperationDa
StrategiesVector) StrategiesVector)
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.fx.tracer.meta_patch.patched_module import linear from tests.test_auto_parallel.test_tensor_shard.test_node_handler.common import \
from colossalai.tensor.sharding_spec import ShardingSpec is_sharding_spec_valid
def test_linear_module_handler(): def test_linear_module_handler():
model = nn.Sequential(nn.Linear(16, 32).to('meta')) model = nn.Sequential(nn.Linear(16, 32).to('meta'))
tracer = ColoTracer() tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')}) graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
...@@ -91,6 +92,12 @@ def test_linear_module_handler(): ...@@ -91,6 +92,12 @@ def test_linear_module_handler():
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('_0') output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
# make sure the sharding spec is valid
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('0.weight'))
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('0.bias'))
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
# make sure the sharding matches across different operation data # make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
...@@ -101,7 +108,7 @@ def test_linear_module_handler(): ...@@ -101,7 +108,7 @@ def test_linear_module_handler():
def test_linear_function_handler(): def test_linear_function_handler():
model = nn.Linear(16, 32).to('meta') model = nn.Linear(16, 32).to('meta')
tracer = ColoTracer() tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')}) graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
...@@ -117,11 +124,13 @@ def test_linear_function_handler(): ...@@ -117,11 +124,13 @@ def test_linear_function_handler():
# # check operation data mapping # # check operation data mapping
mapping = handler.get_operation_data_mapping() mapping = handler.get_operation_data_mapping()
print(mapping['input'].logical_shape)
assert mapping['input'].name == "input_1" assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 16]) assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 16]) assert mapping['input'].logical_shape == torch.Size([16, 16])
assert mapping['other'].name == "weight" assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta assert mapping['other'].data.is_meta
...@@ -137,7 +146,7 @@ def test_linear_function_handler(): ...@@ -137,7 +146,7 @@ def test_linear_function_handler():
assert mapping['output'].name == "linear" assert mapping['output'].name == "linear"
assert mapping['output'].data.is_meta assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 32]) assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy(compute_resharding_cost=False) strategies_vector = handler.register_strategy(compute_resharding_cost=False)
...@@ -167,11 +176,18 @@ def test_linear_function_handler(): ...@@ -167,11 +176,18 @@ def test_linear_function_handler():
for strategy in strategies_vector: for strategy in strategies_vector:
strategy: ShardingStrategy strategy: ShardingStrategy
print(strategy)
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1') input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight') weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias') bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('linear') output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
# make sure the sharding spec is valid
is_sharding_spec_valid(input_sharding_spec, torch.rand(2, 2, 4, 16))
is_sharding_spec_valid(weight_sharding_spec, model.get_parameter('weight'))
is_sharding_spec_valid(bias_sharding_spec, model.get_parameter('bias'))
is_sharding_spec_valid(output_sharding_spec, torch.rand([2, 2, 4, 32]))
# make sure the sharding matches across different operation data # make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1] assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1] assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \ from colossalai.auto_parallel.tensor_shard.node_handler.conv_handler import \
ConvFunctionHandler ConvFunctionHandler
from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \ from colossalai.auto_parallel.tensor_shard.node_handler.unary_elementwise_handler import \
......
from functools import partial
from lib2to3 import pgen2 from lib2to3 import pgen2
import colossalai
import torch
import pytest import pytest
import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port import colossalai
from functools import partial
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor import ColoTensor, ColoParameter, ProcessGroup
from colossalai.nn._ops._utils import gather_forward_split_backward from colossalai.nn._ops._utils import gather_forward_split_backward
from colossalai.tensor import ColoParameter, ColoTensor, ProcessGroup
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
def run_dist(rank, world_size, port): def run_dist(rank, world_size, port):
...@@ -18,7 +20,7 @@ def run_dist(rank, world_size, port): ...@@ -18,7 +20,7 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# create mlp vars # create mlp vars
x = ColoTensor.from_torch_tensor(torch.rand(2, 4, 8, requires_grad=True)).cuda() x = ColoTensor.from_torch_tensor(torch.rand(4, 4, 8, requires_grad=True)).cuda()
w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda() w = ColoParameter.from_torch_tensor(torch.rand(16, 8, requires_grad=True)).cuda()
b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda() b = ColoParameter.from_torch_tensor(torch.rand(16, requires_grad=True)).cuda()
......
import torch import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
def test_sharding_spec(): def test_sharding_spec():
...@@ -11,7 +12,7 @@ def test_sharding_spec(): ...@@ -11,7 +12,7 @@ def test_sharding_spec():
# [8, 9, 10,11], # [8, 9, 10,11],
# [12,13,14,15]] # [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8, 6)) entire_shape = torch.Size((16, 8, 6))
dim_partition_dict = {0: [0, 1]} dim_partition_dict = {0: [0, 1]}
# DistSpec: # DistSpec:
# shard_sequence: S01,R,R # shard_sequence: S01,R,R
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment