Commit 93b788b9 authored by binmakeswell's avatar binmakeswell
Browse files

Merge branch 'main' into fix/format

parents 2fd528b9 1dc003c1
import operator
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
__all__ = ['BatchNormHandler']
class BatchNormHandler(OperatorHandler):
"""
A OperatorHandler which deals with the sharding strategies of normalization.
To keep the math consistency, there are two way to do BatchNorm if the input
shards on batch dimension:
1. We gather the input partitions through batch dimension, then do the normal BatchNorm.
2. We do the SyncBatchNorm on the each input partition seperately, the SyncBN op will help
us to keep the computing correctness.
In this handler, both methods will be considered.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.bias = self.module_named_parameters['bias']
self.output_data = self.node._meta_data
self._sanity_check()
def _sanity_check(self):
'''
In sanity check, we need make sure the input data having correct dimension size.
For BatchNorm1d, the dim of input data should be 3([N, C, L]).
For BatchNorm2d, the dim of input data should be 4([N, C, H, W]).
For BatchNorm3d, the dim of input data should be 5([N, C, H, W, D]).
'''
assert self.input_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def _generate_compute_cost(self, bs, channel_in):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
Argument:
bs(int): Batch size of the input data.
channel_in(int): The channel dimension of input data.
Return:
compute_cost(float): Computation cost per device with this specific strategy
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
# 1D: (L) * N * Cin
# 2D: (H * W) * N * Cin
# 3D: (H * W * D) * N * Cin
input_size = self.input_data.shape[2:]
input_size_product = reduce(operator.mul, input_size, 1)
forward_compute_cost = input_size_product * bs * channel_in
backward_activation_compute_cost = input_size_product * bs * channel_in
backward_weight_compute_cost = input_size_product * bs * channel_in
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
compute_cost = forward_compute_cost + backward_compute_cost
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
numel_input = numel_output
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation
@ignore_sharding_exception
def split_input_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
new_name = f'S{mesh_dim_1}S{mesh_dim_0} = RS{mesh_dim_0} x S{mesh_dim_0}'
dim_partition_dict_for_output = {0: [mesh_dim_1], 1: [mesh_dim_0]}
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost
# the memory cost need to be recomputed
# compute the memroy cost of new strategy
new_sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
new_sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# the communication cost need to count the sharding cost into this strategy
# compute the communication cost of new strategy
origin_communication_cost = communication_cost
tiny_shard_cost = 10
new_forward_communication_cost = tiny_shard_cost
# we need to all gather the batch dimension for the basic strategy
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation, mesh_dim_1)
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
sharding_strategies = ShardingStrategy(new_name,
output_sharding_spec=new_sharding_spec_for_output,
compute_cost=new_compute_cost,
communication_cost=new_communication_cost,
memory_cost=new_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_channel_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
self.device_mesh.shape[mesh_dim_1])
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def non_split(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RR x R'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_weight = 1
memory_cost, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def _construct_batch_sharding_strategies(mesh_dim_list, new_name):
dim_partition_dict_for_output = {0: mesh_dim_list}
new_sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# the computation cost is all the same
new_compute_cost = compute_cost
# the memory cost need to be recomputed
new_sharding_size_input = 1
for mesh_dim in mesh_dim_list:
new_sharding_size_input = new_sharding_size_input * self.device_mesh.shape[mesh_dim]
new_memory_cost, _, memory_cost_backward_activation = self._generate_memory_cost(
new_sharding_size_input, sharding_size_backward_activation, sharding_size_weight)
# the communication cost need to count the sharding cost into this strategy
origin_communication_cost = communication_cost
tiny_shard_cost = 10
new_forward_communication_cost = tiny_shard_cost
if len(mesh_dim_list) == 1:
new_backward_communication_cost = self.device_mesh.all_gather_cost(memory_cost_backward_activation,
mesh_dim_list[0])
else:
new_backward_communication_cost = self.device_mesh.flatten_device_mesh.all_gather_cost(
memory_cost_backward_activation, 0)
new_communication_cost = origin_communication_cost + new_forward_communication_cost + new_backward_communication_cost
new_sharding_strategy = ShardingStrategy(new_name,
output_sharding_spec=new_sharding_spec_for_output,
compute_cost=new_compute_cost,
communication_cost=new_communication_cost,
memory_cost=new_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input,
sharding_spec_for_weight))
return new_sharding_strategy
# shard the output batch dimension to get all possible sharding strategy from this basic strategy
# shard on mesh_dim_0
new_name = f'S{mesh_dim_0}R = RR x R'
mesh_dim_list = [mesh_dim_0]
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
# shard on mesh_dim_1
new_name = f'S{mesh_dim_1}R = RR x R'
mesh_dim_list = [mesh_dim_1]
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
# shard on mesh_dim_0, mesh_dim_1
new_name = f'S{mesh_dim_0}{mesh_dim_1}R = RR x R'
mesh_dim_list = [mesh_dim_0, mesh_dim_1]
new_sharding_strategy = _construct_batch_sharding_strategies(mesh_dim_list, new_name)
self.strategies_vector.append(new_sharding_strategy)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = 1
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# the all reduce communication will happen during the sync bn computing.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_batch_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x R WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
channel_in = self.input_data.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = 1
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# the all reduce communication will happen during the sync bn computing.
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost_forward_activation, 0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1} WITH SYNC_BN'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(bs, channel_in)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _ = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# the all reduce communication will happen during the sync bn computing.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
Example:
norm_handler = BatchNormHandler(node, strategies_vector,
self.shape_consistency_manager)
norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector:
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
Output:
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
'''
# RS = RS x S and strategies based on it, such as
# SS = RS x S
self.split_input_channel(0, 1)
self.split_input_channel(1, 0)
# RR = RR x R and strategies based on it, such as
# SR = SR x R
self.non_split(0, 1)
# RS01 = RS01 x S01
self.split_input_channel_1d(0, 1)
# SR = SR x R WITH SYNC_BN
self.split_input_batch(0)
self.split_input_batch(1)
# SS = SS x S WITH SYNC_BN
self.split_input_both_dim(0, 1)
self.split_input_both_dim(1, 0)
# S01R = S01R x R WITH SYNC_BN
self.split_input_batch_1d(0, 1)
return self.strategies_vector
import operator
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .operator_handler import OperatorHandler
__all__ = ['BcastOpHandler']
class BcastOpHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of broadcast operators(such as operator.add).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert len(self.predecessor_node) == 2
self.lhs_data = self.predecessor_node[0]._meta_data
self.rhs_data = self.predecessor_node[1]._meta_data
self.lhs = self.predecessor_node[0]
self.rhs = self.predecessor_node[1]
self.output_data = self.node._meta_data
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
shape = list(input_.shape)
# padding the shape to the same length as output_data
while len(shape) < self.output_data.dim():
shape.insert(0, 1)
shape = torch.Size(shape)
# if the sharding happens on a size one dimension, we should record it as R.
processed_dim_partition_dict = deepcopy(dim_partition_dict)
for dim_index, _ in dim_partition_dict.items():
if shape[dim_index] == 1:
processed_dim_partition_dict.pop(dim_index)
for dim_index, sharding_index_list in processed_dim_partition_dict.items():
sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
assert shape[
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=shape,
dim_partition_dict=processed_dim_partition_dict)
return sharding_spec
def _generate_compute_cost(self, total_sharding_size):
lhs_matrix_shape = self.lhs_data.shape[-2:]
rhs_matrix_shape = self.rhs_data.shape[-2:]
batch_dimensions_shape = self.output_data.shape[:-2]
batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
compute_cost = reduce(
operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
return compute_cost
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.
dtype = self.node._meta_data.dtype
nodes = self.predecessor_node
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
for input_node, input_spec in zip(nodes, sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
# Then, use the padded input sharding spec to compute the resharding cost.
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
new_entire_shape = list(input_sharding_spec.entire_shape)
while len(new_entire_shape) < len(input_spec.entire_shape):
new_entire_shape.insert(0, 1)
new_entire_shape = torch.Size(new_entire_shape)
new_device_mesh = input_sharding_spec.device_mesh
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
entire_shape=new_entire_shape,
dim_partition_dict=new_dim_partition_dict)
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost["total"] * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
sharding_spec_list = []
check_duplicated_list = []
for output_dim_partition_dict in dim_partition_list:
try:
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
except AssertionError as e:
warnings.warn(f'{e}')
break
sharding_seq = output_sharding_spec.sharding_sequence
if sharding_seq not in check_duplicated_list:
check_duplicated_list.append(sharding_seq)
sharding_spec_list.append(output_sharding_spec)
return sharding_spec_list
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
output_dim_partition_list = []
dim_size = self.output_data.dim()
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
output_dim_partition_list.extend(sharding_list_2d)
# enumerate all the 1D sharding cases
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
# add empty dict for fully replicated case
output_dim_partition_list.append({})
output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
return output_sharding_spec_list
@ignore_sharding_exception
def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_input)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_input)
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the computation cost of this strategy
sharding_dims = []
for mesh_dims in dim_partition_dict_for_output.values():
for mesh_dim in mesh_dims:
sharding_dims.append(self.device_mesh.shape[mesh_dim])
sharding_size = reduce(operator.mul, sharding_dims, 1)
memory_cost = self.output_data.numel() / sharding_size
compute_cost = memory_cost
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=output_sharding_spec,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
##############################################
#used to generate strategies for torch.matmul#
##############################################
@ignore_sharding_exception
def _registry_no_split_strategies_for_matmul(self, dim_partition_dict_for_batch_dim):
# this dim partition dict only describes the batch dimensions, but in this scenario,
# matrix dimensions are fully replicated, so it do not need extra process.
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_batch_dim)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_batch_dim)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_batch_dim)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
batch_sharding_dims = []
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
# in this case, total_sharding_size is equal to the batch sharding size
memory_cost = self.output_data.numel() / batch_sharding_size
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(batch_sharding_size)
# in this case, no communication takes place.
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def _split_dim_i(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
# In this scenario, matrix dimensions will be sharded on 'i' dimension.
# in this case, the matrix dimensions of lhs is sharded on 'i' dimension.
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_lhs.update({-2: mesh_dim_on_matrix})
# in this case, the matrix dimensions of rhs is fully replicated.
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
# in this case, the matrix dimensions of output is sharded on 'i' dimension.
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_output.update({-2: mesh_dim_on_matrix})
# generate sharding specs
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_dims = []
# append batch sharding dims
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
# append the sharding dims on matrix dimension
for mesh_dim in mesh_dim_on_matrix:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / total_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def _split_dim_k(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
# In this scenario, matrix dimensions will be sharded on 'k' dimension.
# in this case, the matrix dimensions of lhs is sharded on 'k' dimension.
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_lhs.update({-1: mesh_dim_on_matrix})
# in this case, the matrix dimensions of rhs is sharded on 'k' dimension.
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_rhs.update({-2: mesh_dim_on_matrix})
# in this case, the matrix dimensions of output is fully replicated.
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
# generate sharding specs
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_dims = []
batch_sharding_dims = []
# append batch sharding dims
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
batch_sharding_dims.append(self.device_mesh.shape[mesh_dim])
# append the sharding dims on matrix dimension
for mesh_dim in mesh_dim_on_matrix:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
batch_sharding_size = reduce(operator.mul, batch_sharding_dims, 1)
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
# in this case, output_data is fully replicated on matrix dimensions.
memory_cost = self.output_data.numel() / batch_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during forward activation computation.
if len(mesh_dim_on_matrix) == 1:
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
else:
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def _split_dim_j(self, dim_partition_dict_for_batch_dim, mesh_dim_on_matrix):
# A batched matrix multiplication can be viewed as [b, i, k] x [b, k, j] -> [b, i, j]
# this dim partition dict describe the batch dimensions, so we should append the matrix dimension sharding info on it.
# In this scenario, matrix dimensions will be is sharded on 'j' dimension.
# in this case, the matrix dimensions of lhs is fully replicated.
dim_partition_dict_for_lhs = deepcopy(dim_partition_dict_for_batch_dim)
# in this case, the matrix dimensions of rhs is sharded on 'j' dimension.
dim_partition_dict_for_rhs = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_rhs.update({-1: mesh_dim_on_matrix})
# in this case, the matrix dimensions of output is sharded on 'j' dimension.
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_batch_dim)
dim_partition_dict_for_output.update({-1: mesh_dim_on_matrix})
# generate sharding specs
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_dims = []
# append batch sharding dims
for mesh_dims in dim_partition_dict_for_batch_dim.values():
for mesh_dim in mesh_dims:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
# append the sharding dims on matrix dimension
for mesh_dim in mesh_dim_on_matrix:
total_sharding_dims.append(self.device_mesh.shape[mesh_dim])
total_sharding_size = reduce(operator.mul, total_sharding_dims, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / total_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during backward activation computation.
if len(mesh_dim_on_matrix) == 1:
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_on_matrix[0])
else:
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(memory_cost, 0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def _registry_1d_strategies_for_matmul(self, dim_partition_dict, mesh_dim_list):
self._split_dim_i(dim_partition_dict, mesh_dim_list)
self._split_dim_k(dim_partition_dict, mesh_dim_list)
self._split_dim_j(dim_partition_dict, mesh_dim_list)
@ignore_sharding_exception
def _split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
dim_partition_dict_for_rhs = {-2: [mesh_dim_1]}
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
dim_partition_dict_for_output = {-2: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / output_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during forward activation computation.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def _split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-1: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
dim_partition_dict_for_rhs = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
dim_partition_dict_for_output = {-1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / output_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during forward and backward activation computation.
communication_cost_forward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_0)
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
communication_cost = communication_cost_backward_activation + communication_cost_forward_activation
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def _split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
dim_partition_dict_for_lhs = {-2: [mesh_dim_0]}
sharding_spec_for_lhs = self._generate_sharding_spec(self.lhs_data, dim_partition_dict_for_lhs)
dim_partition_dict_for_rhs = {-1: [mesh_dim_1]}
sharding_spec_for_rhs = self._generate_sharding_spec(self.rhs_data, dim_partition_dict_for_rhs)
dim_partition_dict_for_output = {-2: [mesh_dim_0], -1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_lhs.sharding_sequence} x {sharding_spec_for_rhs.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_lhs, sharding_spec_for_rhs])
# compute the memory cost of this strategy
total_sharding_size = reduce(operator.mul, self.device_mesh.shape, 1)
output_sharding_size = reduce(operator.mul, self.output_data.shape, 1)
# in this case, output_data uses all the sharding dims.
memory_cost = self.output_data.numel() / output_sharding_size
compute_cost = self._generate_compute_cost(total_sharding_size)
# TODO: add all-reduce cost if lhs or rhs is type of Parameters.
# The communication takes place during backward activation computation.
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_lhs, sharding_spec_for_rhs))
self.strategies_vector.append(sharding_strategies)
def _registry_2d_strategies_for_matmul(self):
self._split_lhs_space_both_contract(0, 1)
self._split_lhs_space_both_contract(1, 0)
self._split_rhs_space_both_contract(0, 1)
self._split_rhs_space_both_contract(1, 0)
self._split_lhs_space_rhs_space(0, 1)
self._split_lhs_space_rhs_space(1, 0)
def register_strategy(self) -> StrategiesVector:
MESH_DIM_LIST = [0, 1]
if self.node.target != torch.matmul:
output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
for output_sharding_spec in output_sharding_specs:
self._register_strategy(output_sharding_spec)
else:
# we only care about the non-computing dimensions,
# therefore, we omit the last two dimensions.
dim_size = self.output_data.dim() - 2
# Both device mesh axises are uesd on batch dimensions
dim_partition_dicts_2d = enumerate_all_possible_2d_sharding(MESH_DIM_LIST[0], MESH_DIM_LIST[1], dim_size)
for dim_partition_dict in dim_partition_dicts_2d:
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
# Only one device mesh axis is uesd on batch dimensions
for mesh_dim_index in [0, 1]:
dim_partition_dicts_1d = enumerate_all_possible_1d_sharding(MESH_DIM_LIST[mesh_dim_index], dim_size)
for dim_partition_dict in dim_partition_dicts_1d:
self._registry_no_split_strategies_for_matmul(dim_partition_dict)
self._registry_1d_strategies_for_matmul(dim_partition_dict, [MESH_DIM_LIST[mesh_dim_index - 1]])
# No device mesh axis is uesd on batch dimensions
dim_partition_dict_on_batch_dim = {}
self._registry_no_split_strategies_for_matmul(dim_partition_dict_on_batch_dim)
self._registry_1d_strategies_for_matmul(dim_partition_dict_on_batch_dim, MESH_DIM_LIST)
self._registry_2d_strategies_for_matmul()
import operator
import warnings
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
__all__ = ['ConvHandler']
class ConvHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of Convolution.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.output_data = self.node._meta_data
self._sanity_check()
def _sanity_check(self):
'''
In sanity check, we need make sure the input data having correct dimension size.
For Conv1d, the dim of input data should be 3([N, C, L]).
For Conv2d, the dim of input data should be 4([N, C, H, W]).
For Conv3d, the dim of input data should be 5([N, C, H, W, D]).
'''
assert self.input_data.dim() in (3, 4,
5), f'We suppose the dim of input fed into conv op should in range of [3, 5].'
def _generate_compute_cost(self, bs, channel_in, channel_out):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
Argument:
bs(int): Batch size of the input data.
channel_in(int): The channel dimension of input data.
channel_out(int): The out channel of the conv weight.
Return:
compute_cost(float): Computation cost per device with this specific strategy
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# 1D: (L) * N * Cout * Cin * kernel
# 2D: (H * W) * N * Cout * Cin * kernel
# 3D: (H * W * D) * N * Cout * Cin * kernel
output_size = self.output_data.shape[2:]
output_size_product = reduce(operator.mul, output_size, 1)
input_size = self.input_data.shape[2:]
input_size_product = reduce(operator.mul, input_size, 1)
kernel_size = self.weight.shape[2:]
kernel_size_product = reduce(operator.mul, kernel_size, 1)
forward_compute_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
backward_activation_cost = input_size_product * bs * channel_in * channel_out * kernel_size_product
backward_weight_cost = output_size_product * bs * channel_in * channel_out * kernel_size_product
compute_cost = forward_compute_cost + backward_activation_cost + backward_weight_cost
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
numel_input = self.input_data.numel()
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
@ignore_sharding_exception
def split_input_batch_weight_out_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1]
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, _, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# This strategy do not need to do all_reduce operation during forward
communication_cost_forward = 0
# compute the backward communication cost to all reduce the input activation grad
communication_cost_backward_activation = self.device_mesh.all_reduce_cost(memory_cost_backward_activation,
mesh_dim_1)
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# total communication cost
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_batch(self, mesh_dim_0):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1]
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = 1
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation in forward phase.
communication_cost_forward = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# compute the total cost
communication_cost = communication_cost_forward + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_both_dim_weight_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // self.device_mesh.shape[mesh_dim_0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_1]
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _, memory_cost_backward_weight = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_1)
# This strategy do not need to do all_reduce operation to compute the input activation grad
communication_cost_backward_activation = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight, mesh_dim_0)
# compute total cost
communication_cost = communication_cost_forward + communication_cost_backward_activation + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_in_channel_weight_both_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_in_channel_weight_in_channel(self, mesh_dim_0):
name = f'RR = RS{mesh_dim_0} x S{mesh_dim_0}R'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // self.device_mesh.shape[mesh_dim_0]
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0]
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
# This strategy do NOT need all_reduce during forward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_weight_out_channel(self, mesh_dim_0):
name = f'RS{mesh_dim_0} = RR x RS{mesh_dim_0}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1]
channel_out = self.weight.shape[1] // self.device_mesh.shape[mesh_dim_0]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0]
sharding_size_backward_activation = 1
sharding_size_weight = self.device_mesh.shape[mesh_dim_0]
memory_cost, _, memory_cost_backward_activation, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# This strategy do not need to do all_reduce during forward phase
communication_cost_forward = 0
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_0)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x RR'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1]
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_weight = 1
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce in both forward and backward phase
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_1d_parallel_on_input_batch(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0] // (self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1])
channel_in = self.input_data.shape[1]
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_weight = 1
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce in forward phase
communication_cost_forward = 0
# compute the backward communication cost to all reduce the weight due to data parallel
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_backward_weight, 0)
# compute the total communication cost
communication_cost = communication_cost_backward_weight + communication_cost_forward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_1d_parallel_on_in_channel(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
bs = self.input_data.shape[0]
channel_in = self.input_data.shape[1] // (self.device_mesh.shape[mesh_dim_0] *
self.device_mesh.shape[mesh_dim_1])
channel_out = self.weight.shape[1]
compute_cost = self._generate_compute_cost(bs, channel_in, channel_out)
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[
mesh_dim_1]
sharding_size_weight = self.device_mesh.mesh_shape[mesh_dim_0] * self.device_mesh.mesh_shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, _, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute communication cost during forward phase
communication_cost_forward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_forward_activation, 0)
# This strategy do NOT need do all_reduce during backward phase
communication_cost_backward = 0
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
Example:
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# [x, mul, conv, output]
nodes = [node for node in gm.graph.nodes]
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(node=nodes[0], in_nodes=[nodes[1], 2], strategies=strategies_for_input)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[nodes[1], ])
conv_handler = ConvHandler(input_node=nodes[1], input_index=0, weight=dict(gm.named_modules())[nodes[2].name].weight, output_node=nodes[2],
device_mesh=device_mesh, strategies_vector=strategies_vector, shape_consistency_manager=shape_consistency_manager)
conv_handler.register_strategy_into_strategies_vector()
for strategy in conv_handler.strategies_vector:
print(f'{strategy.name}: compute_cost is {strategy.compute_cost}, communication_cost is {strategy.communication_cost}, memory_cost is {strategy.memory_cost}, resharding_costs is {strategy.resharding_costs}')
Output:
S0S1 = S0R x RS1: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
S1S0 = S1R x RS0: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
S0R = S0R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 32769.001, 131074.2, 0, 32769.1, 131074.2, 98307.201]}
S1R = S1R x RR: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 32769.001, 131074.2, 98307.201, 0, 32769.1]}
S0R = S0S1 x S1R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 65538.002, 0, 0, 0, 65538.002, 196614.402]}
S1R = S1S0 x S0R: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 65538.002, 65538.002, 196614.402, 0, 0]}
RS1 = RS0 x S0S1: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
RS0 = RS1 x S1S0: compute_cost is 8856576, communication_cost is 984065.01, memory_cost is 984064.0, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
RR = RS0 x S0R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 131074.2, 32769.001, 98307.201, 131074.2, 32769.1]}
RR = RS1 x S1R: compute_cost is 17713152, communication_cost is 1968129.01, memory_cost is 1968128, resharding_costs is {mul: [0, 131074.2, 0, 131074.2, 32769.1, 32769.001, 98307.201]}
RS0 = RR x RS0: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
RS1 = RR x RS1: compute_cost is 17713152, communication_cost is 0, memory_cost is 984064.0, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
RR = RR x RR: compute_cost is 35426304, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 65537.1, 65537.1, 65537.1, 131075.30000000002, 65537.1, 131075.30000000002]}
S01R = S01R x RR: compute_cost is 8856576, communication_cost is 0, memory_cost is 492032.0, resharding_costs is {mul: [0, 65538.002, 262148.4, 0, 16385.001, 262148.4, 196614.402]}
RR = RS01 x S01R: compute_cost is 8856576, communication_cost is 0, memory_cost is 1968128, resharding_costs is {mul: [0, 0, 262148.4, 65538.002, 196614.402, 262148.4, 65538.2]}
'''
# SS = SR x RS
self.split_input_batch_weight_out_channel(0, 1)
self.split_input_batch_weight_out_channel(1, 0)
# SR = SR x RR
self.split_input_batch(0)
self.split_input_batch(1)
# SR = SS x SR
self.split_input_both_dim_weight_in_channel(0, 1)
self.split_input_both_dim_weight_in_channel(1, 0)
# RS = RS x SS
self.split_input_in_channel_weight_both_channel(0, 1)
self.split_input_in_channel_weight_both_channel(1, 0)
# RR = RS x SR
self.split_input_in_channel_weight_in_channel(0)
self.split_input_in_channel_weight_in_channel(1)
# RS = RR x RS
self.split_weight_out_channel(0)
self.split_weight_out_channel(1)
# RR= RR x RR
self.non_split()
# S01R = S01R x RR
self.split_1d_parallel_on_input_batch(0, 1)
# RR = RS01 x S01R
self.split_1d_parallel_on_in_channel(0, 1)
return self.strategies_vector
CONV_STRATEGIES_LIST = [
'S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R',
'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1',
'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R'
]
import operator
from enum import Enum
from functools import reduce
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from ..constants import LINEAR_FUNC_OP, LINEAR_MODULE_OP
from .operator_handler import OperatorHandler
from .strategy_generator import IntermediateStrategy, StrategyGenerator
__all__ = ['DotHandler']
class DotProductStrategyGenerator(StrategyGenerator):
"""
DotProductStrategyGenerator is used to generate the sharding strategies for two 1D tensors in dot product computation.
This is created for torch.matmul where two tensors are 1D tensors. As torch.matmul does not include a bias argument, so we
do not consider bias here.
"""
def validate(self, input, other):
assert input.dim() == 1 and other.dim() == 1
def no_split(self):
name = f'R = R dot R'
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_one_dim(self, mesh_dim):
name = f'S{mesh_dim} = S{mesh_dim} dot S{mesh_dim}'
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "output": {}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
def generate(self) -> List[IntermediateStrategy]:
strategy_list = []
# do not split dimensions for dot product
# R = R dot R
strategy_list.append(self.no_split())
# split two tensors in the same dimensions
# S = S dot S
strategy_list.append(self.split_one_dim(0))
strategy_list.append(self.split_one_dim(1))
return strategy_list
class MatVecStrategyGenerator(StrategyGenerator):
def validate(self, input, other) -> bool:
assert input.dim() > 1 and other.dim() == 1
def no_split(self):
name = "R = R x R"
dim_partition_dict = {"input": {}, "other": {}, "output": {}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_input_batch(self, mesh_dim):
name = f'S{mesh_dim}R = S{mesh_dim}R x R'
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {}, "output": {0: [mesh_dim]}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def generate(self) -> List[IntermediateStrategy]:
strategy_list = []
# no split
strategy_list.append(self.no_split())
# split the batch dim for the first tensor only
strategy_list.append(self.split_input_batch(0))
strategy_list.append(self.split_input_batch(1))
return strategy_list
class MatMulStrategyGenerator(StrategyGenerator):
"""
MatMulStrategyGenerator is used to generate the sharding strategies when the second tensor is
a 2D tensor. This is used for nn.Linear, F.linear, torch.matmul and torch.addmm.
A matmul can be formulated as [n, p] x [p, q] = [n, q]
Args:
is_linear (bool): whether this generator is used for nn.Linear and F.linear.
This will incur extra transformation of the dim partitioning as the weight is transposed.
"""
def __init__(self, is_linear: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_linear = is_linear
# as the weight for the linear module is transposed, we can compute
# the correponding dimension indexfor convenience
if is_linear:
self.dim_q = 0
self.dim_p = 1
else:
self.dim_q = 1
self.dim_p = 0
def validate(self, input, other, bias) -> bool:
# make sure the second tensor is a 2D tensor
assert input.dim() > 0 and other.dim() == 2
# make sure bias is of the same dimension
if self.is_linear:
assert bias is None or bias.shape[-1] == other.shape[0]
else:
assert bias is None or bias.shape[-1] == other.shape[1]
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
self.dim_q: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
self.dim_p: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict = {
"input": {
-1: [mesh_dim_0]
},
"other": {
self.dim_p: [mesh_dim_0],
self.dim_q: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
-1: [mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
dim_partition_dict = {
"input": {
-1: [mesh_dim]
},
"other": {
self.dim_p: [mesh_dim]
},
"bias": {},
"output": {},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
dim_partition_dict = {
"input": {},
"other": {
self.dim_q: [mesh_dim]
},
"bias": {
-1: [mesh_dim]
},
"output": {
-1: [mesh_dim]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim])
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict = {
"input": {
-1: [mesh_dim_0, mesh_dim_1]
},
"other": {
self.dim_p: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {},
}
return IntermediateStrategy(name=name,
dim_partition_dict=dim_partition_dict,
all_reduce_axis=[mesh_dim_0, mesh_dim_1])
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict = {
"input": {},
"other": {
self.dim_q: [mesh_dim_0, mesh_dim_1]
},
"bias": {
-1: [mesh_dim_0, mesh_dim_1]
},
"output": {
-1: [mesh_dim_0, mesh_dim_1]
},
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
class BatchedMatMulStrategyGenerator(StrategyGenerator):
"""
Generate sharding strategies for the batched matrix multiplication.
A batched matrix multiplication can be viewed as
[b, i, k] x [b, k, j] -> [b, i, j]
"""
def __init__(self, is_torch_bmm: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_torch_bmm = is_torch_bmm
def validate(self, input, other, bias) -> bool:
if self.is_torch_bmm:
assert input.shape == other.shape
assert input.dim() > 2
assert other.shape[-1] == bias.shape[0]
else:
# TODO: validate these inputs are broadcastable
pass
def split_one_batch_dim(self):
if 1 in self.device_mesh.mesh_shape:
mesh_dim = self.device_mesh.mesh_shape.index(1)
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
dim_partition_dict = {
"input": {
0: [mesh_dim]
},
"other": {
0: [mesh_dim]
},
"bias": {},
"output": {
0: [mesh_dim]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
else:
return None
def split_two_batch_dim(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}{mesh_dim_1} = Sb{mesh_dim_0}{mesh_dim_1} x Sb{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0, mesh_dim_1]
},
"other": {
0: [mesh_dim_0, mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0, mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_one_batch_dim(self, mesh_dim):
name = f'Sb{mesh_dim} = Sb{mesh_dim} x Sb{mesh_dim}'
dim_partition_dict = {"input": {0: [mesh_dim]}, "other": {0: [mesh_dim]}, "bias": {}, "output": {0: [mesh_dim]}}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_batch_dim_lhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Si{mesh_dim_1} = Sb{mesh_dim_0}Si{mesh_dim_1} x Sb{mesh_dim_0}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0]
},
"bias": {},
"output": {
0: mesh_dim_0,
-2: [mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_batch_dim_rhs_space(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}Sj{mesh_dim_1} = Sb{mesh_dim_0}R x Sb{mesh_dim_0}Sj{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0]
},
"other": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict)
def split_batch_dim_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'Sb{mesh_dim_0}R = Sb{mesh_dim_0}Sk{mesh_dim_1} x Sb{mesh_dim_0}Sk{mesh_dim_1}'
dim_partition_dict = {
"input": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
"other": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
},
"bias": {},
"output": {
0: [mesh_dim_0],
-2: [mesh_dim_1]
}
}
return IntermediateStrategy(name=name, dim_partition_dict=dim_partition_dict, all_reduce_axis=[mesh_dim_1])
def generate(self) -> List[IntermediateStrategy]:
strategy_list = []
# split only the batch dimension
# Sb = Sb x Sb
# can be None as it is only for 1D device mesh
strategy = self.split_one_batch_dim()
if strategy:
strategy_list.append(strategy)
# split batch dim of two inputs and the i dim of the first tensor
# SbSi = SbSi x Sb
strategy_list.append(self.split_batch_dim_lhs_space(0, 1))
strategy_list.append(self.split_batch_dim_lhs_space(1, 0))
# split batch dim of two inputs and the j of the second tensor
# SbSj = Sb x SbSj
strategy_list.append(self.split_batch_dim_rhs_space(0, 1))
strategy_list.append(self.split_batch_dim_rhs_space(1, 0))
# split batch dim of two inputs and the k dim of two inputs
# Sb = SbSk x SbSk, need to all-reduce by k dim
strategy_list.append(self.split_batch_dim_both_contract(0, 1))
strategy_list.append(self.split_batch_dim_both_contract(1, 0))
# split two batch dim
strategy_list.append(self.split_two_batch_dim(0, 1))
strategy_list.append(self.split_two_batch_dim(1, 0))
return strategy_list
class DotHandler(OperatorHandler):
"""
A OperatorHandler which deals with the sharding strategies for nn.Linear and F.linear.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.output_data = self.node._meta_data
def _generate_compute_cost(self, input_shape, weight_shape, total_sharding_size):
# TODO: consider bias addition
compute_cost = reduce(operator.mul, input_shape) * weight_shape[0] * 2 // total_sharding_size
return compute_cost
@ignore_sharding_exception
def split_lhs_space_rhs_space(self, mesh_dim_0, mesh_dim_1):
# handle case SS = SR x RS
name = f'S{mesh_dim_0}S{mesh_dim_1} = S{mesh_dim_0}R x RS{mesh_dim_1}'
dim_partition_dict_for_input = {0: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
# linear layer weight is transposed during init
dim_partition_dict_for_weight = {0: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute computation cost
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
communication_cost_weight_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
communication_cost = communication_cost_activation_backward + communication_cost_weight_backward
# create and register strategy
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_lhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
# handle the case SR = SS x SR
name = f'S{mesh_dim_0}R = S{mesh_dim_0}S{mesh_dim_1} x S{mesh_dim_1}R'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
# since weight of the linear layer is transposed
# the actual dim to be sharded is 1
dim_partition_dict_for_weight = {1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
communication_cost_grad_backward = self.device_mesh.all_reduce_cost(weight_memory_cost, mesh_dim_0)
communication_cost = communication_cost_activation_forward + communication_cost_grad_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_rhs_space_both_contract(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_1} = RS{mesh_dim_0} x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {1: [mesh_dim_0]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_input)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_activation_forward = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_0)
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim_1)
communication_cost = communication_cost_activation_backward + communication_cost_activation_forward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def recompute_split_both_contract(self, mesh_dim):
name = f'RR = RS{mesh_dim} x S{mesh_dim}R'
dim_partition_dict_for_input = {1: [mesh_dim]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_rhs_space_only(self, mesh_dim):
name = f'RS{mesh_dim} = RR x RS{mesh_dim}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_activation_backward = self.device_mesh.all_reduce_cost(input_grad_memory_cost, mesh_dim)
communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_weight_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(weight_memory_cost, 0)
communication_cost = communication_cost_weight_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_forward_activation = self.device_mesh.flatten_device_mesh.all_reduce_cost(
activation_memory_cost, 0)
communication_cost = communication_cost_forward_activation
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape, total_sharding_size)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost, input_grad_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight, dim_partition_dict_for_input)
# compute the communication cost of this strategy
communication_cost_activation_backward = self.device_mesh.flatten_device_mesh.all_reduce_cost(
input_grad_memory_cost, 0)
communication_cost = communication_cost_activation_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a linear node, and record all strategies into the strategies_vector.
Output:
'''
# SS = SR x RS
self.split_lhs_space_rhs_space(0, 1)
self.split_lhs_space_rhs_space(1, 0)
# SR = SS x SR
self.split_lhs_space_both_contract(0, 1)
self.split_lhs_space_both_contract(1, 0)
# RS = RS x SS
self.split_rhs_space_both_contract(0, 1)
self.split_rhs_space_both_contract(1, 0)
# RR= RS x SR
self.recompute_split_both_contract(0)
self.recompute_split_both_contract(1)
# RS = RR x RS
self.split_rhs_space_only(0)
self.split_rhs_space_only(1)
# S01R = S01R x RR
self.split_lhs_1st_dim_1d(0, 1)
# RR = RS01 x S01R
self.split_lhs_2nd_dim_1d(0, 1)
# RS01 = RR x RS01
self.split_rhs_2nd_dim_1d(0, 1)
return self.strategies_vector
import operator
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .operator_handler import OperatorHandler
__all__ = ['EmbeddingHandler']
class EmbeddingHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of Embedding operators(such as nn.embedding).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.output_data = self.node._meta_data
def _generate_compute_cost(self, total_sharding_size):
input_shape = self.input_data.shape
weight_shape = self.weight.shape
input_shape_product = reduce(operator.mul, input_shape, 1)
weight_shape_product = reduce(operator.mul, weight_shape, 1)
compute_cost = input_shape_product * weight_shape_product * 2 / total_sharding_size
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
numel_input = self.input_data.numel()
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
@ignore_sharding_exception
def split_weight_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'RRS{mesh_dim_1} = RR x S{mesh_dim_0}S{mesh_dim_1}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {2: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
compute_cost = self._generate_compute_cost(total_sharding_size)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = 1
sharding_size_weight = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, _ = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# compute the communication cost of this strategy during forward phase
communication_cost_forward = self.device_mesh.all_reduce_cost(memory_cost_forward_activation, mesh_dim_0)
# compute the communication cost of this strategy during backward phase
communication_cost_backward = self.device_mesh.all_reduce_cost(memory_cost_backward_activation, mesh_dim_1)
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_both_dim(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}S{mesh_dim_1}R = S{mesh_dim_0}S{mesh_dim_1} x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0], 1: [mesh_dim_1]}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input, sharding_spec_for_weight])
# compute the computation cost of this strategy
total_sharding_size = self.device_mesh.shape[0] * self.device_mesh.shape[1]
compute_cost = self._generate_compute_cost(total_sharding_size)
# compute the memory cost of this strategy
sharding_size_forward = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_backward_activation = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
sharding_size_weight = 1
memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight = self._generate_memory_cost(
sharding_size_forward, sharding_size_backward_activation, sharding_size_weight)
# This strategy do not need to do all_reduce during forward phase
communication_cost_forward = 0
# compute the communication cost of this strategy during backward phase
communication_cost_backward_activation = 0
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_backward_weight, 0)
communication_cost_backward = communication_cost_backward_activation + communication_cost_backward_weight
communication_cost = communication_cost_forward + communication_cost_backward
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a Conv node, and record all strategies into the strategies_vector.
'''
# RRS = RR x SS
self.split_weight_both_dim(0, 1)
self.split_weight_both_dim(1, 0)
# SSR = SS x RR
self.split_input_both_dim(0, 1)
self.split_input_both_dim(1, 0)
return self.strategies_vector
import operator
from functools import reduce
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
generate_sharding_size,
ignore_sharding_exception,
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from .operator_handler import OperatorHandler
__all__ = ['LayerNormHandler']
class LayerNormHandler(OperatorHandler):
"""
A OperatorHandler which deals with the sharding strategies of normalization.
Note: To keep the math consistency, LayerNorm do not allow shards on hidden dimension.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.weight = self.module_named_parameters['weight']
self.bias = self.module_named_parameters['bias']
self.output_data = self.node._meta_data
def _generate_compute_cost(self, total_sharding_size):
'''
Compute the computation cost per device with this specific strategy.
Note: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
Argument:
bs(int): Batch size of the input data.
channel_in(int): The channel dimension of input data.
Return:
compute_cost(float): Computation cost per device with this specific strategy
'''
# TODO: compute_cost need to be devided by TFLOPS, now it just shows the computation size.
# TODO: a constant coefficient need to be added.
norm_kernel_size = self.weight.shape
# in LayerNorm context, batch dimensions mean all the dimensions do not join the normalization.
input_batch_shape = self.input_data.shape[:-len(norm_kernel_size)]
input_batch_product = reduce(operator.mul, input_batch_shape, 1)
norm_kernel_product = reduce(operator.mul, norm_kernel_size, 1)
forward_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
backward_activation_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
# To compute gradient of on norm kernel element requires input_batch_product times computation, so
# the total cost is input_batch_product * norm_kernel_product
backward_weight_compute_cost = input_batch_product * norm_kernel_product / total_sharding_size
backward_compute_cost = backward_activation_compute_cost + backward_weight_compute_cost
compute_cost = forward_compute_cost + backward_compute_cost
return compute_cost
def _generate_memory_cost(self, sharding_size_forward, sharding_size_backward_activation, sharding_size_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
sharding_size_forward(int): The forward activation will be divided
into sharding_size_forward number partions.
sharding_size_backward_activation(int): The backward activation will
be divided into sharding_size_backward_activation number partions.
sharding_size_weight(int): The backward weight will be divided
into sharding_size_weight number partions.
Return:
memory_cost(Tuple[float]): Memory cost per device with this
specific strategy, the first element of this tuple is forward
memory cost, and the second element of this tuple is backward
memory cost.
memory_cost_forward(float): Memory cost of forward activation per
device with this specific strategy.
memory_cost_backward_activation(float): Memory cost of backward activation
per device with this specific strategy.
'''
# compute the memory cost of this strategy
dtype = self.input_data.dtype
numel_output = self.output_data.numel()
# this operation will not change the shape of input
numel_input = numel_output
numel_weight = self.weight.numel()
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# forward memory_cost
memory_cost_forward_activation = numel_output * size_per_elem_bytes / sharding_size_forward
memory_cost_forward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_forward = memory_cost_forward_activation + memory_cost_forward_weight
# backward memory_cost
memory_cost_backward_activation = numel_input * size_per_elem_bytes / sharding_size_backward_activation
memory_cost_backward_weight = numel_weight * size_per_elem_bytes / sharding_size_weight
memory_cost_backward = memory_cost_backward_activation + memory_cost_backward_weight
# memory_cost pair
memory_cost = (memory_cost_forward, memory_cost_backward)
return memory_cost, memory_cost_forward_activation, memory_cost_backward_activation, memory_cost_backward_weight
def _generate_strategy_with_dim_partition(self, dim_partition):
dim_partition_dict_for_input = dim_partition
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = dim_partition
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
name = f'{sharding_spec_for_output.sharding_sequence} = {sharding_spec_for_input.sharding_sequence} x {sharding_spec_for_weight.sharding_sequence}'
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
total_sharding_size = generate_sharding_size(dim_partition, self.device_mesh)
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(total_sharding_size)
# compute the memory cost of this strategy
sharding_size_forward = generate_sharding_size(dim_partition_dict_for_input, self.device_mesh)
sharding_size_backward_activation = generate_sharding_size(dim_partition_dict_for_output, self.device_mesh)
sharding_size_weight = generate_sharding_size(dim_partition_dict_for_weight, self.device_mesh)
memory_cost, _, _, memory_cost_backward_weight = self._generate_memory_cost(sharding_size_forward,
sharding_size_backward_activation,
sharding_size_weight)
total_mesh_dim_list = []
for mesh_dim_list in dim_partition.values():
total_mesh_dim_list.extend(mesh_dim_list)
# This strategy do not need to do all_reduce operation for activation
communication_cost_forward_activation = 0
communication_cost_backward_activation = 0
if len(total_mesh_dim_list) == 1:
communication_cost_backward_weight = self.device_mesh.all_reduce_cost(memory_cost_backward_weight,
total_mesh_dim_list[0])
else:
assert len(total_mesh_dim_list) == 2, f'temporally we just support 2d device mesh.'
communication_cost_backward_weight = self.device_mesh.flatten_device_mesh.all_reduce_cost(
memory_cost_backward_weight, 0)
communication_cost = communication_cost_forward_activation + communication_cost_backward_activation + communication_cost_backward_weight
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
@ignore_sharding_exception
def split_input_batch_single_mesh_dim(self, mesh_dim_0):
batch_dimension_length = self.input_data.dim() - self.weight.dim()
dim_partition_list = enumerate_all_possible_1d_sharding(mesh_dim_0, batch_dimension_length)
for dim_partition in dim_partition_list:
self._generate_strategy_with_dim_partition(dim_partition)
@ignore_sharding_exception
def split_input_batch_both_mesh_dim(self, mesh_dim_0, mesh_dim_1):
batch_dimension_length = self.input_data.dim() - self.weight.dim()
dim_partition_list = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, batch_dimension_length)
for dim_partition in dim_partition_list:
self._generate_strategy_with_dim_partition(dim_partition)
@ignore_sharding_exception
def non_split(self):
name = f'RR = RR x R'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_output = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
total_sharding_size = 1
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(total_sharding_size)
# compute the memory cost of this strategy
sharding_size_forward = 1
sharding_size_backward_activation = 1
sharding_size_weight = 1
memory_cost, _, _, _ = self._generate_memory_cost(sharding_size_forward, sharding_size_backward_activation,
sharding_size_weight)
# This strategy do not need to do all_reduce operation
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_output,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
'''
Generate every possible strategies for a BatchNorm node, and record all strategies into the strategies_vector.
Example:
norm_handler = BatchNormHandler(node, strategies_vector,
self.shape_consistency_manager)
norm_handler.register_strategy()
for strategy in norm_handler.strategies_vector:
print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
Output:
RS0 = RS0 x S0, computation_cost: 131072, memory_cost: 524288.0
RS1 = RS1 x S1, computation_cost: 131072, memory_cost: 524288.0
RR = RR x R, computation_cost: 262144, memory_cost: 1048576
RS01 = RS01 x S01, computation_cost: 65536, memory_cost: 262144.0
'''
# SR = SR x R with single mesh dim on batch dimensions
self.split_input_batch_single_mesh_dim(0)
self.split_input_batch_single_mesh_dim(1)
# SR = SR x R with both mesh dims on batch dimensions
self.split_input_batch_both_mesh_dim(0, 1)
# RR = RR x R
self.non_split()
return self.strategies_vector
from abc import ABC, abstractmethod
from typing import Dict, List
from webbrowser import Opera
import torch
import torch.nn as nn
from torch.fx.node import Node
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from .._utils import generate_resharding_costs, generate_sharding_spec
from ..sharding_strategy import StrategiesVector
__all__ = ['OperatorHandler']
class OperatorHandler(ABC):
'''
The OperatorHandler is an abstract class used to generate every possible strategies for an operator node.
Args:
node (Node): the input node in node argument list.
device_mesh (DeviceMesh): A logical view of a physical mesh.
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
handle_backward (Optional[bool]): whether to consider the backward pass. The default value is True. False can be used for inference.
'''
def __init__(self,
node: Node,
device_mesh: DeviceMesh,
strategies_vector: StrategiesVector,
handle_backward: bool = True):
self.node = node
self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh
self.strategies_vector = strategies_vector
self.handle_backward = handle_backward
# find the module and its parameters associated with this node
# this can be used to compute the compute/communication/sharding cost
if self.node.op == 'call_module':
module = node.graph.owning_module.get_submodule(node.target)
named_parameters = list(module.named_parameters(recurse=False))
# convert named parameters from list to dict
named_parameters = {k: v for k, v in named_parameters}
elif self.node.op == 'call_function' and self.node.target not in NON_PARAM_FUNC_OP:
module = None
parameters = list(self.node.args)[1]
if isinstance(parameters, Node):
named_parameters = {'weight': parameters._meta_data}
else:
named_parameters = {}
else:
module = None
named_parameters = None
self.module = module
self.module_named_parameters = named_parameters
@abstractmethod
def register_strategy(self) -> StrategiesVector:
"""
Register
"""
pass
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight,
sharding_spec_for_input):
'''
Compute the memory cost per device with this specific strategy.
Argument:
dim_partition_dict_for_output(List[int]): The key is the dimension of output to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
dim_partition_dict_for_weight(List[int]): The key is the dimension of weight to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
Return:
total_memory_cost(float): total memory cost per device with this specific strategy
activation_cost(float): the memory cost of activation per device with this specific strategy
weight_memory_cost(float): the memory cost of weight per device with this specific strategy
'''
# compute the size of one element with specific dtype
dtype = self.input_data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# compute the memory cost of activation
activation_numel = self.output_data.numel()
output_mesh_dims = []
for sharding_dim, mesh_dims in dim_partition_dict_for_output.items():
output_mesh_dims.extend(mesh_dims)
activation_sharding_size = 1
for mesh_dim in output_mesh_dims:
activation_sharding_size *= self.device_mesh.shape[mesh_dim]
activation_memory_cost = activation_numel / activation_sharding_size * size_per_elem_bytes
# compute the memory cost of weight
weight_numel = self.weight.numel()
weight_sharding_size = 1
weight_mesh_dims = []
for sharding_dim, mesh_dims in dim_partition_dict_for_weight.items():
weight_mesh_dims.extend(mesh_dims)
for mesh_dim in weight_mesh_dims:
weight_sharding_size *= self.device_mesh.shape[mesh_dim]
weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes
# compute the memory cost of input grad
input_grad_numel = self.input_data.numel()
input_grad_sharding_size = 1
input_grad_mesh_dims = []
for sharding_dim, mesh_dims in sharding_spec_for_input.items():
input_grad_mesh_dims.extend(mesh_dims)
for mesh_dim in input_grad_mesh_dims:
input_grad_sharding_size *= self.device_mesh.shape[mesh_dim]
input_grad_memory_cost = input_grad_numel / input_grad_sharding_size * size_per_elem_bytes
memory_cost_forward = activation_memory_cost + weight_memory_cost
memory_cost_backward = input_grad_memory_cost + weight_memory_cost
return (memory_cost_forward,
memory_cost_backward), activation_memory_cost, weight_memory_cost, input_grad_memory_cost
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.
if hasattr(self.node._meta_data, 'dtype'):
dtype = self.node._meta_data.dtype
else:
assert isinstance(self.node._meta_data,
tuple), f'Only torch.Tensor, torch.fx.Node and tuple of torch.Tensor is expected'
dtype = self.node._meta_data[0].dtype
nodes = self.predecessor_node
return generate_resharding_costs(nodes=nodes,
sharding_specs=sharding_specs,
count_backward=self.handle_backward,
dtype=dtype)
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
return generate_sharding_spec(input_=input_,
device_mesh=self.device_mesh,
dim_partition_dict=dim_partition_dict)
@abstractmethod
def _generate_compute_cost(self, *args, **kwargs):
"""
Compute the flops involved in the node.
"""
pass
import colorsys
import math
import warnings
from copy import deepcopy
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from ..constants import INFINITY_COST
from .operator_handler import OperatorHandler
class ReshapeHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of Reshape Operator, such as torch.reshape, torch.flatten, etc.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.input_data = self.predecessor_node[0]._meta_data
self.output_data = self.node._meta_data
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@ignore_sharding_exception
def register_strategy(self):
# TODO: add strategies with more output sharding specs other than only fully replicated.
input_node = self.strategies_vector.predecessor_nodes[0]
# For reshape function, to keep the computing correctness we keep the sharding
# spec of input is fully replicated. In addition, we will keep the output in
# replica status and let the successor node choose the way to resharding the
# output node. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for reshape function.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict_for_output = {}
if isinstance(self.output_data, tuple):
dim_partition_dict_for_output = [{} for _ in range(len(self.output_data))]
try:
if isinstance(self.output_data, tuple):
output_sharding_spec = []
for output, dim_partition_dict in zip(self.output_data, dim_partition_dict_for_output):
output_sharding_spec.append(self._generate_sharding_spec(output, dim_partition_dict))
else:
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
except AssertionError as e:
warnings.warn(f'{e}')
continue
name = f'{input_sharding_spec.sharding_sequence} -> FULLY REPLICATED'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = 0
# consider node._meta_data is in type of tuple
memory_cost = 0
# compute the communication cost, in reshape op, the communication happens during casting the input sharding spec to fully replicating.
dim_partition_dict_for_replicate_input = {}
replicate_input_sharding_spec = self._generate_sharding_spec(self.input_data,
dim_partition_dict_for_replicate_input)
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
_, _, communication_cost = shape_consistency_manager.shape_consistency(input_sharding_spec,
replicate_input_sharding_spec)
communication_cost = communication_cost["total"]
# generate resharding cost
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_node] = [0 if cost == 0 else INFINITY_COST for cost in resharding_costs[input_node]]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
self.strategies_vector.append(sharding_strategy)
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, List
from colossalai.device.device_mesh import DeviceMesh
__all__ = ['IntermediateStrategy', 'StrategyGenerator']
@dataclass
class IntermediateStrategy:
"""
IntermediateStrategy contains the subset of meta information for ShardingStrategy. It is
to store the essential information regarding the tensor sharding and leave other meta information to OperatorHandler.
Args:
name (str): name of the sharding strategy.
dim_partition_dict (Dict[Dict]): stores the tensor to dim partition dict mapping.
all_reduce_dims (List[int]): stores the dimensions which require an all-reduce operation.
"""
name: str
dim_partition_dict: Dict[str, Dict[int, List[int]]]
all_reduce_axis: List[int] = None
class StrategyGenerator(ABC):
"""
StrategyGenerator is used to generate the same group of sharding strategies.
"""
def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
@abstractmethod
def generate(self) -> List[IntermediateStrategy]:
"""
"""
pass
@abstractmethod
def validate(self, *args, **kwargs) -> bool:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
"""
pass
import math
import operator
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import \
ignore_sharding_exception
from colossalai.auto_parallel.tensor_shard.deprecated.constants import \
INFINITY_COST
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import (ShardingStrategy, StrategiesVector)
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .operator_handler import OperatorHandler
__all__ = ['UnaryElementwiseHandler']
class UnaryElementwiseHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of UnaryElementwiseOp.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.node.op == 'call_module':
target = self.node.target
submod = self.node.graph.owning_module.get_submodule(target)
submod_type = type(submod)
if submod_type == torch.nn.Dropout:
print(f'predecessor nodes of dropout node are {self.predecessor_node}')
input_nodes_len = 0
for check_node in self.predecessor_node:
if isinstance(check_node._meta_data, torch.Tensor):
input_nodes_len += 1
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op, node name is {self.node}, node args is {self.node.args}.'
self.input_data = self.predecessor_node[0]._meta_data
self.input_node = self.predecessor_node[0]
self.output_data = self.node._meta_data
def _generate_compute_cost(self, *args, **kwargs):
return super()._generate_compute_cost(*args, **kwargs)
@ignore_sharding_exception
def register_strategy(self):
# TODO: integrate element-wise func and module together
# create sharding strategy for element-wise function
# For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise function.
for index, strategy in enumerate(self.input_node.strategies_vector):
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
try:
output_sharding_spec = self._generate_sharding_spec(self.output_data, dim_partition_dict)
except AssertionError as e:
warnings.warn(f'{e}')
continue
# add index into name to pass the duplicated check
# we keep same strategies with different name for node merging, and it will not increase the searching space,
# because in solver, this node will be merged into other nodes, and solver will not create a new variable for this node.
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}_{index}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = self.output_data.numel()
memory_cost = 0
resharding_costs = self._generate_resharding_costs([input_sharding_spec])
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[self.input_node] = [
0 if cost == 0 else INFINITY_COST for cost in resharding_costs[self.input_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
self.strategies_vector.append(sharding_strategy)
import operator
import warnings
from copy import deepcopy
from functools import reduce
from typing import Dict, List
import torch
from colossalai.auto_parallel.tensor_shard.deprecated._utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
ignore_sharding_exception,
)
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from .operator_handler import OperatorHandler
__all__ = ['WhereHandler']
class WhereHandler(OperatorHandler):
"""
An OperatorHandler which deals with the sharding strategies of torch.where.
"""
def __init__(self, *args, **kwargs):
# TODO: x or y could be scalar
super().__init__(*args, **kwargs)
assert len(self.predecessor_node) == 3
self.condition_data = self.predecessor_node[0]._meta_data
self.x_data = self.predecessor_node[1]._meta_data
self.y_data = self.predecessor_node[2]._meta_data
self.condition = self.predecessor_node[0]
self.x = self.predecessor_node[1]
self.y = self.predecessor_node[2]
self.output_data = self.node._meta_data
def _generate_sharding_spec(self, input_: torch.Tensor, dim_partition_dict: Dict[int, List[int]]) -> ShardingSpec:
shape = list(input_.shape)
# padding the shape to the same length as output_data
while len(shape) < self.output_data.dim():
shape.insert(0, 1)
shape = torch.Size(shape)
# if the sharding happens on a size one dimension, we should record it as R.
processed_dim_partition_dict = deepcopy(dim_partition_dict)
for dim_index, _ in dim_partition_dict.items():
if shape[dim_index] == 1:
processed_dim_partition_dict.pop(dim_index)
for dim_index, sharding_index_list in processed_dim_partition_dict.items():
sharding_list = [self.device_mesh.mesh_shape[sharding_index] for sharding_index in sharding_index_list]
sharding_size = reduce(operator.mul, sharding_list, 1)
assert shape[
dim_index] % sharding_size == 0, f'we cannot shard the {dim_index} dimension of tensor into {sharding_size} partitions.'
sharding_spec = ShardingSpec(device_mesh=self.device_mesh,
entire_shape=shape,
dim_partition_dict=processed_dim_partition_dict)
return sharding_spec
def _generate_compute_cost(self, total_sharding_size):
lhs_matrix_shape = self.lhs_data.shape[-2:]
rhs_matrix_shape = self.rhs_data.shape[-2:]
batch_dimensions_shape = self.output_data.shape[:-2]
batch_dimensions_product = reduce(operator.mul, batch_dimensions_shape, 1)
compute_cost = reduce(
operator.mul, lhs_matrix_shape) * rhs_matrix_shape[0] * batch_dimensions_product * 2 / total_sharding_size
return compute_cost
def _generate_resharding_costs(self, sharding_specs):
# The resharding_cost of weight is counted due to sharing weight cases.
dtype = self.node._meta_data.dtype
nodes = self.predecessor_node
resharding_costs = {}
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# shape consistency manager is a singleton class
shape_consistency_manager = ShapeConsistencyManager()
for input_node, input_spec in zip(nodes, sharding_specs):
resharding_costs[input_node] = []
for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# if the input shape is smaller than the target input, we will fill the input to the same length as target.
# Then, use the padded input sharding spec to compute the resharding cost.
if len(input_sharding_spec.entire_shape) < len(input_spec.entire_shape):
new_entire_shape = list(input_sharding_spec.entire_shape)
while len(new_entire_shape) < len(input_spec.entire_shape):
new_entire_shape.insert(0, 1)
new_entire_shape = torch.Size(new_entire_shape)
new_device_mesh = input_sharding_spec.device_mesh
new_dim_partition_dict = input_sharding_spec.dim_partition_dict
input_sharding_spec = ShardingSpec(device_mesh=new_device_mesh,
entire_shape=new_entire_shape,
dim_partition_dict=new_dim_partition_dict)
# compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
input_sharding_spec, input_spec)
total_resharding_cost = total_resharding_cost['total']
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost)
return resharding_costs
def _convert_partition_dict_to_sharding_spec(self, dim_partition_list):
sharding_spec_list = []
check_duplicated_list = []
for output_dim_partition_dict in dim_partition_list:
try:
output_sharding_spec = self._generate_sharding_spec(self.output_data, output_dim_partition_dict)
except AssertionError as e:
warnings.warn(f'{e}')
break
sharding_seq = output_sharding_spec.sharding_sequence
if sharding_seq not in check_duplicated_list:
check_duplicated_list.append(sharding_seq)
sharding_spec_list.append(output_sharding_spec)
return sharding_spec_list
def _enumerate_all_possible_output(self, mesh_dim_0, mesh_dim_1):
# use mesh_dim_0, mesh_dim_1 instead of constant 0, 1 in here for N-D device mesh scaliablity.
output_dim_partition_list = []
dim_size = self.output_data.dim()
# enumerate all the 2D sharding cases
sharding_list_2d = enumerate_all_possible_2d_sharding(mesh_dim_0, mesh_dim_1, dim_size)
output_dim_partition_list.extend(sharding_list_2d)
# enumerate all the 1D sharding cases
sharding_list_1d_on_dim_0 = enumerate_all_possible_1d_sharding(mesh_dim_0, dim_size)
output_dim_partition_list.extend(sharding_list_1d_on_dim_0)
sharding_list_1d_on_dim_1 = enumerate_all_possible_1d_sharding(mesh_dim_1, dim_size)
output_dim_partition_list.extend(sharding_list_1d_on_dim_1)
# add empty dict for fully replicated case
output_dim_partition_list.append({})
output_sharding_spec_list = self._convert_partition_dict_to_sharding_spec(output_dim_partition_list)
return output_sharding_spec_list
@ignore_sharding_exception
def _register_strategy(self, output_sharding_spec):
dim_partition_dict_for_input = output_sharding_spec.dim_partition_dict
sharding_spec_for_condition = self._generate_sharding_spec(self.condition_data, dim_partition_dict_for_input)
sharding_spec_for_x = self._generate_sharding_spec(self.x_data, dim_partition_dict_for_input)
sharding_spec_for_y = self._generate_sharding_spec(self.y_data, dim_partition_dict_for_input)
name = f'{output_sharding_spec.sharding_sequence} = {sharding_spec_for_condition.sharding_sequence} x {sharding_spec_for_x.sharding_sequence} x {sharding_spec_for_y.sharding_sequence}'
dim_partition_dict_for_output = output_sharding_spec.dim_partition_dict
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs(
[sharding_spec_for_condition, sharding_spec_for_x, sharding_spec_for_y])
# compute the computation cost of this strategy
sharding_dims = []
for mesh_dims in dim_partition_dict_for_output.values():
for mesh_dim in mesh_dims:
sharding_dims.append(self.device_mesh.shape[mesh_dim])
sharding_size = reduce(operator.mul, sharding_dims, 1)
memory_cost = self.output_data.numel() / sharding_size
compute_cost = memory_cost
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=output_sharding_spec,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_condition, sharding_spec_for_x,
sharding_spec_for_y))
self.strategies_vector.append(sharding_strategies)
def register_strategy(self) -> StrategiesVector:
MESH_DIM_LIST = [0, 1]
output_sharding_specs = self._enumerate_all_possible_output(MESH_DIM_LIST[0], MESH_DIM_LIST[1])
for output_sharding_spec in output_sharding_specs:
self._register_strategy(output_sharding_spec)
from dataclasses import dataclass
__all__ = ['SolverOptions']
@dataclass
class SolverOptions:
"""
SolverOptions is a dataclass used to configure the preferences for the parallel execution plan search.
"""
fast: bool = False
from copy import deepcopy
from dataclasses import dataclass
from abc import ABC, abstractmethod
from enum import Enum
import operator
import torch
from functools import reduce
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from typing import Dict, List, Union, Tuple, Any
from torch.fx.node import Node
from .constants import *
__all__ = ['ShardingStrategy', 'StrategiesVector']
@dataclass
class ShardingStrategy:
'''
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
and costs information using in solver.
Argument:
name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
compute_cost(float): Computation cost to complete this strategy.(default to 0)
communication_cost(float): Communication cost to complete this strategy.(default to 0)
memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.(default to None)
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
'''
name: str
# TODO: output of fx node,such as torch.var_mean, could be a tuple, so we cannot simply suppose it is a tensor.
output_sharding_spec: Union[ShardingSpec, Tuple[ShardingSpec]]
compute_cost: float = 0.
communication_cost: float = 0.
memory_cost: float = 0.
resharding_costs: Dict[Node, List[float]] = None
# sometimes the input node could be a tuple of nodes, but most of op won't accept tuple of node as input.
# Therefore, we could process them at the specific op(operator.getitem)
input_shardings: List[ShardingSpec] = None
class StrategiesVector(list):
'''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
Argument:
node (Node): node for which the list of sharding strategies are generated.
'''
def __init__(self, node: Node):
super().__init__()
self.node = node
# fetch its input and output nodes
# TODO: placeholder input nodes
self.predecessor_nodes = list(node._input_nodes.keys())
if self.node.op == 'output':
self.predecessor_nodes = list(node._input_nodes.keys())[:1]
self.successor_nodes = list(node.users.keys())
def check_merge(self):
merge_label = False
if self.node.op == 'call_module':
target = self.node.target
root_module = self.node.graph.owning_module
submod = root_module.get_submodule(target)
submod_type = type(submod)
# merge elementwise module node into source nodes
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if submod_type in ELEMENTWISE_MODULE_OP:
merge_label = True
if self.node.op == 'call_function':
# we could merge element-wise op, because the output sharding spec is always same as the input sharding spec.
if self.node.target in ELEMENTWISE_FUNC_OP:
merge_label = True
# we could merge bcast op if the rhs is a scalar, because it will fall back to the element-wise case.
if self.node.target in BCAST_FUNC_OP and len(self.predecessor_nodes) == 1:
merge_label = True
# we could merge reshape op, because the output sharding spec of reshape op is always fully replicated.
if self.node.target in RESHAPE_FUNC_OP:
merge_label = True
return merge_label
import multiprocessing
import time
import warnings
from typing import Dict
import numpy as np
from torch.fx.graph import Graph
from torch.fx.node import Node
from .constants import INFINITY_COST
from .cost_graph import CostGraph
from .graph_analysis import GraphAnalyser
from .strategies_constructor import StrategiesConstructor
try:
import pulp
from pulp import LpMinimize, LpProblem, LpStatus, LpVariable, lpDot, lpSum
except:
warnings.warn(f'please install the pulp')
__all___ = ['Solver']
class Solver:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser,
memory_budget: float = -1.0,
solution_numbers: int = 1,
memory_increasing_coefficient: float = 1.3):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.nodes = [strategies_vector.node for strategies_vector in self.leaf_strategies]
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
if self.solution_numbers > 1:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis()
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
# The last objective value of the best ILP solution.
self.last_objective = None
def _recover_merged_node_strategy(self):
'''
During cost graph constructing, some nodes, such as unary element-wise node or ReshapeOp, were merged into the previous node.
Therefore, the index of those strategies are copied from the previous node. This method is used to recover the strategy index of those merged
node.
'''
for node_index, node in enumerate(self.nodes):
if node.strategies_vector.check_merge():
# the merged node has only one input, and its strategies follow the input sharding strategy
input_strategies_vector = node.args[0].strategies_vector
input_best_strategy_index = self.last_s_val[node_index - 1]
input_sharding_spec = input_strategies_vector[input_best_strategy_index].output_sharding_spec
for strategy_index, strategy in enumerate(node.strategies_vector):
if strategy.input_shardings[0].sharding_sequence == input_sharding_spec.sharding_sequence:
self.last_s_val[node_index] = strategy_index
break
def _generate_node_index_dict(self) -> Dict[Node, int]:
node_index_dict = {}
for index, strategies_vector in enumerate(self.leaf_strategies):
node_index_dict[strategies_vector.node] = index
return node_index_dict
def _prepare_data_for_solver(self):
'''
Extract information from components for solver.
'''
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
# prepare strategies_len
strategies_len = []
for node in self.nodes:
strategies_len.append(self.cost_graph.node_lens[node])
strategies_len = np.array(strategies_len)
# prepare following_nodes
following_nodes = self.cost_graph.following_dict
index_following_nodes = {}
for src, target in following_nodes.items():
src_index = self.node_index_dict[src]
target_index = self.node_index_dict[target]
index_following_nodes[src_index] = target_index
following_nodes = index_following_nodes
for index in range(node_nums):
if index not in following_nodes:
following_nodes[index] = -1
# prepare edge_pairs and resharding costs
edge_pairs = []
resharding_costs = []
for pairs, edge_cost in self.cost_graph.edge_costs.items():
src_node = pairs[0]
dst_node = pairs[1]
src_node_index = self.node_index_dict[src_node]
dst_node_index = self.node_index_dict[dst_node]
edge_pairs.append(src_node_index)
edge_pairs.append(dst_node_index)
for i in range(strategies_len[src_node_index]):
for j in range(strategies_len[dst_node_index]):
resharding_costs.append(edge_cost[(i, j)])
edge_pairs = np.array(edge_pairs)
resharding_costs = np.array(resharding_costs)
# prepare liveness_set
liveness_set = self.liveness_list
# omit alias_set now
alias_set = None
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
compute_costs = []
communication_costs = []
memory_costs = []
extra_node_costs = self.cost_graph.extra_node_costs
for strategies_vector in self.leaf_strategies:
node = strategies_vector.node
for index, strategy in enumerate(strategies_vector):
compute_costs.append(strategy.compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
# cost into
if node in extra_node_costs:
origin_communication_cost = strategy.communication_cost
extra_node_cost = extra_node_costs[node][index]
communication_cost = origin_communication_cost + extra_node_cost
communication_costs.append(communication_cost)
else:
communication_costs.append(strategy.communication_cost)
# temporarily we just consider the forward memory cost
memory_cost = strategy.memory_cost
if isinstance(memory_cost, tuple):
memory_costs.append(memory_cost[0])
else:
memory_costs.append(memory_cost)
compute_costs = np.array(compute_costs)
communication_costs = np.array(communication_costs)
memory_costs = np.array(memory_costs)
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None):
"""
Call the solver with serialized arguments.
"""
tic = time.time()
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
assert isinstance(x, np.ndarray)
assert len(strategies_len) == node_nums, "strategies_len"
def get_non_zero_index(binary_vector):
"""
Get the index of non-zero item in a vector.
"""
ct = 0
ret = None
for i, elem in enumerate(binary_vector):
if pulp.value(elem):
ret = i
ct += 1
assert ct == 1
return ret
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
######################
# omit alias set now #
######################
# A = alias_set.reshape((-1, 2)) # noqa
# for (i, j) in A:
# prod_length = strategies_len[i] * strategies_len[j]
# v.append(alias_convert_costs[pt:pt + prod_length])
# pt += prod_length
# assert pt == len(alias_convert_costs)
# L = [] # noqa
# pt = node_nums
# for i in range(node_nums):
# length = liveness_set[i]
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
v = []
pt = 0
c = []
d = []
m = []
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[pt:pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
# 1. Create variables
#############################
# create variables for node #
#############################
s = []
num_nodes = 0
reverse_follow_backpatch = []
for i in range(node_nums):
if s_follow[i] < 0:
if strategies_len[i] == 1:
s.append([1])
else:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
else:
s.append(None)
reverse_follow_backpatch.append(i)
for i in reverse_follow_backpatch:
s[i] = s[s_follow[i]]
#############################
# create variables for edge #
#############################
e = []
num_edges = 0
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
for element in s:
assert len(element) > 0
# 2. Set initial value
######################################
# set a initial value for warm start #
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
s[idx][i].fixValue()
# 3. Objective
prob = LpProblem("myProblem", LpMinimize)
###################################################################
# computing the node cost(computing cost and communication cost) #
###################################################################
obj = 0
for i in range(node_nums):
assert len(s[i]) == len(c[i])
assert len(s[i]) == len(d[i])
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
assert len(e[i]) == len(r[i])
obj += lpDot(e[i], r[i])
prob += obj
# 4. Constraints
# (a). specified by `cat="Binary"`
# (b)
#################################################
# make sure each node only choose one strategy #
#################################################
for i in range(node_nums):
if s_follow[i] < 0:
prob += lpSum(s[i]) == 1
# (c)
#################################################
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
# (e)
prob += lpSum(e[idx]) == 1
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
######################
# omit alias set now #
######################
# alias_set = set()
# for (idx, (i, j)) in enumerate(A):
# R = len(s[i]) # noqa
# C = len(s[j]) # noqa
# if (i, j) in alias_set:
# raise ValueError(f"Duplicated edges: {(i, j)}")
# alias_set.add((i, j))
# alias_set.add((j, i))
# for row in range(len(s[i])):
# for col in range(len(s[j])):
# if v[idx][row * C + col] > 0.5:
# prob += s[i][row] + s[j][col] <= 1
verbose = True
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
prob.solve(solver)
status = prob.status
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. "
"Please increase the memory budget.")
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
for i in range(node_nums):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
self.last_s_val = list(s_val)
self._recover_merged_node_strategy()
self.last_objective = objective
if objective > INFINITY_COST:
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
return self.last_s_val, e_val, self.last_objective, status
def call_solver_serialized_args(self):
"""
Call the solver with serialized arguments and handle python errors. Additionally,
we could give a serious of solutions with different memory budget.
"""
if self.solution_numbers == 1:
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
return ret
origin_memory_budget = self.memory_budget
memory_budget_list = [
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
]
ret_list = []
for memory_budget in memory_budget_list:
self.memory_budget = memory_budget
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
ret_list.append(ret)
return ret_list
import builtins
import math
import operator
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx import Graph, Node
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from ._utils import generate_resharding_costs, generate_sharding_spec
from .constants import *
from .op_handler import *
from .options import SolverOptions
from .sharding_strategy import ShardingStrategy, StrategiesVector
__all__ = ['StrategiesConstructor']
class StrategiesConstructor:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.solver_options = solver_options
def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
'''
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)
def _is_bcast_matmul(self, node):
is_bcast_matmul = False
if node.target is torch.matmul and len(node.args) == 2:
lhs_data = node.args[0]._meta_data
rhs_data = node.args[1]._meta_data
if lhs_data.dim() >= 3 and rhs_data.dim() >= 3:
is_bcast_matmul = True
return is_bcast_matmul
def build_strategies_and_cost(self):
for node in self.nodes:
strategies_vector = StrategiesVector(node)
input_nodes_len = 0
for check_node in strategies_vector.predecessor_nodes:
if isinstance(check_node._meta_data, torch.Tensor):
input_nodes_len += 1
# input_nodes_len = len(strategies_vector.predecessor_nodes)
# placeholder node
if node.op == 'placeholder':
# For placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
if self.solver_options.fast:
# create sharding strategy for placeholder
name = 'Replica Placeholder'
dim_partition_dict = {}
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_placeholder = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_placeholder)
# get_attr node
if node.op == 'get_attr':
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
if self.solver_options.fast:
# create sharding strategy for get_attr
name = 'Replica Attribute'
dim_partition_dict = {}
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_attribute)
# call_module node
if node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
# conv module
if submod_type in CONV_MODULE_OP:
# use ConvHandler to create sharding strategies for conv module node
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
conv_handler.register_strategy()
# linear module
elif submod_type in LINEAR_MODULE_OP:
# use DotHandler to create sharding strategies for linear module node
dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
dot_handler.register_strategy()
# element-wise module
elif submod_type in ELEMENTWISE_MODULE_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
# BatchNormNd module
elif submod_type in BATCHNORM_MODULE_OP:
# create sharding strategy for element-wise module
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector)
norm_handler.register_strategy()
# for strategy in norm_handler.strategies_vector:
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
# assert False
# MaxPool module
elif submod_type in POOL_MODULE_OP:
# TODO: add sharding constraints on image dimension
# e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension
# create sharding strategy for element-wise module
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.'
input_node = strategies_vector.predecessor_nodes[0]
# For element-wise module, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise module.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# embedding module
elif submod_type in EMBEDDING_MODULE_OP:
embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector)
embedding_handler.register_strategy()
# layernorm module
elif submod_type in LAYERNORM_MODULE_OP:
layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector)
layernorm_handler.register_strategy()
# other module
else:
raise RuntimeError(f'{submod_type} module is NOT supported now.')
# call_function node
if node.op == 'call_function':
target = node.target
# conv function
if target in CONV_FUNC_OP:
# use ConvHandler to create sharding strategies for conv node
# TODO: the operator_handler does NOT support function node processing now.
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
conv_handler.register_strategy()
# linear function
elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node):
# use DotHandler to create sharding strategies for linear node
# TODO: the operator_handler does NOT support function node processing now.
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
linear_handler.register_strategy()
# where function
elif target == torch.where:
if input_nodes_len == 1:
# both of x and y are scalar
pass
elif input_nodes_len == 2:
# one of x or y is type of scalar
pass
else:
# general case
where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
where_handler.register_strategy()
# reshape function
elif target in RESHAPE_FUNC_OP:
# use ReshapeHandler to create sharding strategies for rehsape node
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
reshape_handler.register_strategy()
# element-wise function
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
# bcast op
elif target in BCAST_FUNC_OP:
if isinstance(node._meta_data, torch.Tensor):
bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
bcast_op_handler.register_strategy()
# torch.var_mean
elif target == torch.var_mean:
dim = node.kwargs['dim']
input_tensor_node = strategies_vector.predecessor_nodes[0]
for strategy in input_tensor_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
entire_shape_input = input_sharding_spec.entire_shape
dim_partition_dict_input = input_sharding_spec.dim_partition_dict
name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})'
if dim in dim_partition_dict_input:
# We need to make the action dimension in replicate status
dim_partition_dict_for_input = deepcopy(dim_partition_dict_input)
dim_partition_dict_for_input.pop(dim)
new_input_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_input,
dim_partition_dict=dim_partition_dict_for_input)
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[new_input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[new_input_sharding_spec])
else:
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partion_dict=dim_partition_dict_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# operator.getitem
elif target == operator.getitem:
index = node.args[1]
input_tensor_node = strategies_vector.predecessor_nodes[0]
for strategy in input_tensor_node.strategies_vector:
if isinstance(strategy.output_sharding_spec, ShardingSpec):
input_sharding_spec = strategy.output_sharding_spec
else:
input_sharding_spec = strategy.output_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec],
index=index)
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[strategy.output_sharding_spec])
strategies_vector.append(sharding_strategy)
# torch.arange function
elif target == torch.arange:
name = f'FULLY REPLICATED ARANGE'
entire_shape_output = node._meta_data.shape
dim_partition_dict_for_output = {}
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
memory_cost = node._meta_data.numel()
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=0,
memory_cost=memory_cost)
strategies_vector.append(sharding_strategy)
# op list to be processed to support gpt2
elif target in (builtins.getattr, operator.le, torch.addmm):
pass
# other function
else:
raise RuntimeError(f'{target} function is NOT supported now.')
# call_method node
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
if method in (torch.Tensor.size,):
pass
elif method in ELEMENTWISE_METHOD_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
elif method in RESHAPE_METHOD_OP:
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
reshape_handler.register_strategy()
# print(strategies_vector)
# if len(strategies_vector) == 0:
# print(node)
# assert False
else:
raise RuntimeError(f'{method} function is NOT supported now.')
# output node
if node.op == 'output':
if self.solver_options.fast:
# create sharding strategy for output
name = 'Replica Output'
input_nodes = strategies_vector.predecessor_nodes
input_sharding_specs = []
for input_node in input_nodes:
dim_partition_dict_for_input = {}
entire_shape = input_node._meta_data.shape
sharding_spec = ShardingSpec(self.device_mesh,
entire_shape,
dim_partition_dict=dim_partition_dict_for_input)
input_sharding_specs.append(sharding_spec)
dim_partition_dict = {}
output_sharding_spec = input_sharding_specs
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
input_sharding_specs)
# clear the resharding cost for the output node
# TODO: we may remove this in final version
for prev_node, resharding_cost_list in resharding_costs.items():
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
sharding_strategy_attribute = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=tuple(input_sharding_specs))
strategies_vector.append(sharding_strategy_attribute)
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)
...@@ -8,14 +8,9 @@ from torch.fx.graph import Graph ...@@ -8,14 +8,9 @@ from torch.fx.graph import Graph
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
from colossalai.auto_parallel.tensor_shard.solver import ( from colossalai.auto_parallel.tensor_shard.solver import CostGraph, GraphAnalyser, Solver, StrategiesConstructor
CostGraph,
GraphAnalyser,
Solver,
SolverOptions,
StrategiesConstructor,
)
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.graph_module import ColoGraphModule from colossalai.fx.graph_module import ColoGraphModule
...@@ -69,13 +64,43 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f ...@@ -69,13 +64,43 @@ def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[f
pass pass
def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh): def build_strategy_constructor(graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str,
shard_option: str):
''' '''
This method is used to build the strategy_constructor for the given graph. This method is used to build the strategy_constructor for the given graph.
After this method, each node in the graph will have a strategies_vector which After this method, each node in the graph will have a strategies_vector which
is constructed by the related node handler. is constructed by the related node handler.
''' '''
solver_options = SolverOptions() if solver_preference == 'standard':
solver_preference = SolverPerference.STANDARD
elif solver_preference == 'tp':
solver_preference = SolverPerference.TP
elif solver_preference == 'dp':
solver_preference = SolverPerference.DP
else:
raise ValueError(f'Invalid solver_preference: {solver_preference}')
if dataloader_option == 'replicated':
dataloader_option = DataloaderOption.REPLICATED
elif dataloader_option == 'distributed':
dataloader_option = DataloaderOption.DISTRIBUTED
else:
raise ValueError(f'Invalid dataloader_option: {dataloader_option}')
if shard_option == 'standard':
shard_option = ShardOption.STANDARD
elif shard_option == 'shard':
shard_option = ShardOption.SHARD
elif shard_option == 'shard_last_axis':
shard_option = ShardOption.SHARD_LAST_AXIS
elif shard_option == 'full_shard':
shard_option = ShardOption.FULL_SHARD
else:
raise ValueError(f'Invalid shard_option: {shard_option}')
solver_options = SolverOptions(solver_perference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost() strategies_constructor.build_strategies_and_cost()
...@@ -183,6 +208,9 @@ def initialize_model(model: nn.Module, ...@@ -183,6 +208,9 @@ def initialize_model(model: nn.Module,
device_mesh: DeviceMesh, device_mesh: DeviceMesh,
memory_budget: float = -1.0, memory_budget: float = -1.0,
overlap: bool = False, overlap: bool = False,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False, save_solver_solution: bool = False,
load_solver_solution: bool = False, load_solver_solution: bool = False,
solution_path: str = None, solution_path: str = None,
...@@ -198,6 +226,12 @@ def initialize_model(model: nn.Module, ...@@ -198,6 +226,12 @@ def initialize_model(model: nn.Module,
the memory budget will be infinity. the memory budget will be infinity.
overlap(optional): the overlap is used to specify whether to overlap gradient communication and overlap(optional): the overlap is used to specify whether to overlap gradient communication and
backward computing. backward computing.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path. to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
...@@ -212,7 +246,12 @@ def initialize_model(model: nn.Module, ...@@ -212,7 +246,12 @@ def initialize_model(model: nn.Module,
graph = tracer.trace(root=model, meta_args=meta_args) graph = tracer.trace(root=model, meta_args=meta_args)
gm = ColoGraphModule(model, graph, model.__class__.__name__) gm = ColoGraphModule(model, graph, model.__class__.__name__)
gm.recompile() gm.recompile()
strategies_constructor = build_strategy_constructor(graph, device_mesh)
strategies_constructor = build_strategy_constructor(graph,
device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
shard_option=shard_option)
if load_solver_solution: if load_solver_solution:
solution = torch.load(solution_path) solution = torch.load(solution_path)
else: else:
...@@ -240,6 +279,9 @@ def autoparallelize(model: nn.Module, ...@@ -240,6 +279,9 @@ def autoparallelize(model: nn.Module,
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None, alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
logical_mesh_shape: Tuple[int] = None, logical_mesh_shape: Tuple[int] = None,
logical_mesh_id: torch.Tensor = None, logical_mesh_id: torch.Tensor = None,
solver_preference: str = 'standard',
dataloader_option: str = 'replicated',
shard_option: str = 'standard',
save_solver_solution: bool = False, save_solver_solution: bool = False,
load_solver_solution: bool = False, load_solver_solution: bool = False,
solver_solution_path: str = None, solver_solution_path: str = None,
...@@ -262,6 +304,12 @@ def autoparallelize(model: nn.Module, ...@@ -262,6 +304,12 @@ def autoparallelize(model: nn.Module,
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
generated by search_best_logical_mesh_shape function. generated by search_best_logical_mesh_shape function.
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id. logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
to the solution_path. to the solution_path.
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
...@@ -280,6 +328,8 @@ def autoparallelize(model: nn.Module, ...@@ -280,6 +328,8 @@ def autoparallelize(model: nn.Module,
rst_to_unpack = initialize_model(model, rst_to_unpack = initialize_model(model,
meta_args, meta_args,
device_mesh, device_mesh,
solver_preference=solver_preference,
dataloader_option=dataloader_option,
save_solver_solution=save_solver_solution, save_solver_solution=save_solver_solution,
load_solver_solution=load_solver_solution, load_solver_solution=load_solver_solution,
solution_path=solver_solution_path, solution_path=solver_solution_path,
......
...@@ -11,7 +11,6 @@ from .layer_norm_handler import LayerNormModuleHandler ...@@ -11,7 +11,6 @@ from .layer_norm_handler import LayerNormModuleHandler
from .linear_handler import LinearFunctionHandler, LinearModuleHandler from .linear_handler import LinearFunctionHandler, LinearModuleHandler
from .matmul_handler import MatMulHandler from .matmul_handler import MatMulHandler
from .normal_pooling_handler import NormPoolingHandler from .normal_pooling_handler import NormPoolingHandler
from .option import ShardOption
from .output_handler import OutputHandler from .output_handler import OutputHandler
from .permute_handler import PermuteHandler from .permute_handler import PermuteHandler
from .placeholder_handler import PlaceholderHandler from .placeholder_handler import PlaceholderHandler
...@@ -31,6 +30,6 @@ __all__ = [ ...@@ -31,6 +30,6 @@ __all__ = [
'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler', 'UnaryElementwiseHandler', 'DefaultReshapeHandler', 'PlaceholderHandler', 'OutputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler', 'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler',
'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler', 'GetItemHandler', 'GetattrHandler', 'ViewHandler', 'PermuteHandler', 'TensorConstructorHandler',
'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'ShardOption', 'EmbeddingModuleHandler', 'EmbeddingFunctionHandler', 'SumHandler', 'SoftmaxHandler', 'TransposeHandler',
'TransposeHandler', 'SplitHandler' 'SplitHandler'
] ]
...@@ -152,7 +152,10 @@ class LinearModuleHandler(MetaInfoModuleHandler): ...@@ -152,7 +152,10 @@ class LinearModuleHandler(MetaInfoModuleHandler):
op_data_mapping = self.get_operation_data_mapping() op_data_mapping = self.get_operation_data_mapping()
generators = [] generators = []
generators.append( generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear')) LinearProjectionStrategyGenerator(op_data_mapping,
self.device_mesh,
linear_projection_type='linear',
solver_perference=self.solver_perference))
return generators return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]: def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
from torch.fx.node import Node from torch.fx.node import Node
from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register from colossalai.auto_parallel.meta_profiler.metainfo import MetaInfo, meta_register
from colossalai.auto_parallel.tensor_shard.node_handler.option import ShardOption from colossalai.auto_parallel.tensor_shard.options import ShardOption, SolverPerference
from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData, OperationData,
OperationDataType, OperationDataType,
...@@ -32,19 +32,19 @@ class NodeHandler(ABC): ...@@ -32,19 +32,19 @@ class NodeHandler(ABC):
strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector. strategies_vector (StrategiesVector): all the strategies generated in this handler will be recorded into the strategies_vector.
''' '''
def __init__( def __init__(self,
self, node: Node,
node: Node, device_mesh: DeviceMesh,
device_mesh: DeviceMesh, strategies_vector: StrategiesVector,
strategies_vector: StrategiesVector, shard_option: ShardOption = ShardOption.STANDARD,
shard_option: ShardOption = ShardOption.STANDARD, solver_perference: SolverPerference = SolverPerference.STANDARD) -> None:
) -> None:
self.node = node self.node = node
self.predecessor_node = list(node._input_nodes.keys()) self.predecessor_node = list(node._input_nodes.keys())
self.successor_node = list(node.users.keys()) self.successor_node = list(node.users.keys())
self.device_mesh = device_mesh self.device_mesh = device_mesh
self.strategies_vector = strategies_vector self.strategies_vector = strategies_vector
self.shard_option = shard_option self.shard_option = shard_option
self.solver_perference = solver_perference
def update_resharding_cost(self, strategy: ShardingStrategy) -> None: def update_resharding_cost(self, strategy: ShardingStrategy) -> None:
""" """
...@@ -187,15 +187,24 @@ class NodeHandler(ABC): ...@@ -187,15 +187,24 @@ class NodeHandler(ABC):
remove_strategy_list = [] remove_strategy_list = []
for strategy in self.strategies_vector: for strategy in self.strategies_vector:
shard_level = 0 shard_axis_list = []
last_axis = len(self.device_mesh.mesh_shape) - 1
for op_data, sharding_spec in strategy.sharding_specs.items(): for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor): if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
for dim, shard_axis in sharding_spec.dim_partition_dict.items(): for dim, shard_axes in sharding_spec.dim_partition_dict.items():
shard_level += len(shard_axis) for shard_axis in shard_axes:
if shard_axis not in shard_axis_list:
shard_axis_list.append(shard_axis)
shard_level = len(shard_axis_list)
using_last_axis = last_axis in shard_axis_list or -1 in shard_axis_list
if self.shard_option == ShardOption.SHARD and shard_level == 0: if self.shard_option == ShardOption.SHARD and shard_level == 0:
remove_strategy_list.append(strategy) remove_strategy_list.append(strategy)
if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1: if self.shard_option == ShardOption.FULL_SHARD and shard_level <= 1:
remove_strategy_list.append(strategy) remove_strategy_list.append(strategy)
if self.shard_option == ShardOption.SHARD_LAST_AXIS:
if shard_level != 1 or using_last_axis == False:
remove_strategy_list.append(strategy)
for strategy in remove_strategy_list: for strategy in remove_strategy_list:
self.strategies_vector.remove(strategy) self.strategies_vector.remove(strategy)
......
from enum import Enum
__all__ = ['ShardOption']
class ShardOption(Enum):
"""
This enum class is to define the shard level required in node strategies.
Notes:
STANDARD: We do not add any extra shard requirements.
SHARD: We require the node to be shard using at least one device mesh axis.
FULL_SHARD: We require the node to be shard using all device mesh axes.
"""
STANDARD = 0
SHARD = 1
FULL_SHARD = 2
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment