Unverified Commit 26a37b5c authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] Add conv handler to generate strategies and costs info for conv (#1467)

parent 1b491ad7
This diff is collapsed.
class ShardingStrategy:
'''
ShardingStrategy is a structure containing sharding strategies of inputs and output of this node
and costs information using in solver.
Argument:
name(str): express the sharding strategies in string, such as 'S0S1 = S0R x RS1'.
output_sharding_spec(ShardingSpec): ShardingSpec of the output node.
compute_cost(float): Computation cost to complete this strategy.(default to 0)
communication_cost(float): Communication cost to complete this strategy.(default to 0)
memory_cost(float): Memory cost of the output node using this strategy.(default to 0)
resharding_costs(Dict[int, List[float]]): resharding_cost[i][j] means the cost of i-th argument in the output node argument list
with j-th strategy in its strategies_vector transforms to sharding spec wanted in this
strategy.(default to None)
input_shardings(List(ShardingSpec)): The ShardingSpecs of the input nodes.
'''
def __init__(self,
name,
output_sharding_spec,
compute_cost=0,
communication_cost=0,
memory_cost=0,
resharding_costs=None,
input_shardings=None):
self.name = name
self.output_sharding_spec = output_sharding_spec
self.compute_cost = compute_cost
self.communication_cost = communication_cost
self.memory_cost = memory_cost
self.resharding_costs = resharding_costs
self.input_shardings = input_shardings
class StrategiesVector:
'''
Each node in fx graph will have a corresponding StrategiesVector, to store all the possible
strategies of the node.
Argument:
node(Node): node to build corresponding strategies_vector.
in_nodes(List[Node]): input nodes in the argument list of the node.
following_nodes(List[Node]): the nodes take the target node as their argument.
strategies(List[ShardingStrategy]): enumerate all the possible sharding strategies of the node.
'''
def __init__(self, node, in_nodes, following_nodes=None, strategies=[]):
self.node = node
self.in_nodes = in_nodes
self.following_nodes = following_nodes
self.strategies = strategies
def check_merge(self):
pass
...@@ -199,7 +199,7 @@ class ShardingSpec: ...@@ -199,7 +199,7 @@ class ShardingSpec:
if not dim_spec.is_replica: if not dim_spec.is_replica:
if index not in new_dim_partition_dict: if index not in new_dim_partition_dict:
new_dim_partition_dict[index] = [] new_dim_partition_dict[index] = []
new_dim_partition_dict[index].append(dim_spec.shard_list) new_dim_partition_dict[index].extend(dim_spec.shard_list)
self.dim_partition_dict = new_dim_partition_dict self.dim_partition_dict = new_dim_partition_dict
def sharding_sequence_difference(self, other): def sharding_sequence_difference(self, other):
......
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.solver.conv_handler import ConvHandler
from colossalai.auto_parallel.solver.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
def forward(self, x):
x = x * 2
x = self.conv(x)
return x
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# [x, mul, conv, output]
nodes = [node for node in gm.graph.nodes]
strategies_for_input = []
sharding_option = (None, 0, 1)
for first_sharding_index in sharding_option:
for second_sharding_index in sharding_option:
if first_sharding_index is not None and second_sharding_index == first_sharding_index:
continue
if first_sharding_index is None:
first_dim_spec = _DimSpec([])
else:
first_dim_spec = _DimSpec([first_sharding_index])
if second_sharding_index is None:
second_dim_spec = _DimSpec([])
else:
second_dim_spec = _DimSpec([second_sharding_index])
replica_dim_spec = _DimSpec([])
sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec]
sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=entire_shape,
sharding_sequence=sharding_sequence)
strategies_for_input.append(sharding_spec)
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(node=nodes[0],
in_nodes=[nodes[1], 2],
strategies=strategies_for_input)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
strategies_vector = StrategiesVector(node=nodes[2], in_nodes=[
nodes[1],
])
conv_handler = ConvHandler(input_node=nodes[1],
input_index=0,
weight=dict(gm.named_modules())[nodes[2].name].weight,
output_node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
shape_consistency_manager=shape_consistency_manager)
conv_handler.register_strategy_into_strategies_vector()
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
strategy_name_list = [strategy.name for strategy in conv_handler.strategies_vector.strategies]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RS = RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
# RR= RR x RR
assert 'RR = RR x RR' in strategy_name_list
if __name__ == '__main__':
test_conv_handler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment