Unverified Commit 4973157a authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] added sharding spec conversion for linear handler (#1687)

parent af718e83
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.tensor.sharding_spec import ShardingException
from .node_handler import ModuleHandler, NodeHandler from .node_handler import ModuleHandler, NodeHandler
from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData from ..sharding_strategy import ShardingStrategy_V2, OperationDataType, OperationData
from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2, BatchedMatMulStrategyGenerator from ..strategy import LinearProjectionStrategyGenerator, StrategyGenerator_V2, BatchedMatMulStrategyGenerator
from typing import List, Dict from typing import List, Dict, Union
from .registry import operator_registry from .registry import operator_registry
from copy import deepcopy
from .utils import switch_partition_dim, update_partition_dim
__all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler'] __all__ = ['LinearModuleHandler', 'LinearFunctionHandler', 'BMMFunctionHandler']
...@@ -24,14 +27,22 @@ class LinearModuleHandler(ModuleHandler): ...@@ -24,14 +27,22 @@ class LinearModuleHandler(ModuleHandler):
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies # use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process # the strategies will be transformed back to its original shape in self.post_process
input_meta_data = self.node.args[0]._meta_data
input_logical_shape = input_meta_data.view(-1, input_meta_data.shape[-1]).shape
physical_input_operand = OperationData(name=str(self.node.args[0]), physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG, type=OperationDataType.ARG,
data=self.node.args[0]._meta_data) data=input_meta_data,
logical_shape=input_logical_shape)
physical_other_operand = OperationData(name="weight", physical_other_operand = OperationData(name="weight",
type=OperationDataType.PARAM, type=OperationDataType.PARAM,
data=self.named_parameters['weight'], data=self.named_parameters['weight'],
logical_shape=self.named_parameters['weight'].shape[::-1]) logical_shape=self.named_parameters['weight'].shape[::-1])
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data) output_meta_data = self.node._meta_data
output_logical_shape = output_meta_data.view(-1, output_meta_data.shape[-1]).shape
physical_output = OperationData(name=str(self.node),
type=OperationDataType.OUTPUT,
data=output_meta_data,
logical_shape=output_logical_shape)
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
...@@ -42,28 +53,46 @@ class LinearModuleHandler(ModuleHandler): ...@@ -42,28 +53,46 @@ class LinearModuleHandler(ModuleHandler):
mapping['bias'] = physical_bias_operand mapping['bias'] = physical_bias_operand
return mapping return mapping
def post_process(self, strategy: ShardingStrategy_V2): def post_process(self, strategy: ShardingStrategy_V2) -> Union[ShardingStrategy_V2, List[ShardingStrategy_V2]]:
""" """
Convert the sharding spec of the weight parameter back to its original shape. Convert the sharding spec from the logical shape to the physical shape.
""" """
# switch the dimensions of the transposed weight
for op_data, sharding_spec in strategy.input_sharding_specs.items(): for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == "weight": if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict switch_partition_dim(sharding_spec, 0, -1)
# switch first and last dim of the linear module weight # create multiple sharding strategies for the inputs
first_dim_partition = dim_partition_dict.pop(-1, None) # as input can be multi-dimensinal and the partition dim is only 2D,
last_dim_partition = dim_partition_dict.pop(0, None) # we need to map the partition at dim 0 to one of the first few dimensions of the input
sharding_strategies = []
if first_dim_partition: input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
dim_partition_dict[0] = first_dim_partition output_op_data = strategy.get_op_data_by_name(str(self.node))
num_input_dims = input_op_data.data.dim()
if last_dim_partition: input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
dim_partition_dict[-1] = last_dim_partition
if 0 in input_sharding_spec.dim_partition_dict:
for i in range(num_input_dims - 1):
new_strategy = strategy.clone()
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
try:
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(new_strategy)
except ShardingException:
pass
else:
sharding_strategies.append(strategy)
# re-init the sharding spec return sharding_strategies
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
return strategy
@operator_registry.register(F.linear) @operator_registry.register(F.linear)
...@@ -118,20 +147,37 @@ class LinearFunctionHandler(NodeHandler): ...@@ -118,20 +147,37 @@ class LinearFunctionHandler(NodeHandler):
for op_data, sharding_spec in strategy.input_sharding_specs.items(): for op_data, sharding_spec in strategy.input_sharding_specs.items():
if op_data.name == str(self.node.args[1]): if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict switch_partition_dim(sharding_spec, 0, -1)
# switch first and last dim of the linear module weight # create multiple sharding strategies for the inputs
first_dim_partition = dim_partition_dict.pop(-1, None) # as input can be multi-dimensinal and the partition dim is only 2D,
last_dim_partition = dim_partition_dict.pop(0, None) # we need to map the partition at dim 0 to one of the first few dimensions of the input
sharding_strategies = []
if first_dim_partition: input_op_data = strategy.get_op_data_by_name(str(self.node.args[0]))
dim_partition_dict[0] = first_dim_partition output_op_data = strategy.get_op_data_by_name(str(self.node))
num_input_dims = input_op_data.data.dim()
if last_dim_partition: input_sharding_spec = strategy.get_sharding_spec_by_name(input_op_data.name)
dim_partition_dict[-1] = last_dim_partition
if 0 in input_sharding_spec.dim_partition_dict:
for i in range(num_input_dims - 1):
new_strategy = strategy.clone()
input_sharding_spec = new_strategy.get_sharding_spec_by_name(input_op_data.name)
output_sharding_spec = new_strategy.get_sharding_spec_by_name(output_op_data.name)
try:
update_partition_dim(sharding_spec=input_sharding_spec,
dim_mapping={0: i},
physical_shape=input_op_data.data.shape,
inplace=True)
update_partition_dim(sharding_spec=output_sharding_spec,
dim_mapping={0: i},
physical_shape=output_op_data.data.shape,
inplace=True)
sharding_strategies.append(new_strategy)
except ShardingException:
pass
else:
sharding_strategies.append(strategy)
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
return strategy return strategy
......
...@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod ...@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from typing import Dict, List from typing import Dict, List, Union
from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem from ..sharding_strategy import ShardingStrategy_V2, StrategiesVector, OperationData, TrainCycleItem
from ..strategy import StrategyGenerator_V2 from ..strategy import StrategyGenerator_V2
...@@ -72,17 +72,27 @@ class NodeHandler(ABC): ...@@ -72,17 +72,27 @@ class NodeHandler(ABC):
for generator in strategy_generators: for generator in strategy_generators:
strategies = generator.generate() strategies = generator.generate()
# postprocess a strategy
# postprocess can produce one strategy or multiple strategies
post_processed_strategies_map = map(self.post_process, strategies)
post_processed_strategies = []
for strategy in post_processed_strategies_map:
if isinstance(strategy, (list, tuple)):
post_processed_strategies.extend(strategy)
else:
post_processed_strategies.append(strategy)
# compute the resharding costs based on the previous node # compute the resharding costs based on the previous node
# strategies if specified # strategies if specified
if compute_resharding_cost: if compute_resharding_cost:
strategies = list(map(self.update_resharding_cost, strategies)) post_processed_strategies = list(map(self.update_resharding_cost, post_processed_strategies))
self.strategies_vector.extend(strategies)
self.strategies_vector.extend(post_processed_strategies)
strategies_vector = map(self.post_process, self.strategies_vector)
self.strategies_vector = list(strategies_vector)
return self.strategies_vector return self.strategies_vector
def post_process(self, strategy: ShardingStrategy_V2): def post_process(self, strategy: ShardingStrategy_V2) -> Union[ShardingStrategy_V2, List[ShardingStrategy_V2]]:
# tranform the strategy generated # tranform the strategy generated
# e.g. to process the sharding strategy for the transposed weights # e.g. to process the sharding strategy for the transposed weights
return strategy return strategy
......
import torch
from typing import Dict
from colossalai.tensor.sharding_spec import ShardingSpec
from copy import deepcopy
def switch_partition_dim(sharding_spec: ShardingSpec, dim1: int, dim2: int) -> ShardingSpec:
"""
Switch the sharding mesh dimensions for two tensor dimensions. This operation is in-place.
Args:
sharding_spec (ShardingSpec): the sharding spec for which partition dim are switched
dim1 (int): the tensor dimension to switch
dim2 (int): the tensor dimension to switch
"""
assert len(sharding_spec.entire_shape) == 2
dim_partition_dict = sharding_spec.dim_partition_dict
dim1_partition = dim_partition_dict.pop(dim1, None)
dim2_partition = dim_partition_dict.pop(dim2, None)
if dim1_partition:
dim_partition_dict[dim2] = dim1_partition
if dim2_partition:
dim_partition_dict[dim1] = dim2_partition
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
return sharding_spec
def update_partition_dim(sharding_spec: ShardingSpec,
dim_mapping: Dict[int, int],
physical_shape: torch.Size,
inplace: bool = False):
"""
This method is used to update the partition dim dict from the logical one to the physical one.
Args:
sharding_spec (ShardingSpec): the sharding spec for which partition dims are updated
dim_mapping (Dict[int, int]): the mapping from the logical tensor dimension to the physical tensor dimension
physical_shape (torch.Size): the physical shape for the tensor
"""
if inplace:
current_sharding_spec = sharding_spec
else:
current_sharding_spec = deepcopy(sharding_spec)
old_dim_partition_dict = current_sharding_spec.dim_partition_dict
new_dim_partition_dict = {}
# assign new dim
for old_dim, new_dim in dim_mapping.items():
mesh_dims = old_dim_partition_dict.pop(old_dim)
new_dim_partition_dict[new_dim] = mesh_dims
for tensor_dim, mesh_dims in old_dim_partition_dict.items():
if tensor_dim in new_dim_partition_dict:
raise KeyError(f"There are duplicated entries for the tensor sharding dimension {tensor_dim}")
else:
new_dim_partition_dict[tensor_dim] = mesh_dims
# update sharding spec
current_sharding_spec.__init__(device_mesh=sharding_spec.device_mesh,
entire_shape=physical_shape,
dim_partition_dict=new_dim_partition_dict)
return current_sharding_spec
from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
...@@ -121,16 +122,12 @@ class ShardingStrategy_V2: ...@@ -121,16 +122,12 @@ class ShardingStrategy_V2:
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None) communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None) memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes. input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
input_resharding_costs (Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.(default to None)
""" """
name: str name: str
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
compute_cost: TrainCycleItem = None compute_cost: TrainCycleItem = None
communication_cost: TrainCycleItem = None communication_cost: TrainCycleItem = None
memory_cost: TrainCycleItem = None memory_cost: TrainCycleItem = None
input_resharding_costs: Dict[OperationData, List[float]] = None
communication_actions: Dict[OperationData, CommSpec] = None communication_actions: Dict[OperationData, CommSpec] = None
resharding_costs: Dict[OperationData, Dict[ShardingSpec, TrainCycleItem]] = None resharding_costs: Dict[OperationData, Dict[ShardingSpec, TrainCycleItem]] = None
...@@ -169,6 +166,26 @@ class ShardingStrategy_V2: ...@@ -169,6 +166,26 @@ class ShardingStrategy_V2:
return sharding_spec return sharding_spec
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}") raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
def clone(self):
def _deepcopy_dict_vals(data: Dict):
return {k: deepcopy(v) for k, v in data.items()}
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs else None
communication_actions = _deepcopy_dict_vals(self.communication_actions) if self.communication_actions else None
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs else None
compute_cost = deepcopy(self.compute_cost)
communication_cost = deepcopy(self.communication_cost)
memory_cost = deepcopy(self.memory_cost)
return ShardingStrategy_V2(name=self.name,
sharding_specs=sharding_specs,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
communication_actions=communication_actions,
resharding_costs=resharding_costs)
class StrategiesVector(list): class StrategiesVector(list):
''' '''
......
...@@ -6,6 +6,8 @@ from enum import Enum ...@@ -6,6 +6,8 @@ from enum import Enum
from functools import reduce from functools import reduce
import operator import operator
__all__ = ['_DimSpec', 'ShardingException', 'ShardingSpec']
ALLGATHER_COST = 20 ALLGATHER_COST = 20
SHARD_COST = 5 SHARD_COST = 5
STEP_PENALTY = 6 STEP_PENALTY = 6
...@@ -136,6 +138,10 @@ class _DimSpec: ...@@ -136,6 +138,10 @@ class _DimSpec:
return difference return difference
class ShardingException(Exception):
pass
class ShardingSpec: class ShardingSpec:
''' '''
Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong Sharding spec for a tensor, it contains info of the logical device mesh this tensor belong
......
...@@ -3,14 +3,15 @@ import torch ...@@ -3,14 +3,15 @@ import torch
import torch.nn as nn import torch.nn as nn
from colossalai.fx import ColoTracer, ColoGraphModule from colossalai.fx import ColoTracer, ColoGraphModule
from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler from colossalai.auto_parallel.solver.op_handler.dot_handler_v2 import LinearModuleHandler, LinearFunctionHandler
from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import OperationData, OperationDataType, StrategiesVector, ShardingStrategy_V2
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
def test_linear_module_handler(): def test_linear_module_handler():
model = nn.Sequential(nn.Linear(16, 32).to('meta')) model = nn.Sequential(nn.Linear(16, 32).to('meta'))
tracer = ColoTracer() tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')}) graph = tracer.trace(model, meta_args={"input": torch.rand(2, 2, 4, 16).to('meta')})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
...@@ -34,9 +35,9 @@ def test_linear_module_handler(): ...@@ -34,9 +35,9 @@ def test_linear_module_handler():
assert mapping['input'].name == "input_1" assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 16]) assert mapping['input'].data.shape == torch.Size([2, 2, 4, 16])
assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 16]) assert mapping['input'].logical_shape == torch.Size([16, 16])
assert mapping['other'].name == "weight" assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta assert mapping['other'].data.is_meta
...@@ -52,11 +53,14 @@ def test_linear_module_handler(): ...@@ -52,11 +53,14 @@ def test_linear_module_handler():
assert mapping['output'].name == "_0" assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 32]) assert mapping['output'].data.shape == torch.Size([2, 2, 4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT assert mapping['output'].type == OperationDataType.OUTPUT
assert mapping['output'].logical_shape == torch.Size([16, 32])
strategies_vector = handler.register_strategy() strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector] strategy_name_list = [val.name for val in strategies_vector]
# one strategy will be converted to different physical sharding spec
assert len(strategy_name_list) > 8
# SS = SR x RS # SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list assert 'S0S1 = S0R x RS1' in strategy_name_list
...@@ -78,6 +82,19 @@ def test_linear_module_handler(): ...@@ -78,6 +82,19 @@ def test_linear_module_handler():
assert 'RS0 = RR x RS0' in strategy_name_list assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list assert 'RS1 = RR x RS1' in strategy_name_list
for strategy in strategies_vector:
strategy: ShardingStrategy_V2
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('_0')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
def test_linear_function_handler(): def test_linear_function_handler():
model = nn.Linear(16, 32).to('meta') model = nn.Linear(16, 32).to('meta')
...@@ -123,6 +140,8 @@ def test_linear_function_handler(): ...@@ -123,6 +140,8 @@ def test_linear_function_handler():
strategies_vector = handler.register_strategy() strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector] strategy_name_list = [val.name for val in strategies_vector]
# one strategy will be converted to different physical sharding spec
assert len(strategy_name_list) > 8
# SS = SR x RS # SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list assert 'S0S1 = S0R x RS1' in strategy_name_list
...@@ -144,6 +163,19 @@ def test_linear_function_handler(): ...@@ -144,6 +163,19 @@ def test_linear_function_handler():
assert 'RS0 = RR x RS0' in strategy_name_list assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list assert 'RS1 = RR x RS1' in strategy_name_list
for strategy in strategies_vector:
strategy: ShardingStrategy_V2
input_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('weight')
bias_sharding_spec = strategy.get_sharding_spec_by_name('bias')
output_sharding_spec = strategy.get_sharding_spec_by_name('linear')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[1] == input_sharding_spec.sharding_sequence[-1]
assert weight_sharding_spec.sharding_sequence[0] == output_sharding_spec.sharding_sequence[-1]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
if __name__ == '__main__': if __name__ == '__main__':
test_linear_module_handler() test_linear_module_handler()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment