Unverified Commit 82d4376c authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] adapt solver with resnet (#1583)

* [autoparallel]adapt solver with resnet

* polish code

* polish code
parent f3403ff9
from .operator_handler import OperatorHandler
from .dot_handler import DotHandler
from .conv_handler import ConvHandler
from .sharding_strategy import ShardingStrategy, StrategiesVector from .sharding_strategy import ShardingStrategy, StrategiesVector
from .graph_analysis import GraphAnalyser from .graph_analysis import GraphAnalyser
from .solver import Solver
from .cost_graph import CostGraph
from .strategies_constructor import StrategiesConstructor
from .constants import *
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'StrategiesVector', 'ShardingStrategy', 'GraphAnalyser'] __all__ = ['StrategiesVector', 'ShardingStrategy', 'GraphAnalyser', 'Solver', 'StrategiesConstructor', 'CostGraph']
...@@ -3,13 +3,14 @@ import operator ...@@ -3,13 +3,14 @@ import operator
__all__ = [ __all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP', 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'LINEAR_MODULE_OP',
'LINEAR_FUNC_OP' 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP'
] ]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
# TODO: flatten should not be added into this group
ELEMENTWISE_FUNC_OP = [ ELEMENTWISE_FUNC_OP = [
torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv, torch.add, operator.add, torch.abs, torch.cos, torch.exp, torch.mul, operator.mul, operator.floordiv,
operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout operator.truediv, operator.neg, torch.multiply, torch.nn.functional.relu, torch.nn.functional.dropout, torch.flatten
] ]
CONV_MODULE_OP = [ CONV_MODULE_OP = [
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
...@@ -20,3 +21,7 @@ CONV_FUNC_OP = [ ...@@ -20,3 +21,7 @@ CONV_FUNC_OP = [
] ]
LINEAR_MODULE_OP = [torch.nn.Linear] LINEAR_MODULE_OP = [torch.nn.Linear]
LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm] LINEAR_FUNC_OP = [torch.nn.functional.linear, torch.matmul, torch.bmm]
BATCHNORM_MODULE_OP = [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d, torch.nn.SyncBatchNorm]
POOL_MODULE_OP = [torch.nn.MaxPool1d, torch.nn.MaxPool2d, torch.nn.MaxPool3d, torch.nn.AdaptiveAvgPool2d]
INFINITY_COST = 1e13
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from typing import List from typing import List
import math import math
from torch.fx.node import Node from torch.fx.node import Node
......
...@@ -15,7 +15,7 @@ class LiveVariable: ...@@ -15,7 +15,7 @@ class LiveVariable:
LiveVariable is a data structure to store the meta information of a variable for liveness analysis. LiveVariable is a data structure to store the meta information of a variable for liveness analysis.
""" """
name: str name: str
meta: Union[Any, List[Any]] node: Node
is_inplace: bool is_inplace: bool
...@@ -80,13 +80,13 @@ class GraphAnalyser: ...@@ -80,13 +80,13 @@ class GraphAnalyser:
""" """
return self._graph return self._graph
def liveness_analysis(self) -> OrderedDict[int, LiveStage]: def liveness_analysis(self) -> List[LiveStage]:
""" """
Analyse the graph to obtain the variable liveness information. This function returns Analyse the graph to obtain the variable liveness information. This function returns
an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object. an ordered dictionary where the key is the compute stage ID and the value is a LivenessStage object.
""" """
compute_nodes = self.graph.nodes compute_nodes = self.graph.nodes
liveness_dict = ODict() liveness_list = []
# checked: record all variables created since the first stage # checked: record all variables created since the first stage
# all: record the live variables only exist until the current stage. # all: record the live variables only exist until the current stage.
...@@ -97,25 +97,6 @@ class GraphAnalyser: ...@@ -97,25 +97,6 @@ class GraphAnalyser:
all_live_variables = LiveVariableVector() all_live_variables = LiveVariableVector()
unique_live_vars = LiveVariableVector() unique_live_vars = LiveVariableVector()
def _add_param_or_buf(node, tensor_type):
module = get_node_module(node)
if tensor_type == 'param':
iterator = module.named_parameters()
elif tensor_type == 'buffer':
iterator = module.named_buffers()
else:
raise ValueError(f"Expected tensor_type to be param or buffer, but got {tensor_type}")
for name, tensor in iterator:
tensor_name = f'{node.name}.{name}'
if not checked_variables.exists(tensor_name):
live_tensor = LiveVariable(name=tensor_name, meta=tensor.to('meta'), is_inplace=False)
unique_live_vars.append(live_tensor)
checked_variables.append(live_tensor)
all_live_variables.append(live_tensor)
for idx, node in enumerate(compute_nodes): for idx, node in enumerate(compute_nodes):
############################# #############################
# find new living variables # # find new living variables #
...@@ -135,26 +116,19 @@ class GraphAnalyser: ...@@ -135,26 +116,19 @@ class GraphAnalyser:
# add the output var # add the output var
meta = getattr(node, '_meta_data', None) meta = getattr(node, '_meta_data', None)
live_var = LiveVariable(name=node.name, meta=meta, is_inplace=is_inplace) live_var = LiveVariable(name=node.name, node=node, is_inplace=is_inplace)
if not is_inplace: if not is_inplace:
unique_live_vars.append(live_var) unique_live_vars.append(live_var)
checked_variables.append(live_var) checked_variables.append(live_var)
all_live_variables.append(live_var) all_live_variables.append(live_var)
# add the model parameters
if node.op == 'call_module':
_add_param_or_buf(node, tensor_type='param')
_add_param_or_buf(node, tensor_type='buffer')
# add this output variable to the checked list
checked_variables.append(live_var)
# check if any input is not checked yet # check if any input is not checked yet
for arg in node.args: for arg in node.args:
arg_name = str(arg) if not isinstance(arg, Node):
continue
arg_name = arg.name
if not checked_variables.exists(arg_name): if not checked_variables.exists(arg_name):
meta = getattr(node, '_meta_data', None) live_var_from_arg = LiveVariable(name=arg_name, node=node, is_inplace=False)
live_var_from_arg = LiveVariable(name=arg_name, meta=meta, is_inplace=False)
all_live_variables.append(live_var_from_arg) all_live_variables.append(live_var_from_arg)
checked_variables.append(live_var_from_arg) checked_variables.append(live_var_from_arg)
unique_live_vars.append(live_var_from_arg) unique_live_vars.append(live_var_from_arg)
...@@ -167,8 +141,23 @@ class GraphAnalyser: ...@@ -167,8 +141,23 @@ class GraphAnalyser:
node=node, node=node,
all_live_vars=all_live_variables.copy(), all_live_vars=all_live_variables.copy(),
unique_live_vars=unique_live_vars.copy()) unique_live_vars=unique_live_vars.copy())
liveness_dict[idx] = stage # if a LiveStage is covered by another LiveStage, we just keep the larger one.
return liveness_dict replace = False
for index, prev_stage in enumerate(liveness_list):
all_covered = True
for ele in prev_stage.unique_live_vars:
if ele not in stage.unique_live_vars:
all_covered = False
break
if all_covered:
replace = True
break
if replace:
liveness_list[index] = stage
else:
liveness_list.append(stage)
return liveness_list
def get_alias_set(self): def get_alias_set(self):
pass pass
from .operator_handler import OperatorHandler
from .dot_handler import DotHandler
from .conv_handler import ConvHandler
from .batch_norm_handler import BatchNormHandler
__all__ = ['OperatorHandler', 'DotHandler', 'ConvHandler', 'BatchNormHandler']
\ No newline at end of file
...@@ -44,11 +44,8 @@ class DotHandler(OperatorHandler): ...@@ -44,11 +44,8 @@ class DotHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
numel = self.output_data.numel() dim_partition_dict_for_output, dim_partition_dict_for_weight)
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0] * self.device_mesh.shape[mesh_dim_1]
memory_cost = numel * size_per_elem_bytes / sharding_size
# compute the communication cost # compute the communication cost
# no all-reduce required for this case # no all-reduce required for this case
...@@ -59,7 +56,7 @@ class DotHandler(OperatorHandler): ...@@ -59,7 +56,7 @@ class DotHandler(OperatorHandler):
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=memory_cost, memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
...@@ -86,19 +83,16 @@ class DotHandler(OperatorHandler): ...@@ -86,19 +83,16 @@ class DotHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
numel = self.output_data.numel() dim_partition_dict_for_output, dim_partition_dict_for_weight)
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
# compute the communication cost of this strategy # compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=memory_cost, memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
...@@ -122,19 +116,16 @@ class DotHandler(OperatorHandler): ...@@ -122,19 +116,16 @@ class DotHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
numel = self.output_data.numel() dim_partition_dict_for_output, dim_partition_dict_for_weight)
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim_0]
memory_cost = numel * size_per_elem_bytes / sharding_size
# compute the communication cost of this strategy # compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim_1) communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim_1)
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=memory_cost, memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
...@@ -158,18 +149,16 @@ class DotHandler(OperatorHandler): ...@@ -158,18 +149,16 @@ class DotHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
numel = self.output_data.numel() dim_partition_dict_for_output, dim_partition_dict_for_weight)
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
memory_cost = numel * size_per_elem_bytes
# compute the communication cost of this strategy # compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim) communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=memory_cost, memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
...@@ -193,19 +182,115 @@ class DotHandler(OperatorHandler): ...@@ -193,19 +182,115 @@ class DotHandler(OperatorHandler):
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape) compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy # compute the memory cost of this strategy
dtype = self.input_data.dtype toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
numel = self.output_data.numel() dim_partition_dict_for_output, dim_partition_dict_for_weight)
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
sharding_size = self.device_mesh.shape[mesh_dim]
memory_cost = numel * size_per_elem_bytes / sharding_size
# compute the communication cost of this strategy # compute the communication cost of this strategy
communication_cost = self.device_mesh.all_reduce_cost(memory_cost, mesh_dim) communication_cost = self.device_mesh.all_reduce_cost(activation_memory_cost, mesh_dim)
sharding_strategies = ShardingStrategy(name, sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput, output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost, compute_cost=compute_cost,
communication_cost=communication_cost, communication_cost=communication_cost,
memory_cost=memory_cost, memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def split_lhs_1st_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'S{mesh_dim_0}{mesh_dim_1}R = S{mesh_dim_0}{mesh_dim_1}R x RR'
dim_partition_dict_for_input = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
# compute the communication cost of this strategy
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def split_lhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RR = RS{mesh_dim_0}{mesh_dim_1} x S{mesh_dim_0}{mesh_dim_1}R'
dim_partition_dict_for_input = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {0: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
# compute the communication cost of this strategy
communication_cost = self.device_mesh.flatten_device_mesh.all_reduce_cost(activation_memory_cost, 0)
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies)
def split_rhs_2nd_dim_1d(self, mesh_dim_0, mesh_dim_1):
name = f'RS{mesh_dim_0}{mesh_dim_1} = RR x RS{mesh_dim_0}{mesh_dim_1}'
dim_partition_dict_for_input = {}
sharding_spec_for_input = self._generate_sharding_spec(self.input_data, dim_partition_dict_for_input)
dim_partition_dict_for_weight = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_weight = self._generate_sharding_spec(self.weight, dim_partition_dict_for_weight)
dim_partition_dict_for_output = {1: [mesh_dim_0, mesh_dim_1]}
sharding_spec_for_ouput = self._generate_sharding_spec(self.output_data, dim_partition_dict_for_output)
# generate resharding cost for this strategy
resharding_costs = self._generate_resharding_costs([sharding_spec_for_input])
# compute the computation cost of this strategy
compute_cost = self._generate_compute_cost(self.input_data.shape, self.weight.shape)
# compute the memory cost of this strategy
toatl_memory_cost, activation_memory_cost, weight_memory_cost = self._generate_memory_cost(
dim_partition_dict_for_output, dim_partition_dict_for_weight)
# compute the communication cost of this strategy
communication_cost = 0
sharding_strategies = ShardingStrategy(name,
output_sharding_spec=sharding_spec_for_ouput,
compute_cost=compute_cost,
communication_cost=communication_cost,
memory_cost=toatl_memory_cost,
resharding_costs=resharding_costs, resharding_costs=resharding_costs,
input_shardings=(sharding_spec_for_input, sharding_spec_for_weight)) input_shardings=(sharding_spec_for_input, sharding_spec_for_weight))
self.strategies_vector.append(sharding_strategies) self.strategies_vector.append(sharding_strategies)
...@@ -236,4 +321,14 @@ class DotHandler(OperatorHandler): ...@@ -236,4 +321,14 @@ class DotHandler(OperatorHandler):
# RS = RR x RS # RS = RR x RS
self.split_rhs_space_only(0) self.split_rhs_space_only(0)
self.split_rhs_space_only(1) self.split_rhs_space_only(1)
# S01R = S01R x RR
self.split_lhs_1st_dim_1d(0, 1)
# RR = RS01 x S01R
self.split_lhs_2nd_dim_1d(0, 1)
# RS01 = RR x RS01
self.split_rhs_2nd_dim_1d(0, 1)
return self.strategies_vector return self.strategies_vector
...@@ -8,7 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh ...@@ -8,7 +8,7 @@ from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from .sharding_strategy import StrategiesVector from ..sharding_strategy import StrategiesVector
__all__ = ['OperatorHandler'] __all__ = ['OperatorHandler']
...@@ -70,6 +70,48 @@ class OperatorHandler(ABC): ...@@ -70,6 +70,48 @@ class OperatorHandler(ABC):
dim_partition_dict=dim_partition_dict) dim_partition_dict=dim_partition_dict)
return sharding_spec return sharding_spec
def _generate_memory_cost(self, dim_partition_dict_for_output, dim_partition_dict_for_weight):
'''
Compute the memory cost per device with this specific strategy.
Argument:
dim_partition_dict_for_output(List[int]): The key is the dimension of output to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
dim_partition_dict_for_weight(List[int]): The key is the dimension of weight to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
Return:
total_memory_cost(float): total memory cost per device with this specific strategy
activation_cost(float): the memory cost of activation per device with this specific strategy
weight_memory_cost(float): the memory cost of weight per device with this specific strategy
'''
# compute the size of one element with specific dtype
dtype = self.input_data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
# compute the memory cost of activation
activation_numel = self.output_data.numel()
output_mesh_dims = []
for sharding_dim, mesh_dims in dim_partition_dict_for_output.items():
output_mesh_dims.extend(mesh_dims)
activation_sharding_size = 1
for mesh_dim in output_mesh_dims:
activation_sharding_size *= self.device_mesh.shape[mesh_dim]
activation_memory_cost = activation_numel / activation_sharding_size * size_per_elem_bytes
# compute the memory cost of weight
weight_numel = self.weight.numel()
weight_sharding_size = 1
weight_mesh_dims = []
for sharding_dim, mesh_dims in dim_partition_dict_for_weight.items():
weight_mesh_dims.extend(mesh_dims)
for mesh_dim in weight_mesh_dims:
weight_sharding_size *= self.device_mesh.shape[mesh_dim]
weight_memory_cost = weight_numel / weight_sharding_size * size_per_elem_bytes
total_memory_cost = activation_memory_cost + weight_memory_cost
return total_memory_cost, activation_memory_cost, weight_memory_cost
def _generate_resharding_costs(self, sharding_spec_for_input): def _generate_resharding_costs(self, sharding_spec_for_input):
''' '''
Compute the resharding costs with this specific strategy. Compute the resharding costs with this specific strategy.
...@@ -85,17 +127,20 @@ class OperatorHandler(ABC): ...@@ -85,17 +127,20 @@ class OperatorHandler(ABC):
''' '''
# The resharding_cost of weight is counted due to sharing weight cases. # The resharding_cost of weight is counted due to sharing weight cases.
resharding_costs = {} resharding_costs = {}
for input_node, target_spec in zip(self.predecessor_node, sharding_spec_for_input): dtype = self.node._meta_data.dtype
size_per_elem_bytes = torch.tensor([], dtype=dtype).element_size()
for input_node, input_spec in zip(self.predecessor_node, sharding_spec_for_input):
resharding_costs[input_node] = [] resharding_costs[input_node] = []
for strategy in input_node.strategies_vector: for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# compute the resharding cost during forward phase # compute the resharding cost during forward phase
_, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency( _, _, resharding_cost_forward = self.shape_consistency_manager.shape_consistency(
input_sharding_spec, target_spec) input_sharding_spec, input_spec)
# In backward phase, we should convert grad with target_spec into input_sharding_spec # In backward phase, we should convert grad with target_spec into input_sharding_spec
_, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency( _, _, resharding_cost_backward = self.shape_consistency_manager.shape_consistency(
target_spec, input_sharding_spec) input_spec, input_sharding_spec)
resharding_cost = resharding_cost_forward + resharding_cost_backward # we need multiply the size of elem dtype to get correct communication cost
resharding_cost = (resharding_cost_forward + resharding_cost_backward) * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost) resharding_costs[input_node].append(resharding_cost)
return resharding_costs return resharding_costs
import warnings
import time
import numpy as np
import multiprocessing
from torch.fx.node import Node
from torch.fx.graph import Graph
from . import GraphAnalyser
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from typing import Dict
from .constants import INFINITY_COST
try:
import pulp
from pulp import LpVariable, LpProblem, LpMinimize, lpSum, lpDot, LpStatus
except:
warnings.warn(f'please install the pulp')
__all___ = ['Solver']
class Solver:
def __init__(self,
graph: Graph,
strategies_constructor: StrategiesConstructor,
cost_graph: CostGraph,
graph_analyser: GraphAnalyser,
memory_budget: float = -1.0,
solution_numbers: int = 1,
memory_increasing_coefficient: float = 1.3):
'''
Solver class will integrate information provided by the components and use ILP solver to find a possible optimal strategies combination for target computing graph.
Argument:
graph: The computing graph to be optimized.
strategies_constructor: It will provide all the possible strategies for each node in the computing graph.
cost_graph: A graph data structure to simplify the edge cost graph.
graph_analyser: graph_analyser will analyse the graph to obtain the variable liveness information, which will be used to generate memory constraints.
memory_budget: Memory constraint for the solution.
solution_numbers: If solution_numbers is larger than one, solver will us a serious of solutions based on different memory budget.
memory_increasing_coefficient: If solution_numbers is larger than one, we will use this coefficient to generate new memory budget.
'''
self.graph = graph
self.strategies_constructor = strategies_constructor
self.cost_graph = cost_graph
self.graph_analyser = graph_analyser
self.nodes = list(self.graph.nodes)
self.leaf_strategies = self.strategies_constructor.leaf_strategies
self.strategy_map = self.strategies_constructor.strategy_map
self.memory_budget = memory_budget
self.solution_numbers = solution_numbers
if self.solution_numbers > 1:
self.memory_increasing_coefficient = memory_increasing_coefficient
else:
self.memory_increasing_coefficient = 1
self.liveness_list = self.graph_analyser.liveness_analysis()
self.node_index_dict = self._generate_node_index_dict()
# The last solution vector of auto sharding.
self.last_s_val = None
# The last objective value of the best ILP solution.
self.last_objective = None
def _generate_node_index_dict(self) -> Dict[Node, int]:
node_index_dict = {}
for index, strategies_vector in enumerate(self.leaf_strategies):
node_index_dict[strategies_vector.node] = index
return node_index_dict
def _prepare_data_for_solver(self):
'''
Extract information from components for solver.
'''
node_nums = len(self.leaf_strategies)
memory_budget = self.memory_budget
# prepare strategies_len
strategies_len = []
for node in self.nodes:
strategies_len.append(self.cost_graph.node_lens[node])
strategies_len = np.array(strategies_len)
# prepare following_nodes
following_nodes = self.cost_graph.following_dict
index_following_nodes = {}
for src, target in following_nodes.items():
src_index = self.node_index_dict[src]
target_index = self.node_index_dict[target]
index_following_nodes[src_index] = target_index
following_nodes = index_following_nodes
for index in range(node_nums):
if index not in following_nodes:
following_nodes[index] = -1
# prepare edge_pairs and resharding costs
edge_pairs = []
resharding_costs = []
for pairs, edge_cost in self.cost_graph.edge_costs.items():
src_node = pairs[0]
dst_node = pairs[1]
src_node_index = self.node_index_dict[src_node]
dst_node_index = self.node_index_dict[dst_node]
edge_pairs.append(src_node_index)
edge_pairs.append(dst_node_index)
for i in range(strategies_len[src_node_index]):
for j in range(strategies_len[dst_node_index]):
resharding_costs.append(edge_cost[(i, j)])
edge_pairs = np.array(edge_pairs)
resharding_costs = np.array(resharding_costs)
# prepare liveness_set
liveness_set = self.liveness_list
# omit alias_set now
alias_set = None
alias_convert_costs = None
# prepare compute_costs, communication_costs and memory_costs
compute_costs = []
communication_costs = []
memory_costs = []
extra_node_costs = self.cost_graph.extra_node_costs
for strategies_vector in self.leaf_strategies:
node = strategies_vector.node
for index, strategy in enumerate(strategies_vector):
compute_costs.append(strategy.compute_cost)
# node in extra_node_costs means it has some extra communication
# cost from node merging, so we need to add those extra communication
# cost into
if node in extra_node_costs:
origin_communication_cost = strategy.communication_cost
extra_node_cost = extra_node_costs[node][index]
communication_cost = origin_communication_cost + extra_node_cost
communication_costs.append(communication_cost)
else:
communication_costs.append(strategy.communication_cost)
# temporarily we just consider the forward memory cost
memory_cost = strategy.memory_cost
if isinstance(memory_cost, tuple):
memory_costs.append(memory_cost[0])
else:
memory_costs.append(memory_cost)
compute_costs = np.array(compute_costs)
communication_costs = np.array(communication_costs)
memory_costs = np.array(memory_costs)
# omit initial value for nodes
s_init_np = None
return node_nums, memory_budget, strategies_len, following_nodes, edge_pairs, alias_set, liveness_set, compute_costs, communication_costs, memory_costs, resharding_costs, alias_convert_costs, s_init_np
def _call_solver_serialized_args(self,
node_nums,
memory_budget,
strategies_len,
following_nodes,
edge_pairs,
alias_set,
liveness_set,
compute_costs,
communication_costs,
memory_costs,
resharding_costs,
alias_convert_costs,
s_init_np=None):
"""
Call the solver with serialized arguments.
"""
tic = time.time()
for x in [strategies_len, edge_pairs, compute_costs, communication_costs, memory_costs, resharding_costs]:
assert isinstance(x, np.ndarray)
assert len(strategies_len) == node_nums, "strategies_len"
def get_non_zero_index(binary_vector):
"""
Get the index of non-zero item in a vector.
"""
ct = 0
ret = None
for i, elem in enumerate(binary_vector):
if pulp.value(elem):
ret = i
ct += 1
assert ct == 1
return ret
# 0. Unpack flatten numpy arrays
s_follow = following_nodes
E = edge_pairs.reshape((-1, 2)) # noqa
r = []
pt = 0
edge_set = set()
for (i, j) in E:
prod_length = strategies_len[i] * strategies_len[j]
if (i, j) in edge_set:
raise ValueError(f"Duplicated edges: {(i, j)}")
edge_set.add((i, j))
r.append(resharding_costs[pt:pt + prod_length])
pt += prod_length
assert pt == len(resharding_costs)
######################
# omit alias set now #
######################
# A = alias_set.reshape((-1, 2)) # noqa
# for (i, j) in A:
# prod_length = strategies_len[i] * strategies_len[j]
# v.append(alias_convert_costs[pt:pt + prod_length])
# pt += prod_length
# assert pt == len(alias_convert_costs)
# L = [] # noqa
# pt = node_nums
# for i in range(node_nums):
# length = liveness_set[i]
# L.append(liveness_set[pt:pt + length])
# pt += length
# assert pt == len(liveness_set)
v = []
pt = 0
c = []
d = []
m = []
pt = 0
for i in range(node_nums):
length = strategies_len[i]
c.append(compute_costs[pt:pt + length])
d.append(communication_costs[pt:pt + length])
m.append(memory_costs[pt:pt + length])
pt += length
assert pt == len(compute_costs), f"{pt} == {len(compute_costs)}"
assert pt == len(communication_costs), f"{pt} == {len(communication_costs)}"
assert pt == len(memory_costs), f"{pt} == {len(memory_costs)}"
# 1. Create variables
#############################
# create variables for node #
#############################
s = []
num_nodes = 0
reverse_follow_backpatch = []
for i in range(node_nums):
if s_follow[i] < 0:
if strategies_len[i] == 1:
s.append([1])
else:
num_nodes += 1
s.append(LpVariable.matrix(f"s[{i}]", (range(strategies_len[i]),), cat="Binary"))
else:
if s_follow[i] < len(s):
s.append(s[s_follow[i]])
else:
s.append(None)
reverse_follow_backpatch.append(i)
for i in reverse_follow_backpatch:
s[i] = s[s_follow[i]]
#############################
# create variables for edge #
#############################
e = []
num_edges = 0
for (idx, (i, j)) in enumerate(E):
if len(s[i]) == 1:
e.append(s[j])
elif len(s[j]) == 1:
e.append(s[i])
else:
num_edges += 1
e.append(LpVariable.matrix(f"e[{i},{j}]", (range(len(s[i]) * len(s[j])),), cat="Binary"))
assert len(e[idx]) == len(r[idx])
# 2. Set initial value
######################################
# set a initial value for warm start #
######################################
if s_init_np is not None:
s_init = s_init_np.reshape((-1, 3))
for (idx, value, fix) in s_init:
for i in range(len(s[idx])):
s[idx][i].setInitialValue(i == value)
if fix:
s[idx][i].fixValue()
# 3. Objective
prob = LpProblem("myProblem", LpMinimize)
###################################################################
# computing the node cost(computing cost and communication cost) #
###################################################################
obj = 0
for i in range(node_nums):
obj += lpDot(s[i], c[i]) + lpDot(s[i], d[i])
#############################################
# computing the edge cost(resharding cost) #
#############################################
for i in range(len(E)):
obj += lpDot(e[i], r[i])
prob += obj
# 4. Constraints
# (a). specified by `cat="Binary"`
# (b)
#################################################
# make sure each node only choose one strategy #
#################################################
for i in range(node_nums):
if s_follow[i] < 0:
prob += lpSum(s[i]) == 1
# (c)
#################################################
# compute memory consumption with liveness set #
#################################################
if memory_budget > 0:
for liveness_stage in liveness_set:
mem = 0
for live_variable in liveness_stage.unique_live_vars:
node_index = self.node_index_dict[live_variable.node]
mem += lpSum(s[node_index][j] * m[node_index][j] for j in range(len(s[node_index])))
prob += mem <= memory_budget
# (d). specified by `cat="Binary"`
for (idx, (i, j)) in enumerate(E):
if strategies_len[i] == 1 or strategies_len[j] == 1:
continue
# (e)
prob += lpSum(e[idx]) == 1
# (f)
for row in range(len(s[i])):
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for col in range(0, C)) <= s[i][row]
# (g)
for col in range(len(s[j])):
R = len(s[i]) # noqa
C = len(s[j]) # noqa
prob += lpSum(e[idx][row * C + col] for row in range(0, R)) <= s[j][col]
# (h)
######################
# omit alias set now #
######################
# alias_set = set()
# for (idx, (i, j)) in enumerate(A):
# R = len(s[i]) # noqa
# C = len(s[j]) # noqa
# if (i, j) in alias_set:
# raise ValueError(f"Duplicated edges: {(i, j)}")
# alias_set.add((i, j))
# alias_set.add((j, i))
# for row in range(len(s[i])):
# for col in range(len(s[j])):
# if v[idx][row * C + col] > 0.5:
# prob += s[i][row] + s[j][col] <= 1
verbose = True
msg = verbose
time_limit = 600
assert "COIN_CMD" in pulp.listSolvers(
onlyAvailable=True), ("Please install ILP solvers by 'sudo apt install coinor-cbc'")
solver = pulp.COIN_CMD(mip=True, msg=msg, timeLimit=time_limit, threads=multiprocessing.cpu_count())
# solver = pulp.GLPK_CMD(mip=True, msg=msg, timeLimit=time_limit)
prob.solve(solver)
status = prob.status
objective = pulp.value(prob.objective)
objective = float(objective) if objective is not None else -1.0
if verbose:
print(f"ILP Status: {LpStatus[status]}\tObjective: {objective}\t"
f"Time: {time.time() - tic}")
print(f"#nodes: {num_nodes}, #edges: {num_edges}")
if prob.status in [pulp.LpStatusInfeasible]:
raise RuntimeError("Cannot run the function under the given memory budget. "
"Please increase the memory budget.")
# Get and check results
s_val = np.full((node_nums,), -1, dtype=np.int32)
for i in range(node_nums):
s_val[i] = get_non_zero_index(s[i])
e_val = np.full((len(E),), -1, dtype=np.int32)
for (idx, (i, j)) in enumerate(E):
e_val[idx] = get_non_zero_index(e[idx])
i_spec_index = e_val[idx] // len(s[j])
j_spec_index = e_val[idx] % len(s[j])
assert i_spec_index == s_val[i], f"e_val[{i}][{j}]"
assert j_spec_index == s_val[j], f"e_val[{i}][{j}]"
if verbose and r[idx][e_val[idx]] > 0:
print(f"Edge cost {(i, j)} : {r[idx][e_val[idx]]}")
self.last_s_val = s_val
self.last_objective = objective
if objective > INFINITY_COST:
warnings.warn("Detect unexpected behaviors in the auto-sharding pass.")
return s_val, e_val, objective, status
def call_solver_serialized_args(self):
"""
Call the solver with serialized arguments and handle python errors. Additionally,
we could give a serious of solutions with different memory budget.
"""
if self.solution_numbers == 1:
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
return ret
origin_memory_budget = self.memory_budget
memory_budget_list = [
origin_memory_budget * self.memory_increasing_coefficient**i for i in range(self.solution_numbers)
]
ret_list = []
for memory_budget in memory_budget_list:
self.memory_budget = memory_budget
args = self._prepare_data_for_solver()
ret = self._call_solver_serialized_args(*args)
ret_list.append(ret)
return ret_list
from torch.fx import Graph, Node from torch.fx import Graph, Node
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from .sharding_strategy import ShardingStrategy, StrategiesVector from . import ShardingStrategy, StrategiesVector
from .conv_handler import ConvHandler from .op_handler import *
from .constants import * from .constants import *
from copy import deepcopy from copy import deepcopy
import math import math
...@@ -175,6 +175,58 @@ class StrategiesConstructor: ...@@ -175,6 +175,58 @@ class StrategiesConstructor:
input_shardings=[input_sharding_spec]) input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy) strategies_vector.append(sharding_strategy)
# BatchNormNd module
elif submod_type in BATCHNORM_MODULE_OP:
# bn1 call_module bn1 (conv1,)
# print(node, node.op, node.target, node.args)
# create sharding strategy for element-wise module
# input_node = strategies_vector.predecessor_nodes[0]
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector,
self.shape_consistency_manager)
norm_handler.register_strategy()
# for strategy in norm_handler.strategies_vector:
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
# assert False
# MaxPool module
elif submod_type in POOL_MODULE_OP:
# create sharding strategy for element-wise module
assert len(strategies_vector.predecessor_nodes
) == 1, f'Temporally, we just support single input element-wise op.'
input_node = strategies_vector.predecessor_nodes[0]
# For element-wise module, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise module.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = self._generate_sharding_spec(node, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# other module # other module
else: else:
raise RuntimeError(f'{submod_type} module is NOT supported now.') raise RuntimeError(f'{submod_type} module is NOT supported now.')
...@@ -203,7 +255,7 @@ class StrategiesConstructor: ...@@ -203,7 +255,7 @@ class StrategiesConstructor:
# TODO: integrate element-wise func and module together # TODO: integrate element-wise func and module together
# create sharding strategy for element-wise function # create sharding strategy for element-wise function
assert len(strategies_vector.predecessor_nodes assert len(strategies_vector.predecessor_nodes
) == 1, f'Temporally, we just support single input element-wise op.' ) == 1, f'Temporally, we just support single input element-wise op, node name is {node}.'
input_node = strategies_vector.predecessor_nodes[0] input_node = strategies_vector.predecessor_nodes[0]
# For element-wise function, we keep the sharding spec of output node same as # For element-wise function, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same # the input. Therefore, the different strategies of input node with same
...@@ -349,6 +401,13 @@ class StrategiesConstructor: ...@@ -349,6 +401,13 @@ class StrategiesConstructor:
memory_cost = 0 memory_cost = 0
resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes, resharding_costs = self._generate_resharding_costs(strategies_vector.predecessor_nodes,
input_sharding_specs) input_sharding_specs)
# clear the resharding cost for the output node
# TODO: we may remove this in final version
if True:
for prev_node, resharding_cost_list in resharding_costs.items():
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
sharding_strategy_attribute = ShardingStrategy(name, sharding_strategy_attribute = ShardingStrategy(name,
output_sharding_spec, output_sharding_spec,
memory_cost=memory_cost, memory_cost=memory_cost,
......
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.op_handler.batch_norm_handler import BatchNormHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
class BNModel(nn.Module):
def __init__(self, c):
super().__init__()
self.bn = nn.BatchNorm2d(c)
def forward(self, x):
x = x * 2
x = self.bn(x)
return x
def test_bn_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = BNModel(16)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %bn : [#users=1] = call_module[target=bn](args = (%mul,), kwargs = {})
# return bn
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# [x, mul, bn, output]
nodes = [node for node in gm.graph.nodes]
# find the sharding strategies for the input node of the bn node
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(nodes[1])
sharding_option = (None, 0, 1)
for first_sharding_index in sharding_option:
for second_sharding_index in sharding_option:
if first_sharding_index is not None and second_sharding_index == first_sharding_index:
continue
if first_sharding_index is None:
first_dim_spec = _DimSpec([])
else:
first_dim_spec = _DimSpec([first_sharding_index])
if second_sharding_index is None:
second_dim_spec = _DimSpec([])
else:
second_dim_spec = _DimSpec([second_sharding_index])
replica_dim_spec = _DimSpec([])
sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec]
sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=entire_shape,
sharding_sequence=sharding_sequence)
strategy_name = str(sharding_spec.sharding_sequence)
sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
strategies_vector_for_input.append(sharding_strategy)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
# generate bn strategy
strategies_vector = StrategiesVector(node=nodes[2])
bn_handler = BatchNormHandler(node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager)
bn_handler.register_strategy()
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']
strategy_name_list = [strategy.name for strategy in bn_handler.strategies_vector]
# RS = RS x S and strategies based on it, such as
# SS = RS x S
assert 'RS0 = RS0 x S0' in strategy_name_list
assert 'S1S0 = RS0 x S0' in strategy_name_list
assert 'RS1 = RS1 x S1' in strategy_name_list
assert 'S0S1 = RS1 x S1' in strategy_name_list
# RR = RR x R and strategies based on it, such as
# SR = SR x R
assert 'RR = RR x R' in strategy_name_list
assert 'S0R = RR x R' in strategy_name_list
assert 'S1R = RR x R' in strategy_name_list
assert 'S01R = RR x R' in strategy_name_list
# RS01 = RS01 x S01
assert 'RS01 = RS01 x S01' in strategy_name_list
# SR = SR x R WITH SYNC_BN
assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list
assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list
# SS = SS x S WITH SYNC_BN
assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list
assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list
# S01R = S01R x R WITH SYNC_BN
assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list
if __name__ == '__main__':
test_bn_handler()
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from colossalai.fx.proxy import ColoProxy from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler from colossalai.auto_parallel.solver.op_handler.conv_handler import ConvHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
......
...@@ -6,8 +6,6 @@ import pytest ...@@ -6,8 +6,6 @@ import pytest
from colossalai.fx.proxy import ColoProxy from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
......
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from colossalai.fx.proxy import ColoProxy from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.dot_handler import DotHandler from colossalai.auto_parallel.solver.op_handler.dot_handler import DotHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
......
...@@ -32,21 +32,21 @@ def test_liveness_analysis(): ...@@ -32,21 +32,21 @@ def test_liveness_analysis():
gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__) gm = ColoGraphModule(root=model, graph=graph, class_name=model.__class__.__name__)
graph_analyser = GraphAnalyser(gm) graph_analyser = GraphAnalyser(gm)
liveness_dict = graph_analyser.liveness_analysis() liveness_list = graph_analyser.liveness_analysis()
stage_count = len(liveness_dict) stage_count = len(liveness_list)
# 8 stages including input and output # if a LiveStage is covered by another LiveStage, we just keep the larger one.
assert stage_count == 8 assert stage_count == 1
# a variable named `relu` must exist # a variable named `relu` must exist
# and this live var must have inplace = True # and this live var must have inplace = True
assert liveness_dict[5].all_live_vars.exists('relu') assert liveness_list[0].all_live_vars.exists('relu')
relu_var = liveness_dict[5].all_live_vars.get('relu') relu_var = liveness_list[0].all_live_vars.get('relu')
assert relu_var.is_inplace assert relu_var.is_inplace
# the unique vars must be fewer than the all vars since in-place ops exist # the unique vars must be fewer than the all vars since in-place ops exist
all_live_vars = liveness_dict[7].all_live_vars all_live_vars = liveness_list[0].all_live_vars
unique_live_vars = liveness_dict[7].unique_live_vars unique_live_vars = liveness_list[0].unique_live_vars
assert len(unique_live_vars) + 1 == len(all_live_vars) assert len(unique_live_vars) + 1 == len(all_live_vars)
......
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.solver.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.solver.cost_graph import CostGraph
from colossalai.auto_parallel.solver.graph_analysis import GraphAnalyser
from copy import deepcopy
from colossalai.auto_parallel.solver import Solver
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3)
self.conv3 = nn.Conv2d(c_out, c_out, kernel_size=3)
self.relu = nn.ReLU()
def forward(self, x):
x = x * 2
x = self.conv1(x)
x = self.conv2(x)
x = x / 2
x = self.conv3(x)
x = self.relu(x)
return x
@pytest.mark.skip("for higher testing speed")
def test_solver():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
# %conv2 : [#users=1] = call_module[target=conv2](args = (%conv1,), kwargs = {})
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv2, 2), kwargs = {})
# %conv3 : [#users=1] = call_module[target=conv3](args = (%truediv,), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%conv3,), kwargs = {})
# return relu
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = {'fast_mode': True}
strategies_constructor = StrategiesConstructor(graph, device_mesh, shape_consistency_manager, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
# [ 0 0 13 13 13 13 13 0]
strategies_combination_list = ret[0]
assert solver.leaf_strategies[2][13].name == 'S01R = S01R x RR'
if __name__ == '__main__':
test_solver()
...@@ -6,7 +6,7 @@ import pytest ...@@ -6,7 +6,7 @@ import pytest
from colossalai.fx.proxy import ColoProxy from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler, CONV_STRATEGIES_LIST from colossalai.auto_parallel.solver.op_handler.conv_handler import CONV_STRATEGIES_LIST
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment