Commit 08f2920e authored by zhuwenwen's avatar zhuwenwen
Browse files

init colossalai, support dtk2304

parent da3f0934
Pipeline #237 failed with stages
in 0 seconds
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['ReshapeGenerator']
class ReshapeGenerator(FollowingStrategyGenerator):
"""
ReshapeGenerator which deals with the sharding strategies of Reshape Op, such as torch.Tensor.permute.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in
# replica status and let the successor node choose the way to resharding the
# output node. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for reshape function.
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = {}
if isinstance(self.op_data["output"].data, tuple):
dim_partition_dict_for_output = [{} for _ in range(len(self.op_data["output"].data))]
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> FULLY REPLICATED_{index}'
total_mesh_dim_list = []
for mesh_dim_list in dim_partition_dict_for_input.values():
total_mesh_dim_list.extend(mesh_dim_list)
# if there is only one sharding dimension, we should use the value instead of list as logical_process_axis.
if len(total_mesh_dim_list) == 1:
total_mesh_dim_list = total_mesh_dim_list[0]
input_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=total_mesh_dim_list,
comm_type=CommType.BEFORE,
arg_index=0)
input_comm_action.comm_spec.gather_dim = total_mesh_dim_list
elif len(total_mesh_dim_list) >= 2:
source_spec = sharding_spec_mapping["input"]
target_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=source_spec.entire_shape,
dim_partition_dict={})
comm_spec = {'src_spec': source_spec, 'tgt_spec': target_spec}
input_comm_action = CommAction(comm_spec=comm_spec, comm_type=CommType.BEFORE, arg_index=0)
else:
input_comm_action = None
if input_comm_action is not None:
communication_action_mapping["input"] = input_comm_action
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
for strategy in strategy_list:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategy_list
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
__all__ = ['SoftmaxGenerator']
class SoftmaxGenerator(FollowingStrategyGenerator):
"""
SoftmaxGenerator is used to generate strategies for torch.nn.Softmax or F.softmax.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
'''
Compute the computation cost per device with this specific strategy.
'''
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
forward_compute_cost = output_size_product * 2
backward_compute_cost = input_size_product
total_compute_cost = forward_compute_cost + backward_compute_cost
compute_cost = TrainCycleItem(fwd=forward_compute_cost, bwd=backward_compute_cost, total=total_compute_cost)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
softmax_dim = self.op_data['softmax_dim'].data
if softmax_dim in dim_partition_dict_for_input:
recover_dims = dim_partition_dict_for_input.pop(softmax_dim)
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import operator
from abc import ABC, abstractmethod
from functools import reduce
from typing import Any, Dict, List, Union
import torch
from torch.fx import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.utils import convert_dim_partition_dict
class StrategyGenerator(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
TODO: remove the original strategy_generator.py after refactoring
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
# validate the whether operation data is of desired value
self.validate()
@property
def has_bias(self):
"""
A utility method to check for the existence of bias operand for convenience.
"""
return 'bias' in self.op_data
def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.PARAM
def is_buffer(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.BUFFER
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
communication_action_mapping: Dict[str, CommSpec]):
"""
A factory method to produce a ShardingStrategy object.
Args:
sharding_spec_mapping (Dict[str, ShardingSpec]): the mapping between the operation data name and the ShardingSpec object.
communication_action_mapping (Dict[str, CommSpec]): the mapping between the operation data name and the CommSpec object.
"""
sharding_specs = self.replace_op_name_with_op_data(sharding_spec_mapping)
communication_actions = self.replace_op_name_with_op_data(communication_action_mapping)
return ShardingStrategy(name=name, sharding_specs=sharding_specs, communication_actions=communication_actions)
def to_sharding_spec_mapping(self, mapping: Dict[str, Dict[int, List[int]]]):
"""
A utility method to convert the the dim partition dict to a ShardingSpec object.
Args:
mapping (Dict[str, Dict[int, List[int]]]): the key of the mapping is the operation data name and the value is a dim partition dictionary.
Notes:
The op_data.data is commonly type of torch.Tensor, torch.nn.Parameter, so the sharding spec is easy to create from the shape of the data.
However, if the op_data.data is of other non-iterative types, such as float or int, we should return None. If the op_data.data is of some iterative types, such as
list or tuple, we should return a list of ShardingSpec objects follow the same rule as above mentioned.
"""
results = {}
for op_data_name, dim_partition_dict in mapping.items():
if op_data_name in self.op_data:
op_data = self.op_data[op_data_name]
def _to_sharding_spec(
data: any, logical_shape: any,
dim_partition_dict: Dict[int, List[int]]) -> Union[ShardingSpec, List[ShardingSpec], None]:
"""
This is a recursive function to convert the dim partition dict to a ShardingSpec object.
"""
if isinstance(data, torch.Tensor):
dim_size = len(logical_shape)
dim_partition_dict = convert_dim_partition_dict(dim_size, dim_partition_dict)
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=logical_shape,
dim_partition_dict=dim_partition_dict)
return sharding_spec
elif isinstance(data, (list, tuple)):
sharding_spec = []
for data_element, logical_shape_element, dim_partition_dict_element in zip(
data, logical_shape, dim_partition_dict):
sharding_spec.append(
_to_sharding_spec(data_element, logical_shape_element, dim_partition_dict_element))
return sharding_spec
else:
return None
sharding_spec = _to_sharding_spec(op_data.data, op_data.logical_shape, dim_partition_dict)
results[op_data_name] = sharding_spec
return results
def replace_op_name_with_op_data(self, mapping: Dict[str, Any]):
"""
Convert the key of the dictionary from the operation data name to an OperationData object.
"""
results = {}
for k, v in mapping.items():
op_data = self.op_data[k]
results[op_data] = v
return results
def get_communication_spec(self, sharding_spec: ShardingSpec, communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]]):
"""
A factory method to produce a CommSpec object.
"""
return CommSpec(comm_pattern=communication_pattern,
sharding_spec=sharding_spec,
logical_process_axis=logical_process_axis)
def get_communication_action(self,
sharding_spec: ShardingSpec,
communication_pattern: CollectiveCommPattern,
logical_process_axis: Union[int, List[int]],
comm_type: CommType,
arg_index: int = -1,
key_for_kwarg: any = None) -> CommAction:
"""
A factory method to produce a CommAction object.
"""
return CommAction(comm_spec=self.get_communication_spec(sharding_spec=sharding_spec,
communication_pattern=communication_pattern,
logical_process_axis=logical_process_axis),
comm_type=comm_type,
arg_index=arg_index,
key_for_kwarg=key_for_kwarg)
def update_communication_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Compute the communication cost involved in the forward and backward iteration.
"""
comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
def _compute_and_add(op_data: OperationData, comm_spec: CommSpec):
num_ele_in_comm = comm_spec.get_comm_cost()
dtype = op_data.data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for phase, cost in num_ele_in_comm.items():
num_ele_in_comm[phase] = num_ele_in_comm[phase] * size_per_elem_bytes
comm_cost.fwd += num_ele_in_comm['forward']
comm_cost.bwd += num_ele_in_comm['backward']
comm_cost.total += num_ele_in_comm['total']
# check if communication action exists
# if so, loop over each action and compute the cost of each action
if strategy.communication_actions is not None:
for operand, comm_action in strategy.communication_actions.items():
if isinstance(comm_action, CommAction):
comm_spec = comm_action.comm_spec
else:
# this condition branch will be removed after all the handler updated.
comm_spec = comm_action
if isinstance(comm_spec, dict):
src_spec = comm_spec['src_spec']
tgt_spec = comm_spec['tgt_spec']
shape_consistency_manager = ShapeConsistencyManager()
_, comm_action_sequence, _ = shape_consistency_manager.shape_consistency(src_spec, tgt_spec)
for comm_spec_ in comm_action_sequence:
_compute_and_add(operand, comm_spec_)
else:
_compute_and_add(operand, comm_spec)
# update the communication cost attribute in-place
strategy.communication_cost = comm_cost
return strategy
@abstractmethod
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Customize this method to compute the computation flops.
"""
pass
@abstractmethod
def update_memory_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
"""
Customize this method to compute the memory cost in bytes.
"""
pass
def _compute_size_in_bytes(self, strategy: ShardingStrategy, key: str):
"""
Compute the size of a tensor in bytes.
Args:
strategy (ShardingStrategy): the ShardingStrategy generated.
key (str): the name of the operation data defined by the generator.
"""
op_data = self.op_data[key]
def _compute_size_in_bytes_helper(sharding_spec, meta_data):
sharded_shape = sharding_spec.get_sharded_shape_per_device()
if len(sharded_shape) == 0:
num_elements = 1
else:
num_elements = reduce(operator.mul, sharded_shape)
dtype = getattr(meta_data, 'dtype')
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
return num_elements * size_per_elem_bytes
if isinstance(op_data.data, tuple):
assert isinstance(strategy.sharding_specs[op_data], list), \
'sharding_spec of op_data should be a list of sharding specs if op_data.data is a tuple.'
total_bytes = 0
for index, sharding_spec in enumerate(strategy.sharding_specs[op_data]):
meta_data = op_data.data[index]
if isinstance(meta_data, torch.Tensor):
element_bytes = _compute_size_in_bytes_helper(sharding_spec, meta_data)
else:
# if meta_data is not a tensor, we count the memroy as 0
element_bytes = 0
total_bytes += element_bytes
else:
if isinstance(op_data.data, torch.Tensor):
total_bytes = _compute_size_in_bytes_helper(strategy.sharding_specs[op_data], op_data.data)
else:
# if op_data.data is not a tensor, we count the memroy as 0
total_bytes = 0
return total_bytes
def generate(self) -> List[ShardingStrategy]:
"""
Generate all possible sharding strategies for this operation.
"""
strategies = self.collate_strategies()
# some strategies may be None as ignore_sharding_exception may return None
# when ShardingSpecException occurs.
# thus, remove those None values
strategies = [strategy for strategy in strategies if strategy]
# update the costs
# update mete info on cost
# these update methods are all in-place, the default method will do nothing
# the cost info will only be added if the child class overrides these methods
for strategy in strategies:
self.update_communication_cost(strategy)
self.update_compute_cost(strategy)
self.update_memory_cost(strategy)
return strategies
@abstractmethod
def collate_strategies(self) -> List[ShardingStrategy]:
pass
@abstractmethod
def validate(self) -> bool:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass
class FollowingStrategyGenerator(StrategyGenerator):
"""
FollowingStrategyGenerator is used to generate the sharding strategies which depends on its predecessor node.
TODO: remove the original strategy_generator.py after refactoring
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_node: Node):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
self.predecessor_node = predecessor_node
class OutputStrategyGenerator(StrategyGenerator):
"""
OutputStrategyGenerator is used to generate the sharding strategies for Output Node.
"""
def __init__(self, operation_data_mapping: Dict[str, OperationData], device_mesh: DeviceMesh,
predecessor_nodes: List[Node]):
super().__init__(operation_data_mapping, device_mesh)
self.predecessor_nodes = predecessor_nodes
import copy
import operator
from functools import reduce
from typing import List
from colossalai.auto_parallel.tensor_shard.node_handler.strategy.strategy_generator import FollowingStrategyGenerator
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
check_keep_sharding_status,
detect_reshape_mapping,
infer_output_dim_partition_dict,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = ['SumGenerator']
class SumGenerator(FollowingStrategyGenerator):
"""
SumGenerator deals with the sharding strategies of torch.sum op.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
sharded_input_shape = strategy.sharding_specs[self.op_data['input']].get_sharded_shape_per_device()
sharded_output_shape = strategy.sharding_specs[self.op_data['output']].get_sharded_shape_per_device()
input_size_product = reduce(operator.mul, sharded_input_shape)
output_size_product = reduce(operator.mul, sharded_output_shape)
compute_cost = TrainCycleItem(fwd=input_size_product,
bwd=output_size_product,
total=input_size_product + output_size_product)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = copy.deepcopy(input_sharding_spec.dim_partition_dict)
sum_dims, sum_mapping_dict = self.op_data['sum_info'].data
# TODO: a better way to handle the distributed sum is sum all the data on chip and then do all reduce
# among all the shard groups
recover_dims = []
dim_partition_dict_for_output = {}
for dim in dim_partition_dict_for_input:
if dim in sum_dims:
recover_dims.append(dim)
elif dim in sum_mapping_dict:
dim_partition_dict_for_output[sum_mapping_dict[dim]] = dim_partition_dict_for_input[dim]
else:
raise RuntimeError(f'dim {dim} is not in sum_mapping_dict or sum_dims')
for dim in recover_dims:
dim_partition_dict_for_input.pop(dim)
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern
from colossalai.tensor.sharding_spec import ShardingSpec
from .strategy_generator import StrategyGenerator
__all__ = ['TensorConstructorGenerator']
class TensorConstructorGenerator(StrategyGenerator):
"""
TensorConstructorGenerator which deals with
the sharding strategies for tensor constructor operation, such as torch.arange.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {'output': self._compute_size_in_bytes(strategy, "output")}
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
bwd_mem_cost = MemoryCost(activation=0, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
dim_partition_dict_mapping = {
"output": {},
}
communication_action_mapping = {}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = 'Replica Tensor Constructor'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem)
from .strategy_generator import FollowingStrategyGenerator
__all__ = ['UnaryElementwiseGenerator']
class UnaryElementwiseGenerator(FollowingStrategyGenerator):
"""
UnaryElementwiseGenerator which deals with the sharding strategies of UnaryElementwiseOp.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'input': self._compute_size_in_bytes(strategy, "input"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = input + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items() if not self.is_param(k)])
fwd_parameter_cost = sum([v for k, v in forward_size_mapping.items() if self.is_param(k)])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=fwd_parameter_cost)
# compute bwd cost incurred
# bwd_cost = input_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items() if not self.is_param(k)])
bwd_parameter_cost = sum([v for k, v in backward_size_mapping.items() if self.is_param(k)])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=bwd_parameter_cost)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost,
parameter=fwd_parameter_cost + bwd_parameter_cost)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = []
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise function.
for index, strategy in enumerate(self.predecessor_node.strategies_vector):
dim_partition_dict_mapping = {}
communication_action_mapping = {}
input_sharding_spec = strategy.output_sharding_specs[self.op_data["input"]]
dim_partition_dict_for_input = input_sharding_spec.dim_partition_dict
dim_partition_dict_for_output = copy.deepcopy(dim_partition_dict_for_input)
dim_partition_dict_mapping = {
"input": dim_partition_dict_for_input,
"output": dim_partition_dict_for_output,
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{sharding_spec_mapping["input"].sharding_sequence} -> {sharding_spec_mapping["output"].sharding_sequence}_{index}'
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
strategy_list.append(strategy)
return strategy_list
import copy
from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from .strategy_generator import StrategyGenerator
__all__ = ['WhereGenerator']
class WhereGenerator(StrategyGenerator):
"""
WhereGenerator is a generic class to generate strategies for Where operation.
"""
def validate(self) -> bool:
return super().validate()
def update_compute_cost(self, strategy: ShardingStrategy):
compute_cost = TrainCycleItem(fwd=10, bwd=10, total=20)
strategy.compute_cost = compute_cost
def update_memory_cost(self, strategy: ShardingStrategy):
'''
Compute the memory cost per device with this specific strategy.
'''
forward_size_mapping = {
'condition': self._compute_size_in_bytes(strategy, "condition"),
'x': self._compute_size_in_bytes(strategy, "x"),
'y': self._compute_size_in_bytes(strategy, "y"),
'output': self._compute_size_in_bytes(strategy, "output")
}
backward_size_mapping = copy.deepcopy(forward_size_mapping)
backward_size_mapping.pop("output")
# compute fwd cost incurred
# fwd_cost = condition + x + y + output
fwd_activation_cost = sum([v for k, v in forward_size_mapping.items()])
fwd_mem_cost = MemoryCost(activation=fwd_activation_cost, parameter=0)
# compute bwd cost incurred
# bwd_cost = condition_grad + x_grad + y_grad
bwd_activation_cost = sum([v for k, v in backward_size_mapping.items()])
bwd_mem_cost = MemoryCost(activation=bwd_activation_cost, parameter=0)
# compute total cost
total_mem_cost = MemoryCost(activation=fwd_activation_cost + bwd_activation_cost, parameter=0)
memory_cost = TrainCycleItem(fwd=fwd_mem_cost, bwd=bwd_mem_cost, total=total_mem_cost)
strategy.memory_cost = memory_cost
@ignore_sharding_exception
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_mapping = {
"condition": dim_partition,
"x": dim_partition,
"y": dim_partition,
"output": dim_partition
}
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["condition"].sharding_sequence} x {sharding_spec_mapping["x"].sharding_sequence} x {sharding_spec_mapping["y"].sharding_sequence}'
communication_action_mapping = {}
strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping,
communication_action_mapping=communication_action_mapping)
return strategy
def enumerate_all_possible_output_spec(self, mesh_dim_0, mesh_dim_1, dimension_length):
dim_partition_list = []
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_0, dimension_length))
dim_partition_list.extend(enumerate_all_possible_1d_sharding(mesh_dim_1, dimension_length))
dim_partition_list.extend(enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dimension_length))
# append {} for non_split case
dim_partition_list.append({})
return dim_partition_list
def collate_strategies(self) -> List[ShardingStrategy]:
'''
Generate every possible strategies for a where node, and record all strategies into the strategies_vector.
'''
strategy_list = []
dimension_length = len(self.op_data["output"].logical_shape)
dim_partition_list = self.enumerate_all_possible_output_spec(0, 1, dimension_length)
for dim_partition in dim_partition_list:
strategy = self._generate_strategy_with_dim_partition(dim_partition)
strategy_list.append(strategy)
return strategy_list
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, SumGenerator
__all__ = ['SumHandler']
@operator_registry.register(torch.Tensor.sum)
@operator_registry.register(torch.sum)
class SumHandler(NodeHandler):
"""
A SumHandler which deals with the sharding strategies for torch.sum or torch.Tensor.sum.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(SumGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# check if the input operand is a parameter
if isinstance(self.node.args[0]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
input_data = self.node.args[0]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[0]), type=data_type, data=input_data)
if len(self.node.args) > 1:
sum_dims = self.node.args[1]
else:
sum_dims = tuple(range(self.node.args[0]._meta_data.dim()))
if isinstance(sum_dims, int):
sum_dims = (sum_dims,)
# recover negative value to positive
num_dims = self.node.args[0]._meta_data.dim()
for i in range(len(sum_dims)):
if sum_dims[i] < 0:
sum_dims[i] += num_dims
# mapping the input dims to output dims
# For examples:
# input: torch.rand(2, 3, 4, 5)
# output: torch.sum(input, (0, 2))
# sum_mapping_dict = {1: 0, 3: 1}
# sum_mapping_dict[1] = 0 means the 0th dim of output is the 1st dim of input
# sum_mapping_dict[3] = 1 means the 1st dim of output is the 3rd dim of input
sum_mapping_dict = {}
if 'keepdim' in self.node.kwargs and self.node.kwargs['keepdim']:
for i in range(num_dims):
sum_mapping_dict.update({i: i})
else:
output_index = 0
for i in range(num_dims):
if i not in sum_dims:
sum_mapping_dict.update({i: output_index})
output_index += 1
assert output_index == self.node._meta_data.dim()
sum_info = (sum_dims, sum_mapping_dict)
physical_shape_operand = OperationData(name='sum_info', type=OperationDataType.ARG, data=sum_info)
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {
"input": physical_input_operand,
"sum_info": physical_shape_operand,
"output": physical_output_operand
}
return mapping
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator
from .strategy.tensor_constructor_generator import TensorConstructorGenerator
__all__ = ['TensorConstructorHandler']
@operator_registry.register(torch.arange)
class TensorConstructorHandler(NodeHandler):
"""
A TensorConstructorHandler which deals with the sharding strategies for tensor constructor operations, such as torch.arange.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(TensorConstructorGenerator(op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
output_data = self.node._meta_data
physical_output_operand = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=output_data)
mapping = {"output": physical_output_operand}
return mapping
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, UnaryElementwiseGenerator
__all__ = ['UnaryElementwiseHandler']
@operator_registry.register(torch.Tensor.to)
@operator_registry.register(torch.Tensor.type)
@operator_registry.register(torch.abs)
@operator_registry.register(torch.nn.ReLU)
@operator_registry.register(torch.nn.Tanh)
@operator_registry.register(torch.tanh)
@operator_registry.register(torch.nn.modules.dropout.Dropout)
@operator_registry.register(torch.Tensor.contiguous)
@operator_registry.register(torch.nn.functional.dropout)
class UnaryElementwiseHandler(NodeHandler):
"""
A UnaryElementwiseHandler which deals with the sharding strategies for UnaryElementwise Op.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(UnaryElementwiseGenerator(op_data_mapping, self.device_mesh, self.node.args[0]))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_input_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {"input": physical_input_operand, "output": physical_output}
return mapping
import copy
import operator
from typing import Dict, List
import torch
from ..sharding_strategy import OperationData, OperationDataType, ShardingStrategy, StrategiesVector
from ..utils import recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import StrategyGenerator, WhereGenerator
__all__ = ['WhereHandler']
@operator_registry.register(torch.where)
class WhereHandler(NodeHandler):
"""
A WhereHandler which deals with the sharding strategies for torch.where.
"""
def get_strategy_generator(self) -> List[StrategyGenerator]:
logical_op_data_mapping, _ = self.get_operation_data_mapping()
generators = []
generators.append(WhereGenerator(logical_op_data_mapping, self.device_mesh))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# use transposed shape for strategies
# the strategies will be transformed back to its original shape in self.post_process
physical_condition_operand = OperationData(name=str(self.node.args[0]),
type=OperationDataType.ARG,
data=self.node.args[0]._meta_data)
physical_x_operand = OperationData(name=str(self.node.args[1]),
type=OperationDataType.ARG,
data=self.node.args[1]._meta_data)
physical_y_operand = OperationData(name=str(self.node.args[2]),
type=OperationDataType.ARG,
data=self.node.args[2]._meta_data)
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
physical_mapping = {
"condition": physical_condition_operand,
"x": physical_x_operand,
"y": physical_y_operand,
"output": physical_output
}
logical_shape_for_all = self.node._meta_data.shape
logical_mapping = {}
for key, physical_operand in physical_mapping.items():
logical_mapping[key] = self.convert_physical_operand_to_logical_operand(physical_operand,
logical_shape_for_all)
return logical_mapping, physical_mapping
def convert_physical_operand_to_logical_operand(self, physical_operand, target_shape):
logical_operand = copy.deepcopy(physical_operand)
logical_operand.logical_shape = target_shape
return logical_operand
def post_process(self, strategy: ShardingStrategy):
logical_op_data_mapping, physical_op_data_mapping = self.get_operation_data_mapping()
for key in logical_op_data_mapping.keys():
logical_sharding_spec = strategy.sharding_specs[logical_op_data_mapping[key]]
logical_shape = logical_op_data_mapping[key].logical_shape
physical_shape = physical_op_data_mapping[key].logical_shape
physical_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
logical_sharding_spec, logical_shape, physical_shape)
strategy.sharding_specs.pop(logical_op_data_mapping[key])
strategy.sharding_specs[physical_op_data_mapping[key]] = physical_sharding_spec
strategy.name = f"{strategy.sharding_specs[physical_op_data_mapping['output']].sharding_sequence} = {strategy.sharding_specs[physical_op_data_mapping['condition']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['x']].sharding_sequence} x {strategy.sharding_specs[physical_op_data_mapping['y']].sharding_sequence}"
return strategy
from copy import deepcopy
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, List, Tuple, Union
import torch
from torch.fx.node import Node
from colossalai.tensor.shape_consistency import CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from .constants import (
BCAST_FUNC_OP,
ELEMENTWISE_FUNC_OP,
ELEMENTWISE_METHOD_OP,
ELEMENTWISE_MODULE_OP,
RESHAPE_FUNC_OP,
RESHAPE_METHOD_OP,
)
__all__ = ['OperationDataType', 'OperationData', 'TrainCycleItem', 'MemoryCost', 'ShardingStrategy', 'StrategiesVector']
class OperationDataType(Enum):
"""
An operation can come from the argument list of an operator or the parameter list of a module.
"""
INPUT = 0
ARG = 1
PARAM = 2
BUFFER = 3
OUTPUT = 4
@dataclass
class OperationData:
"""
OperationData is the data related to an operator, the data can be the operand or the output.
Args:
name (str): the name of the operation-related data
type (OperationDataType): the type of the operation data
data (Any): the value for this data, usually it is a meta tensor.
logical_shape (Tuple[int]): the logical shape of the data, it can be different from the its actual shape in memory.
"""
name: str
type: OperationDataType
data: Any
logical_shape: Tuple[int] = None
def __post_init__(self):
# if no logical shape is specified, use the data shape as the logical shape
if self.logical_shape is None:
def _infer_logical_shape(data: any):
"""
This function is used to infer the logical shape of the data.
"""
if isinstance(data, torch.Tensor):
return data.shape
elif isinstance(data, torch.Size):
return None
elif isinstance(data, (tuple, list)):
data_type = type(data)
return data_type([_infer_logical_shape(d) for d in data])
else:
return None
self.logical_shape = _infer_logical_shape(self.data)
def __repr__(self) -> str:
return f'OperationData(name={self.name}, type={self.type})'
def __eq__(self, other) -> bool:
return other.name == self.name
def __hash__(self) -> int:
return hash(f'{self.name}')
@dataclass
class TrainCycleItem:
"""
TrainCycleItem is a dataclass to store the items which have different values for the forward and backward pass
in a training iteration.
Args:
fwd (float): the item for the forward pass
bwd (float): the item for the backward pass
"""
fwd: Any
bwd: Any
total: Any
@dataclass
class MemoryCost:
"""
MemoryCost is a dataclass which stores the memory usage in the program.
Args:
activation (int): the memory cost incurred by the activations in bytes.
parameter (int): the memory cost incurred by the module parameter in bytes.
temp (int): the memory cost incurred by the temporary tensors in bytes.
buffer (int): the memory cost incurred by the module buffer in bytes.
"""
activation: int = 0
parameter: int = 0
temp: int = 0
buffer: int = 0
class CommType(Enum):
"""
CommType describes the sequential order of a communication action and a computation action.
Meaning:
BEFORE: the communication action happens just before the computation operation.
AFTER: the communication action happens after the computation operation.
HOOK: the communication action is used to do the grad all reduce.
IMPLICIT: the communication action happens during the kernel execution, such as SyncBatchNorm
"""
BEFORE = 0
AFTER = 1
HOOK = 2
IMPLICIT = 3
@dataclass
class CommAction:
"""
CommAction is used to record the communication action.
Args:
comm_spec: express the communication pattern and the process groups to execute the communication action.
comm_type: describes the sequential order of a communication action and a computation action.
arg_index: record the location of tensor which join the communication, we cannot use name of node or op_data at runtime,
because the args of node may be changed by graph transform passes.
"""
comm_spec: CommSpec = None
comm_type: CommType = None
arg_index: int = -1
key_for_kwarg: any = None
@dataclass
class ShardingStrategy:
"""
ShardingStrategy is a dataclass to store the meta information on tensor sharding for a node.
Args:
name (str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
output_sharding_spec (ShardingSpec): ShardingSpec of the output node.
compute_cost (TrainCycleItem): Computation cost to complete this strategy. (default to None)
communication_cost (TrainCycleItem): Communication cost to complete this strategy. (default to None)
memory_cost (TrainCycleItem): Memory cost of the output node using this strategy. (default to None)
input_sharding_specs (List(ShardingSpec)): The ShardingSpecs of the input nodes.
"""
name: str
sharding_specs: Dict[OperationData, Union[ShardingSpec, Tuple[ShardingSpec]]] = None
compute_cost: TrainCycleItem = None
communication_cost: TrainCycleItem = None
memory_cost: TrainCycleItem = None
communication_actions: Dict[OperationData, CommAction] = None
resharding_costs: Dict[Node, List[TrainCycleItem]] = None
@property
def input_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
specs = {}
specs.update(self._get_sharding_spec(OperationDataType.ARG))
specs.update(self._get_sharding_spec(OperationDataType.PARAM))
return specs
@property
def argument_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.ARG)
@property
def param_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.PARAM)
@property
def output_sharding_specs(self) -> Dict[OperationData, ShardingSpec]:
return self._get_sharding_spec(OperationDataType.OUTPUT)
def _get_sharding_spec(self, operation_data_type: OperationDataType):
specs = {k: v for k, v in self.sharding_specs.items() if k.type == operation_data_type}
return specs
def get_op_data_by_name(self, name: str):
for op_data in self.sharding_specs.keys():
if op_data.name == name:
return op_data
raise KeyError(f"Could not find the OperationData with name {name}")
def get_sharding_spec_by_name(self, name: str):
for op_data, sharding_spec in self.sharding_specs.items():
if op_data.name == name:
return sharding_spec
raise KeyError(f"Could not find the ShardingSpec for OperationData with name {name}")
def clone(self):
def _deepcopy_dict_vals(data: Dict):
return {k: deepcopy(v) for k, v in data.items()}
sharding_specs = _deepcopy_dict_vals(self.sharding_specs) if self.sharding_specs is not None else None
# We need to deepcopy it when self.communication_actions is not None, instead of checking its __bool__ value.
# Consider the examples below:
# If self.communication_actions is an empty dictionary {}, then self.communication_actions is not None, but its __bool__ value is False.
# In this case, if we set None to the new object, program will crash when we try to access the communication_actions.items.
communication_actions = _deepcopy_dict_vals(
self.communication_actions) if self.communication_actions is not None else None
# same reason as communication_actions
resharding_costs = _deepcopy_dict_vals(self.resharding_costs) if self.resharding_costs is not None else None
compute_cost = deepcopy(self.compute_cost)
communication_cost = deepcopy(self.communication_cost)
memory_cost = deepcopy(self.memory_cost)
return ShardingStrategy(name=self.name,
sharding_specs=sharding_specs,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
communication_actions=communication_actions,
resharding_costs=resharding_costs)
class StrategiesVector(list):
'''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
Argument:
node (Node): node for which the list of sharding strategies are generated.
'''
def __init__(self, node: Node):
super().__init__()
self.node = node
# fetch its input and output nodes
# TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys())
self.successor_nodes = list(node.users.keys())
def check_merge(self):
merge_label = False
if self.node.op == 'call_module':
target = self.node.target
root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target)
submod_type = type(submod)
# merge elementwise module node into source nodes
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True
if self.node.op == 'call_function':
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
# TODO: remove this after we support the fall back logic.
# if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
# merge_label = True
# we could merge reshape op, because their computation costs are negligible.
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True
if self.node.op == 'call_method':
# we could merge reshape op, because their computation costs are negligible.
method = getattr(self.node.args[0]._meta_data.__class__, self.node.target)
if method in RESHAPE_METHOD_OP:
merge_label = True
if method in ELEMENTWISE_METHOD_OP:
merge_label = True
return merge_label
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .options import SolverOptions
from .solver import Solver
from .strategies_constructor import StrategiesConstructor
__all__ = ['GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph', 'SolverOptions']
import torch
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
class CostGraph:
'''
A graph data structure to simplify the edge cost graph. It has two main functions:
1. To feed the quadratic resharding costs into solver, we need to linearize it. We build edge_cost in
CostGraph, and it stored every combinations of strategies for a src-dst node pair in an 1D list.
2. To reduce the searching space, we merge computationally-trivial operators, such as
element-wise operators, transpose, and reduction, into their following nodes. The merging infomation will
be given by the StrategiesVector depending on the type of target node and following nodes.
Argument:
leaf_strategies(List[StrategiesVector]): It stores StrategiesVector of every nodes on the graph.
simplify(bool, optional): The generated cost graph will be simplified if it is true. (default to True)
'''
def __init__(self, leaf_strategies, simplify=True, forward_only=False):
self.leaf_strategies = leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
# stores number of strategies in each node
self.node_lens = {strategies_vector.node: len(strategies_vector) for strategies_vector in self.leaf_strategies}
# extra_node_costs will store the extra costs introduced by merging nodes
self.extra_node_costs = {}
self.following_dict = {}
self.simplify = simplify
self.forward_only = forward_only
self._build_cost_graph()
def _remove_invalid_node(self, node, attr_name):
remove_list = []
target_node_list = getattr(node, attr_name, [])
for target_node in target_node_list:
if target_node not in self.nodes:
remove_list.append(target_node)
for element in remove_list:
target_node_list.remove(element)
def _build_cost_graph(self):
'''
This method will generate edge_cost for adjacent node pair. Additionally, 'parents' and 'children' attribute will be
set to node.
'''
self.edge_costs = {}
if self.simplify:
self.merge_pair = []
for strategies_vector in self.leaf_strategies:
# build edge_cost
dst_node = strategies_vector.node
for src_node in strategies_vector.predecessor_nodes:
if src_node not in self.nodes:
continue
node_pair = (src_node, dst_node)
edge_cost = {}
for i in range(len(strategies_vector)):
for j in range(len(src_node.strategies_vector)):
resharding_cost_item = strategies_vector[i].resharding_costs[src_node][j]
if self.forward_only:
edge_cost[(j, i)] = resharding_cost_item.fwd
else:
edge_cost[(j, i)] = resharding_cost_item.total
self.edge_costs[node_pair] = edge_cost
# add parents and children attribute to node
# parent_nodes = [node for node in strategies_vector.predecessor_nodes]
# children_nodes = [node for node in strategies_vector.successor_nodes]
parent_nodes = []
children_nodes = []
def _check_tensor_in_node(data):
"""
This method is used to check whether the data has a tensor inside or not.
"""
has_tensor_flag = False
if isinstance(data, torch.Tensor):
return True
elif isinstance(data, (tuple, list)):
for d in data:
has_tensor_flag = has_tensor_flag or _check_tensor_in_node(d)
return has_tensor_flag
for node in strategies_vector.predecessor_nodes:
if _check_tensor_in_node(node._meta_data):
parent_nodes.append(node)
for node in strategies_vector.successor_nodes:
if _check_tensor_in_node(node._meta_data):
children_nodes.append(node)
setattr(dst_node, 'parents', parent_nodes)
setattr(dst_node, 'children', children_nodes)
if self.simplify and strategies_vector.check_merge():
for followed_node in strategies_vector.predecessor_nodes:
# we only merge node pairs which src node has a tensor element inside.
# This is necessay because the node without a tensor element inside will not
# be assigned any strategy.
if _check_tensor_in_node(followed_node._meta_data):
self.merge_pair.append((followed_node, dst_node))
def get_edge_cost(self, src_node, dst_node):
return self.edge_costs[(src_node, dst_node)]
def merge_node(self, src_node, dst_node):
'''
To merge dst_node into src_node, we need to do it in following steps:
1. For each strategy in dst_node, we need to pick an appropriate strategy
of src_node to merge, it is important because the logical resharding costs
between the parents node of src_node and merged node depend on the src_node
strategies dispatching. For example, for the graph 0->1->2, after merging node 1
into node 2, edge_costs[(node 0, node 2)][(0, 0)] = edge_costs[(node 0, node 1)][(0, x)]
x represents the picking strategy of node 1 merged into node 2 strategy 0.
2. We need to accumulate the extra costs introduced by merging nodes, the extra costs
contains two parts, one is resharding costs between src_node strategy and dst_node strategy,
another is the origin extra costs in src_node strategy.
3. Build connections between new node pairs, and remove the src_node after all consumer nodes
detached from it.
Argument:
src_node(Node): The node will be merged into dst_node.
dst_node(Node): The node to integrate src_node.
'''
# build merge_map
merge_map = {}
for src_index, _ in enumerate(src_node.strategies_vector):
min_cost = INFINITY_COST
lowest_cost_index = -1
for dst_index, dst_strategy in enumerate(dst_node.strategies_vector):
resharding_cost_item = dst_strategy.resharding_costs[src_node][src_index]
if self.forward_only:
resharding_cost = resharding_cost_item.fwd
else:
resharding_cost = resharding_cost_item.total
if resharding_cost <= min_cost:
min_cost = resharding_cost
lowest_cost_index = dst_index
merge_map[src_index] = lowest_cost_index
# extra_node_cost for src node
self.extra_node_costs[src_node] = [0.0] * self.node_lens[src_node]
for src_index, strategy in enumerate(src_node.strategies_vector):
target_strate_index = merge_map[src_index]
target_strategy = dst_node.strategies_vector[target_strate_index]
resharding_cost_item = target_strategy.resharding_costs[src_node][src_index]
if self.forward_only:
resharding_cost_to_add = resharding_cost_item.fwd
else:
resharding_cost_to_add = resharding_cost_item.total
self.extra_node_costs[src_node][src_index] += resharding_cost_to_add
if dst_node in self.extra_node_costs:
self.extra_node_costs[src_node][src_index] += self.extra_node_costs[dst_node][target_strate_index]
# add new node pair to cost graph
for child_node in dst_node.children:
new_node_pair = (src_node, child_node)
old_node_pair = (dst_node, child_node)
if new_node_pair in self.edge_costs:
continue
edge_cost = {}
for i in range(self.node_lens[src_node]):
for j in range(self.node_lens[child_node]):
dst_strate_index = merge_map[i]
edge_cost[(i, j)] = self.edge_costs[old_node_pair][(dst_strate_index, j)]
if new_node_pair not in self.edge_costs:
self.edge_costs[new_node_pair] = edge_cost
else:
# we should accumulate the resharding costs if args of child node contain
# both src node and dst node.
for index_pair, resharding_cost in self.edge_costs[new_node_pair]:
self.edge_costs[new_node_pair][index_pair] += edge_cost[index_pair]
# connect src node and children of dst node
dst_node.parents.remove(src_node)
src_node.children.remove(dst_node)
self.edge_costs.pop((src_node, dst_node))
for child_node in dst_node.children:
if child_node not in src_node.children:
src_node.children.append(child_node)
if src_node not in child_node.parents:
child_node.parents.append(src_node)
# remove dst node from cost graph when dst node has no producer.
if len(dst_node.parents) == 0:
child_node.parents.remove(dst_node)
node_pair = (dst_node, child_node)
self.edge_costs.pop(node_pair)
if len(dst_node.parents) == 0:
self.following_dict[dst_node] = src_node
dst_node.children = []
def _reindexing_src(self, src):
if src not in self.following_dict:
return src
return self._reindexing_src(self.following_dict[src])
def simplify_graph(self):
if not self.simplify:
return
self.merge_pair.reverse()
for (src_node, dst_node) in self.merge_pair:
self.merge_node(src_node, dst_node)
self.merge_pair.reverse()
reindexing_following_dict = {}
for dst, src in self.following_dict.items():
reindexing_following_dict[dst] = self._reindexing_src(src)
self.following_dict = reindexing_following_dict
from dataclasses import dataclass
from typing import List
from torch.fx.graph import Graph
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node
from colossalai.fx.passes.utils import get_node_module
__all__ = ['LiveVariable', 'LiveVariableVector', 'LiveStage', 'GraphAnalyser']
@dataclass
class LiveVariable:
"""
LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
"""
name: str
node: Node
is_inplace: bool
class LiveVariableVector(list):
"""
LiveVariableVector is a data structure to store the list of LiveVariable objects.
"""
def exists(self, name) -> bool:
"""
Check if a variable has already existed in the current list by name.
"""
for var in self:
if name == var.name:
return True
return False
def get(self, name) -> LiveVariable:
for var in self:
if name == var.name:
return var
raise KeyError(f"Variable {name} is not found")
def copy(self) -> "LiveVariableVector":
"""
Create a copy of this vector
"""
vector = LiveVariableVector()
for var in self:
vector.append(var)
return vector
@dataclass
class LiveStage:
"""
LiveStage is a data structure to record the living variables at this current node.
"""
name: str
node: Node
all_live_vars: LiveVariableVector
unique_live_vars: LiveVariableVector
class GraphAnalyser:
def __init__(self, gm: GraphModule):
self._gm = gm
self._graph = gm.graph
@property
def gm(self) -> GraphModule:
"""
Return the GraphModule object associated with this analyser.
"""
return self._gm
@property
def graph(self) -> Graph:
"""
Return the Graph object associated with this analyser.
"""
return self._graph
def liveness_analysis(self) -> List[LiveStage]:
"""
Analyse the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
"""
compute_nodes = self.graph.nodes
liveness_list = []
# checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage.
# this can be different from the `checked list`` as some varialbes may be destroyed prior to this stage.
# unique: record the unique live variables only exist until the current stage.
# this is different from `all list` as some variables are duplicated.
checked_variables = LiveVariableVector()
all_live_variables = LiveVariableVector()
unique_live_vars = LiveVariableVector()
for idx, node in enumerate(compute_nodes):
#############################
# find new living variables #
#############################
# detect whether the current op is an in-place op
# if it is an in-place op, we would deem it as a duplciate var
is_inplace = False
if node.op == 'call_function':
# check if this is an inplace op such as torch.nn.functional.relu(x, inplace=True)
if node.kwargs.get('inplace', False):
is_inplace = True
elif node.op == 'call_module':
# to check if this is an inplace op such as torch.nn.Relu(inplace=True)
module = get_node_module(node)
if getattr(module, 'inplace', False):
is_inplace = True
# add the output var
meta = getattr(node, '_meta_data', None)
live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace:
unique_live_vars.append(live_var)
checked_variables.append(live_var)
all_live_variables.append(live_var)
# check if any input is not checked yet
for arg in node.args:
if not isinstance(arg, Node):
continue
arg_name = arg.name
if not checked_variables.exists(arg_name):
live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
all_live_variables.append(live_var_from_arg)
checked_variables.append(live_var_from_arg)
unique_live_vars.append(live_var_from_arg)
# TODO: add the logic to remove live variables
# this should be completed if we are able to trace the backward compute graph
# add this stage to liveness dict
stage = LiveStage(name=node.name,
node=node,
all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy())
# if a LiveStage is covered by another LiveStage, we just keep the larger one.
replace = False
for index, prev_stage in enumerate(liveness_list):
all_covered = True
for ele in prev_stage.unique_live_vars:
if ele not in stage.unique_live_vars:
all_covered = False
break
if all_covered:
replace = True
break
if replace:
liveness_list[index] = stage
else:
liveness_list.append(stage)
return liveness_list
def get_alias_set(self):
pass
from dataclasses import dataclass
from enum import Enum
__all__ = ['SolverOptions']
class SolverPerference(Enum):
"""
This enum class is to define the solver preference.
"""
STANDARD = 0
DP = 1
TP = 2
class DataloaderOption(Enum):
"""
This enum class is to define the dataloader option.
"""
REPLICATED = 0
DISTRIBUTED = 1
@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
solver_perference: SolverPerference = SolverPerference.STANDARD
dataloader_option: DataloaderOption = DataloaderOption.REPLICATED
import multiprocessing
import time
import warnings
from typing import Dict
import numpy as np
from torch.fx.graph import Graph
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.constants import INFINITY_COST
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .strategies_constructor import StrategiesConstructor
try:
import pulp
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
warnings.warn(f'please install the pulp')
__all___ = ['Solver']
class Solver:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser,
memory_budget: float = -1.0,
solution_numbers: int = 1,
forward_only: bool = False,
memory_increasing_coefficient: float = 1.3,
verbose=True):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
self.forward_only = forward_only
if self.solution_numbers > 1:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis()
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
# The last objective value of the best ILP solution.
self.last_objective = None
self.verbose = verbose
def _recover_merged_node_strategy(self):
'''
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
input_best_strategy_index = self.last_s_val[node_index - 1]
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
for strategy_index, strategy in enumerate(node.strategies_vector):
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
self.last_s_val[node_index] = strategy_index
break
def _generate_node_index_dict(self) -> Dict[Node, int]:
node_index_dict = {}
for index, strategies_vector in enumerate(self.leaf_strategies):
node_index_dict[strategies_vector.node] = index
return node_index_dict
def _prepare_data_for_solver(self):
'''
Extract information from components for solver.
'''
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
# prepare strategies_len
strategies_len = []
for node in self.nodes:
strategies_len.append(self.cost_graph.node_lens[node])
strategies_len = np.array(strategies_len)
# prepare following_nodes
following_nodes = self.cost_graph.following_dict
index_following_nodes = {}
for src, target in following_nodes.items():
src_index = self.node_index_dict[src]
target_index = self.node_index_dict[target]
index_following_nodes[src_index] = target_index
following_nodes = index_following_nodes
for index in range(node_nums):
if index not in following_nodes:
following_nodes[index] = -1
# prepare edge_pairs and resharding costs
edge_pairs = []
resharding_costs = []
for pairs, edge_cost in self.cost_graph.edge_costs.items():
src_node = pairs[0]
dst_node = pairs[1]
src_node_index = self.node_index_dict[src_node]
dst_node_index = self.node_index_dict[dst_node]
edge_pairs.append(src_node_index)
edge_pairs.append(dst_node_index)
for i in range(strategies_len[src_node_index]):
for j in range(strategies_len[dst_node_index]):
resharding_costs.append(edge_cost[(i, j)])
edge_pairs = np.array(edge_pairs)
resharding_costs = np.array(resharding_costs)
# prepare liveness_set
liveness_set = self.liveness_list
# omit alias_set now
alias_set = None
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
compute_costs = []
communication_costs = []
memory_costs = []
extra_node_costs = self.cost_graph.extra_node_costs
for strategies_vector in self.leaf_strategies:
node = strategies_vector.node
for index, strategy in enumerate(strategies_vector):
compute_cost_item = strategy.compute_cost
communication_cost_item = strategy.communication_cost
memory_cost_item = strategy.memory_cost
if self.forward_only:
origin_communication_cost = communication_cost_item.fwd
compute_cost = compute_cost_item.fwd
# extract MemoryCost item from the memory TrainCycleItem
memory_cost = memory_cost_item.fwd
else:
origin_communication_cost = communication_cost_item.total
compute_cost = compute_cost_item.total
# extract MemoryCost item from the memory TrainCycleItem
memory_cost = memory_cost_item.total
# extract the memory cost in float from MemoryCost item and sum them up
memory_cost = memory_cost.parameter + memory_cost.activation + memory_cost.buffer
compute_costs.append(compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
# cost into
if node in extra_node_costs:
extra_node_cost = extra_node_costs[node][index]
communication_cost = origin_communication_cost + extra_node_cost
communication_costs.append(communication_cost)
else:
communication_costs.append(origin_communication_cost)
memory_costs.append(memory_cost)
compute_costs = np.array(compute_costs)
communication_costs = np.array(communication_costs)
memory_costs = np.array(memory_costs)
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np, self.verbose
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None,
verbose=True):
"""
Call the solver with serialized arguments.
"""
tic = time.time()
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
assert isinstance(x, np.ndarray)
assert len(strategies_len) == node_nums, "strategies_len"
def get_non_zero_index(binary_vector):
"""
Get the index of non-zero item in a vector.
"""
ct = 0
ret = None
for i, elem in enumerate(binary_vector):
if pulp.value(elem):
ret = i
ct += 1
assert ct == 1
return ret
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
######################
# omit alias set now #
######################
# A = alias_set.reshape((-1, 2)) # noqa
# for (i, j) in A:
# prod_length = strategies_len[i] * strategies_len[j]
# v.append(alias_convert_costs[pt:pt + prod_length])
# pt += prod_length
# assert pt == len(alias_convert_costs)
# L = [] # noqa
# pt = node_nums
# for i in range(node_nums):
# length = liveness_set[i]
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
v = []
pt = 0
c = []
d = []
m = []
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[pt:pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
# 1. Create variables
#############################
# create variables for node #
#############################
s = []
num_nodes = 0
reverse_follow_backpatch = []
for i in range(node_nums):
if s_follow[i] < 0:
if strategies_len[i] == 1:
s.append([1])
else:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
else:
s.append(None)
reverse_follow_backpatch.append(i)
for i in reverse_follow_backpatch:
s[i] = s[s_follow[i]]
#############################
# create variables for edge #
#############################
e = []
num_edges = 0
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
for element in s:
assert len(element) > 0
# 2. Set initial value
######################################
# set a initial value for warm start #
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
s[idx][i].fixValue()
# 3. Objective
prob = LpProblem("myProblem", LpMinimize)
###################################################################
# computing the node cost(computing cost and communication cost) #
###################################################################
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
assert len(s[i]) == len(d[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i])
prob += obj
# 4. Constraints
# (a). specified by `cat="Binary"`
# (b)
#################################################
# make sure each node only choose one strategy #
#################################################
for i in range(node_nums):
if s_follow[i] < 0:
prob += lpSum(s[i]) == 1
# (c)
#################################################
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
if live_variable.node not in self.node_index_dict:
continue
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
# (e)
prob += lpSum(e[idx]) == 1
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
######################
# omit alias set now #
######################
# alias_set = set()
# for (idx, (i, j)) in enumerate(A):
# R = len(s[i]) # noqa
# C = len(s[j]) # noqa
# if (i, j) in alias_set:
# raise ValueError(f"Duplicated edges: {(i, j)}")
# alias_set.add((i, j))
# alias_set.add((j, i))
# for row in range(len(s[i])):
# for col in range(len(s[j])):
# if v[idx][row * C + col] > 0.5:
# prob += s[i][row] + s[j][col] <= 1
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
prob.solve(solver)
status = prob.status
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. "
"Please increase the memory budget.")
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
for i in range(node_nums):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
self.last_s_val = list(s_val)
# self._recover_merged_node_strategy()
self.last_objective = objective
if objective > INFINITY_COST:
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
return self.last_s_val, e_val, self.last_objective, status
def call_solver_serialized_args(self):
"""
Call the solver with serialized arguments and handle python errors. Additionally,
we could give a serious of solutions with different memory budget.
"""
if self.solution_numbers == 1:
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
return ret
origin_memory_budget = self.memory_budget
memory_budget_list = [
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
]
ret_list = []
for memory_budget in memory_budget_list:
self.memory_budget = memory_budget
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
ret_list.append(ret)
return ret_list
import builtins
import math
import operator
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx import Graph, Node
from colossalai.auto_parallel.tensor_shard.node_handler import (
GetattrHandler,
OuputHandler,
PlacehodlerHandler,
operator_registry,
)
from colossalai.auto_parallel.tensor_shard.sharding_strategy import StrategiesVector
from colossalai.auto_parallel.tensor_shard.utils import generate_resharding_costs, generate_sharding_spec
from colossalai.device.device_mesh import DeviceMesh
from .options import DataloaderOption, SolverOptions
__all__ = ['StrategiesConstructor']
class StrategiesConstructor:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.solver_options = solver_options
self.no_strategy_nodes = []
def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
Note that this operation is in-place.
'''
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)
def build_strategies_and_cost(self):
"""
This method is to build the strategy vector for each node in the computation graph.
"""
def _check_no_strategy_for_node(node):
if node.op in ('placeholder', 'get_attr', 'output'):
return False
def _check_no_strategy_for_data(data):
label = True
if isinstance(data, torch.Tensor):
return False
elif isinstance(data, (tuple, list)):
for d in data:
label = label and _check_no_strategy_for_data(d)
return label
return _check_no_strategy_for_data(node._meta_data)
for node in self.nodes:
strategies_vector = StrategiesVector(node)
if _check_no_strategy_for_node(node):
self.no_strategy_nodes.append(node)
pass
# placeholder node
elif node.op == 'placeholder':
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
placeholder_option = 'distributed'
else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
placeholder_option = 'replicated'
placeholder_handler = PlacehodlerHandler(node,
self.device_mesh,
strategies_vector,
placeholder_option=placeholder_option)
placeholder_handler.register_strategy()
# get_attr node
elif node.op == 'get_attr':
getattr_handler = GetattrHandler(node, self.device_mesh, strategies_vector)
getattr_handler.register_strategy()
# call_module node
elif node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
handler = operator_registry.get(submod_type)(node, self.device_mesh, strategies_vector)
handler.register_strategy()
# call_function node
elif node.op == 'call_function':
target = node.target
handler = operator_registry.get(target)(node, self.device_mesh, strategies_vector)
handler.register_strategy()
# call_method node
elif node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
handler = operator_registry.get(method)(node, self.device_mesh, strategies_vector)
handler.register_strategy()
# output node
elif node.op == 'output':
if self.solver_options.dataloader_option == DataloaderOption.DISTRIBUTED:
output_option = 'distributed'
else:
assert self.solver_options.dataloader_option == DataloaderOption.REPLICATED, f'placeholder_option {self.solver_options.dataloader_option} is not supported'
output_option = 'replicated'
output_handler = OuputHandler(node, self.device_mesh, strategies_vector, output_option=output_option)
output_handler.register_strategy()
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)
from .broadcast import (
BroadcastType,
comm_actions_for_oprands,
get_broadcast_shape,
is_broadcastable,
recover_sharding_spec_for_broadcast_shape,
)
from .factory import generate_resharding_costs, generate_sharding_spec
from .misc import check_sharding_spec_validity, ignore_sharding_exception, pytree_map
from .reshape import check_keep_sharding_status, detect_reshape_mapping, infer_output_dim_partition_dict
from .sharding import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size,
transpose_partition_dim,
update_partition_dim,
)
__all__ = [
'BroadcastType', 'get_broadcast_shape', 'is_broadcastable', 'recover_sharding_spec_for_broadcast_shape',
'generate_resharding_costs', 'generate_sharding_spec', 'ignore_sharding_exception', 'check_sharding_spec_validity'
'transpose_partition_dim', 'update_partition_dim', 'enumerate_all_possible_1d_sharding',
'enumerate_all_possible_2d_sharding', 'generate_sharding_size', 'comm_actions_for_oprands', 'pytree_map',
'detect_reshape_mapping', 'check_keep_sharding_status', 'infer_output_dim_partition_dict'
]
from enum import Enum, auto
from typing import List
import torch
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommAction,
CommType,
OperationData,
OperationDataType,
)
from colossalai.tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
__all__ = [
'BroadcastType', 'is_broadcastable', 'get_broadcast_shape', 'recover_sharding_spec_for_broadcast_shape',
'comm_actions_for_oprands'
]
class BroadcastType(Enum):
EQUAL = auto()
PADDDING = auto()
MULTIPLE = auto()
def is_broadcastable(shape1: torch.Size, shape2: torch.Size) -> bool:
"""
Check if two shapes are broadcastable to each other.
"""
for s1, s2 in zip(shape1[::-1], shape2[::-1]):
if s1 == 1 or s2 == 1 or s1 == s2:
pass
else:
return False
return True
def get_broadcast_shape(shape1: torch.Size, shape2: torch.Size) -> List[int]:
"""
Compute the broadcast shape given two shapes.
"""
assert is_broadcastable(shape1, shape2), f'{shape1} and {shape2} are not broadcastable'
shape1_reverse = shape1[::-1]
shape2_reverse = shape2[::-1]
min_common_dim = min(len(shape1), len(shape2))
dims = []
for s1, s2 in zip(shape1_reverse, shape2_reverse):
dims.append(max(s1, s2))
# append the remaining dims
dims.extend(shape1_reverse[min_common_dim:])
dims.extend(shape2_reverse[min_common_dim:])
return dims[::-1]
def get_broadcast_dim_info(logical_shape, physical_shape):
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
assert logical_num_dims >= physical_num_dims, \
'The number of dimensions in the logical shape is smaller than that of the physical shape, this tensor is not broadcast!'
# track the dim and its broadcasting type
logical_dim_broadcast_info = {}
for i in range(logical_num_dims):
# get the trailing dim size
logical_dim_idx = logical_num_dims - i - 1
phyiscal_dim_idx = physical_num_dims - i - 1
logical_dim_size = logical_shape[logical_dim_idx]
if phyiscal_dim_idx >= 0:
physical_dim_size = physical_shape[phyiscal_dim_idx]
if physical_dim_size == logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.EQUAL
elif physical_dim_size == 1 and physical_dim_size != logical_dim_size:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.MULTIPLE
else:
logical_dim_broadcast_info[logical_dim_idx] = BroadcastType.PADDDING
return logical_dim_broadcast_info
def recover_sharding_spec_for_broadcast_shape(logical_sharding_spec: ShardingSpec, logical_shape: torch.Size,
physical_shape: torch.Size) -> ShardingSpec:
"""
This function computes the sharding spec for the physical shape of a broadcast tensor.
Args:
logical_sharding_spec (ShardingSpec): the sharding spec for the broadcast tensor
logical_shape (torch.Size): logical shape is the broadcast shape of a tensor
physical_shape (torch.Size): the shape of the tensor before broadcasting
"""
# if the two shapes are the same, no broadcast occurs
# we directly return the current sharding spec
# recording the sharding dimensions removed during logical shape converting to physical one
removed_dims = []
if list(logical_shape) == list(physical_shape):
return logical_sharding_spec, removed_dims
# get the number of dimensions
logical_num_dims = len(logical_shape)
physical_num_dims = len(physical_shape)
# get the broadcast info
logical_dim_broadcast_info = get_broadcast_dim_info(logical_shape, physical_shape)
# generate the sharding spec for the physical shape
physical_dim_partition = {}
logical_dim_partition = logical_sharding_spec.dim_partition_dict
for shape_dim, mesh_dim in logical_dim_partition.items():
logical_broadcast_type = logical_dim_broadcast_info[shape_dim]
if logical_broadcast_type == BroadcastType.PADDDING or logical_broadcast_type == BroadcastType.MULTIPLE:
removed_dims.extend(mesh_dim)
else:
# get the corresponding physical dim
physical_dim = physical_num_dims - (logical_num_dims - shape_dim)
physical_dim_partition[physical_dim] = mesh_dim
physical_sharding_spec = ShardingSpec(device_mesh=logical_sharding_spec.device_mesh,
entire_shape=physical_shape,
dim_partition_dict=physical_dim_partition)
return physical_sharding_spec, removed_dims
def comm_actions_for_oprands(node: Node, removed_dims: List[int], op_data: OperationData,
sharding_spec: ShardingSpec) -> CommAction:
"""
This method is used to generate communication actions for oprands which lose information
during convert logical shape to physical shape.
"""
if len(removed_dims) == 1:
# if list length is 1, extract element from list to avoid using flatten device mesh
removed_dims = removed_dims[0]
comm_spec = CommSpec(comm_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
sharding_spec=sharding_spec,
logical_process_axis=removed_dims)
if op_data.type == OperationDataType.PARAM:
comm_type = CommType.HOOK
else:
comm_type = CommType.BEFORE
arg_index = -1
for index, arg in enumerate(node.args):
if op_data.name == str(arg):
arg_index = index
assert arg_index >= 0, f'op_data should be an argument of node.'
comm_action = CommAction(
comm_spec=comm_spec,
comm_type=comm_type,
arg_index=arg_index,
)
return comm_action
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment