"pcdet/ops/vscode:/vscode.git/clone" did not exist on "32567b044c327a4d3cee179094f32646d8311c95"
Unverified Commit 0b2a7383 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] remove deprecated codes (#2664)

parent 7fa6be49
import builtins
import math
import operator
from copy import deepcopy
from typing import Dict, List
import torch
from torch.fx import Graph, Node
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec
from ._utils import generate_resharding_costs, generate_sharding_spec
from .constants import *
from .op_handler import *
from .options import SolverOptions
from .sharding_strategy import ShardingStrategy, StrategiesVector
__all__ = ['StrategiesConstructor']
class StrategiesConstructor:
"""
StrategiesConstructor is used to construct the parallelization plan for the model execution.
Args:
graph (Graph): a Graph object used for analysis and strategy generation.
device_mesh (DeviceMesh): a DeviceMesh object which contains the meta information about the cluster.
solver_options (SolverOptions): a SolverOptions object which specifies the preferences for plan searching.
"""
def __init__(self, graph: Graph, device_mesh: DeviceMesh, solver_options: SolverOptions):
self.graph = graph
assert graph.owning_module is not None, 'The given graph is not associated with a owning_module'
self.root_module = self.graph.owning_module
self.nodes = list(graph.nodes)
self.device_mesh = device_mesh
self.leaf_strategies = []
self.strategy_map = {}
self.solver_options = solver_options
def remove_duplicated_strategy(self, strategies_vector):
'''
In build_strategies_and_cost method, we may produce some duplicated strategies.
In this method, we will remove the duplicated strategies depending on the strategies name.
'''
name_checklist = []
remove_list = []
for strategy in strategies_vector:
if strategy.name not in name_checklist:
name_checklist.append(strategy.name)
else:
remove_list.append(strategy)
for strategy in remove_list:
strategies_vector.remove(strategy)
def _is_bcast_matmul(self, node):
is_bcast_matmul = False
if node.target is torch.matmul and len(node.args) == 2:
lhs_data = node.args[0]._meta_data
rhs_data = node.args[1]._meta_data
if lhs_data.dim() >= 3 and rhs_data.dim() >= 3:
is_bcast_matmul = True
return is_bcast_matmul
def build_strategies_and_cost(self):
for node in self.nodes:
strategies_vector = StrategiesVector(node)
input_nodes_len = 0
for check_node in strategies_vector.predecessor_nodes:
if isinstance(check_node._meta_data, torch.Tensor):
input_nodes_len += 1
# input_nodes_len = len(strategies_vector.predecessor_nodes)
# placeholder node
if node.op == 'placeholder':
# For placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the placeholder node.
# Otherwise, all the possible sharding spec for the placeholder node will be enumerated.
if self.solver_options.fast:
# create sharding strategy for placeholder
name = 'Replica Placeholder'
dim_partition_dict = {}
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_placeholder = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_placeholder)
# get_attr node
if node.op == 'get_attr':
# Same as placeholder nodes, if solver_options.fast is True, we just let them in
# fully replicate status, then strategies of following node will be treated equally due
# to replicate status has no resharding cost to other status. At the same time, the searching
# space is smaller than enumerating all the possible sharding spec for the get_attr node.
# Otherwise, all the possible sharding spec for the get_attr node will be enumerated.
if self.solver_options.fast:
# create sharding strategy for get_attr
name = 'Replica Attribute'
dim_partition_dict = {}
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
sharding_strategy_attribute = ShardingStrategy(name, output_sharding_spec, memory_cost=memory_cost)
strategies_vector.append(sharding_strategy_attribute)
# call_module node
if node.op == 'call_module':
target = node.target
submod = self.root_module.get_submodule(target)
submod_type = type(submod)
# conv module
if submod_type in CONV_MODULE_OP:
# use ConvHandler to create sharding strategies for conv module node
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
conv_handler.register_strategy()
# linear module
elif submod_type in LINEAR_MODULE_OP:
# use DotHandler to create sharding strategies for linear module node
dot_handler = DotHandler(node, self.device_mesh, strategies_vector)
dot_handler.register_strategy()
# element-wise module
elif submod_type in ELEMENTWISE_MODULE_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
# BatchNormNd module
elif submod_type in BATCHNORM_MODULE_OP:
# create sharding strategy for element-wise module
norm_handler = BatchNormHandler(node, self.device_mesh, strategies_vector)
norm_handler.register_strategy()
# for strategy in norm_handler.strategies_vector:
# print(f'{strategy.name}, computation_cost: {strategy.compute_cost}, memory_cost: {strategy.memory_cost}')
# assert False
# MaxPool module
elif submod_type in POOL_MODULE_OP:
# TODO: add sharding constraints on image dimension
# e.g.: for a 2D pooling input NCHW, we should promise no sharding happens on H and W dimension
# create sharding strategy for element-wise module
assert input_nodes_len == 1, f'Temporally, we just support single input element-wise op.'
input_node = strategies_vector.predecessor_nodes[0]
# For element-wise module, we keep the sharding spec of output node same as
# the input. Therefore, the different strategies of input node with same
# output sharding spec will generate same strategy for element-wise module.
sharding_spec_checklist = []
for strategy in input_node.strategies_vector:
# It looks a little bit confusing, the input of the processing node
# is the output of the input_node.
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
if input_sharding_spec in sharding_spec_checklist:
continue
sharding_spec_checklist.append(input_sharding_spec)
dim_partition_dict = deepcopy(input_sharding_spec.dim_partition_dict)
output_sharding_spec = generate_sharding_spec(node, self.device_mesh, dim_partition_dict)
name = f'{input_sharding_spec.sharding_sequence} -> {output_sharding_spec.sharding_sequence}'
# TODO: use meta_info_prop to profile memory cost and compute cost
compute_cost = node._meta_data.numel()
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# embedding module
elif submod_type in EMBEDDING_MODULE_OP:
embedding_handler = EmbeddingHandler(node, self.device_mesh, strategies_vector)
embedding_handler.register_strategy()
# layernorm module
elif submod_type in LAYERNORM_MODULE_OP:
layernorm_handler = LayerNormHandler(node, self.device_mesh, strategies_vector)
layernorm_handler.register_strategy()
# other module
else:
raise RuntimeError(f'{submod_type} module is NOT supported now.')
# call_function node
if node.op == 'call_function':
target = node.target
# conv function
if target in CONV_FUNC_OP:
# use ConvHandler to create sharding strategies for conv node
# TODO: the operator_handler does NOT support function node processing now.
conv_handler = ConvHandler(node, self.device_mesh, strategies_vector)
conv_handler.register_strategy()
# linear function
elif target in LINEAR_FUNC_OP and not self._is_bcast_matmul(node):
# use DotHandler to create sharding strategies for linear node
# TODO: the operator_handler does NOT support function node processing now.
linear_handler = DotHandler(node, self.device_mesh, strategies_vector)
linear_handler.register_strategy()
# where function
elif target == torch.where:
if input_nodes_len == 1:
# both of x and y are scalar
pass
elif input_nodes_len == 2:
# one of x or y is type of scalar
pass
else:
# general case
where_handler = WhereHandler(node, self.device_mesh, strategies_vector)
where_handler.register_strategy()
# reshape function
elif target in RESHAPE_FUNC_OP:
# use ReshapeHandler to create sharding strategies for rehsape node
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
reshape_handler.register_strategy()
# element-wise function
elif target in ELEMENTWISE_FUNC_OP or (target in BCAST_FUNC_OP and input_nodes_len == 1):
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
# bcast op
elif target in BCAST_FUNC_OP:
if isinstance(node._meta_data, torch.Tensor):
bcast_op_handler = BcastOpHandler(node, self.device_mesh, strategies_vector)
bcast_op_handler.register_strategy()
# torch.var_mean
elif target == torch.var_mean:
dim = node.kwargs['dim']
input_tensor_node = strategies_vector.predecessor_nodes[0]
for strategy in input_tensor_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec,
ShardingSpec), f'The input node should NOT be a tuple of tensor.'
entire_shape_input = input_sharding_spec.entire_shape
dim_partition_dict_input = input_sharding_spec.dim_partition_dict
name = f'{new_input_sharding_spec.sharding_sequence} -> ({output_sharding_spec.sharding_sequence}, {output_sharding_spec.sharding_sequence})'
if dim in dim_partition_dict_input:
# We need to make the action dimension in replicate status
dim_partition_dict_for_input = deepcopy(dim_partition_dict_input)
dim_partition_dict_for_input.pop(dim)
new_input_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_input,
dim_partition_dict=dim_partition_dict_for_input)
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_for_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[new_input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[new_input_sharding_spec])
else:
entire_shape_output = deepcopy(entire_shape_input)
entire_shape_output.pop(dim)
dim_partition_dict_for_output = deepcopy(dim_partition_dict_input)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partion_dict=dim_partition_dict_input)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec])
sharding_strategy = ShardingStrategy(name, (output_sharding_spec, output_sharding_spec),
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[input_sharding_spec])
strategies_vector.append(sharding_strategy)
# operator.getitem
elif target == operator.getitem:
index = node.args[1]
input_tensor_node = strategies_vector.predecessor_nodes[0]
for strategy in input_tensor_node.strategies_vector:
if isinstance(strategy.output_sharding_spec, ShardingSpec):
input_sharding_spec = strategy.output_sharding_spec
else:
input_sharding_spec = strategy.output_sharding_spec[index]
assert isinstance(input_sharding_spec, ShardingSpec), f'This assertion is used to debug.'
dim_partition_dict_for_output = deepcopy(input_sharding_spec.dim_partition_dict)
entire_shape_output = deepcopy(input_sharding_spec.entire_shape)
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
# TODO: use meta_info_prop to profile origin memory cost and compute cost, then divide them depending on sharding spec.
compute_cost = 0
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
[input_sharding_spec],
index=index)
# to prevent the resharding happening, set their resharding cost to inf.
resharding_costs[input_tensor_node] = [
cost if cost == 0 else INFINITY_COST for cost in resharding_costs[input_tensor_node]
]
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=compute_cost,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=[strategy.output_sharding_spec])
strategies_vector.append(sharding_strategy)
# torch.arange function
elif target == torch.arange:
name = f'FULLY REPLICATED ARANGE'
entire_shape_output = node._meta_data.shape
dim_partition_dict_for_output = {}
output_sharding_spec = ShardingSpec(self.device_mesh,
entire_shape_output,
dim_partition_dict=dim_partition_dict_for_output)
memory_cost = node._meta_data.numel()
sharding_strategy = ShardingStrategy(name,
output_sharding_spec,
compute_cost=0,
memory_cost=memory_cost)
strategies_vector.append(sharding_strategy)
# op list to be processed to support gpt2
elif target in (builtins.getattr, operator.le, torch.addmm):
pass
# other function
else:
raise RuntimeError(f'{target} function is NOT supported now.')
# call_method node
if node.op == 'call_method':
method = getattr(node.args[0]._meta_data.__class__, node.target)
if method in (torch.Tensor.size,):
pass
elif method in ELEMENTWISE_METHOD_OP:
unary_elementwise_handler = UnaryElementwiseHandler(node, self.device_mesh, strategies_vector)
unary_elementwise_handler.register_strategy()
elif method in RESHAPE_METHOD_OP:
reshape_handler = ReshapeHandler(node, self.device_mesh, strategies_vector)
reshape_handler.register_strategy()
# print(strategies_vector)
# if len(strategies_vector) == 0:
# print(node)
# assert False
else:
raise RuntimeError(f'{method} function is NOT supported now.')
# output node
if node.op == 'output':
if self.solver_options.fast:
# create sharding strategy for output
name = 'Replica Output'
input_nodes = strategies_vector.predecessor_nodes
input_sharding_specs = []
for input_node in input_nodes:
dim_partition_dict_for_input = {}
entire_shape = input_node._meta_data.shape
sharding_spec = ShardingSpec(self.device_mesh,
entire_shape,
dim_partition_dict=dim_partition_dict_for_input)
input_sharding_specs.append(sharding_spec)
dim_partition_dict = {}
output_sharding_spec = input_sharding_specs
# TODO: use meta_info_prop to profile memory cost
memory_cost = 0
resharding_costs = generate_resharding_costs(strategies_vector.predecessor_nodes,
input_sharding_specs)
# clear the resharding cost for the output node
# TODO: we may remove this in final version
for prev_node, resharding_cost_list in resharding_costs.items():
resharding_costs[prev_node] = [0] * len(resharding_cost_list)
sharding_strategy_attribute = ShardingStrategy(name,
output_sharding_spec,
memory_cost=memory_cost,
resharding_costs=resharding_costs,
input_shardings=tuple(input_sharding_specs))
strategies_vector.append(sharding_strategy_attribute)
self.remove_duplicated_strategy(strategies_vector)
setattr(node, 'strategies_vector', strategies_vector)
self.leaf_strategies.append(strategies_vector)
self.strategy_map[node] = strategies_vector
# remove no strategy nodes
remove_list = []
for strategies_vector in self.leaf_strategies:
if len(strategies_vector) == 0:
remove_list.append(strategies_vector.node)
for node in remove_list:
if node.strategies_vector in self.leaf_strategies:
self.leaf_strategies.remove(node.strategies_vector)
if node in self.strategy_map:
self.strategy_map.pop(node)
from copy import deepcopy
from pickletools import optimize
import pytest
import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
self.relu = nn.ReLU()
def forward(self, x):
x = x * 2
x = self.conv1(x)
x = x / 2
x = self.relu(x)
return x
def test_cost_graph():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv1, 2), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%truediv,), kwargs = {})
# return relu
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
# (x, mul):{(0, 0): 0}
# (mul, conv1):{(0, 0): 65547.1, (0, 1): 65547.1, (0, 2): 65547.1, (0, 3): 65547.1, (0, 4): 131105.30000000002, (0, 5): 131105.30000000002, (0, 6): 65547.1, (0, 7): 65547.1, (0, 8): 65547.1, (0, 9): 65547.1, (0, 10): 0, (0, 11): 0, (0, 12): 0, (0, 13): 131105.30000000002, (0, 14): 131105.30000000002}
# (conv1, truediv):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): 0, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (9, 0): inf, (10, 0): inf, (11, 0): inf, (12, 0): inf, (13, 0): inf, (14, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): 0, (6, 1): inf, (7, 1): inf, (8, 1): inf, (9, 1): inf, (10, 1): inf, (11, 1): inf, (12, 1): inf, (13, 1): inf, (14, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (9, 2): inf, (10, 2): inf, (11, 2): inf, (12, 2): inf, (13, 2): inf, (14, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (9, 3): inf, (10, 3): inf, (11, 3): inf, (12, 3): inf, (13, 3): inf, (14, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): inf, (5, 4): inf, (6, 4): 0, (7, 4): inf, (8, 4): 0, (9, 4): inf, (10, 4): inf, (11, 4): inf, (12, 4): inf, (13, 4): inf, (14, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): inf, (6, 5): inf, (7, 5): 0, (8, 5): inf, (9, 5): 0, (10, 5): inf, (11, 5): inf, (12, 5): inf, (13, 5): inf, (14, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): inf, (7, 6): inf, (8, 6): inf, (9, 6): inf, (10, 6): 0, (11, 6): 0, (12, 6): 0, (13, 6): inf, (14, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): inf, (8, 7): inf, (9, 7): inf, (10, 7): inf, (11, 7): inf, (12, 7): inf, (13, 7): 0, (14, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): inf, (9, 8): inf, (10, 8): inf, (11, 8): inf, (12, 8): inf, (13, 8): inf, (14, 8): 0}
# (truediv, relu):{(0, 0): 0, (1, 0): inf, (2, 0): inf, (3, 0): inf, (4, 0): inf, (5, 0): inf, (6, 0): inf, (7, 0): inf, (8, 0): inf, (0, 1): inf, (1, 1): 0, (2, 1): inf, (3, 1): inf, (4, 1): inf, (5, 1): inf, (6, 1): inf, (7, 1): inf, (8, 1): inf, (0, 2): inf, (1, 2): inf, (2, 2): 0, (3, 2): inf, (4, 2): inf, (5, 2): inf, (6, 2): inf, (7, 2): inf, (8, 2): inf, (0, 3): inf, (1, 3): inf, (2, 3): inf, (3, 3): 0, (4, 3): inf, (5, 3): inf, (6, 3): inf, (7, 3): inf, (8, 3): inf, (0, 4): inf, (1, 4): inf, (2, 4): inf, (3, 4): inf, (4, 4): 0, (5, 4): inf, (6, 4): inf, (7, 4): inf, (8, 4): inf, (0, 5): inf, (1, 5): inf, (2, 5): inf, (3, 5): inf, (4, 5): inf, (5, 5): 0, (6, 5): inf, (7, 5): inf, (8, 5): inf, (0, 6): inf, (1, 6): inf, (2, 6): inf, (3, 6): inf, (4, 6): inf, (5, 6): inf, (6, 6): 0, (7, 6): inf, (8, 6): inf, (0, 7): inf, (1, 7): inf, (2, 7): inf, (3, 7): inf, (4, 7): inf, (5, 7): inf, (6, 7): inf, (7, 7): 0, (8, 7): inf, (0, 8): inf, (1, 8): inf, (2, 8): inf, (3, 8): inf, (4, 8): inf, (5, 8): inf, (6, 8): inf, (7, 8): inf, (8, 8): 0}
# (relu, output):{(0, 0): 246019.30000000002, (1, 0): 246019.30000000002, (2, 0): 123009.1, (3, 0): 123009.1, (4, 0): 123009.1, (5, 0): 123009.1, (6, 0): 0, (7, 0): 246019.30000000002, (8, 0): 246019.30000000002}
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
# construct all node pairs
all_node_pairs = []
for node in graph.nodes:
if node.op == 'output':
continue
for child in node.users.keys():
all_node_pairs.append((node, child))
for node_pair in all_node_pairs:
assert node_pair in cost_graph.edge_costs
# construct merged node pairs
merged_node_pairs = []
node_list = list(graph.nodes)
# add (conv1_weight, conv2d), (conv1_bias, view), (conv2d, add), (view, add), (add, output), (x, conv2d) into check node pairs
merged_node_pairs.append((node_list[0], node_list[4]))
merged_node_pairs.append((node_list[2], node_list[4]))
merged_node_pairs.append((node_list[3], node_list[5]))
merged_node_pairs.append((node_list[5], node_list[6]))
merged_node_pairs.append((node_list[4], node_list[6]))
merged_node_pairs.append((node_list[6], node_list[-1]))
cost_graph.simplify_graph()
for node_pair in all_node_pairs:
if node_pair in merged_node_pairs:
assert node_pair in cost_graph.edge_costs
else:
assert node_pair not in cost_graph.edge_costs
if __name__ == '__main__':
test_cost_graph()
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.batch_norm_handler import BatchNormHandler
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
class BNModel(nn.Module):
def __init__(self, c):
super().__init__()
self.bn = nn.BatchNorm2d(c)
def forward(self, x):
x = x * 2
x = self.bn(x)
return x
def test_bn_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
tracer = ColoTracer()
model = BNModel(16)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %bn : [#users=1] = call_module[target=bn](args = (%mul,), kwargs = {})
# return bn
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# [x, mul, bn, output]
nodes = [node for node in gm.graph.nodes]
# find the sharding strategies for the input node of the bn node
# strategies_for_input = [[R, R, R, R], [R, S0, R, R], [R, S1, R, R], [S0, R, R, R], [S0, S1, R, R], [S1, R, R, R], [S1, S0, R, R]]
strategies_vector_for_input = StrategiesVector(nodes[1])
sharding_option = (None, 0, 1)
for first_sharding_index in sharding_option:
for second_sharding_index in sharding_option:
if first_sharding_index is not None and second_sharding_index == first_sharding_index:
continue
if first_sharding_index is None:
first_dim_spec = _DimSpec([])
else:
first_dim_spec = _DimSpec([first_sharding_index])
if second_sharding_index is None:
second_dim_spec = _DimSpec([])
else:
second_dim_spec = _DimSpec([second_sharding_index])
replica_dim_spec = _DimSpec([])
sharding_sequence = [first_dim_spec, second_dim_spec, replica_dim_spec, replica_dim_spec]
sharding_spec = ShardingSpec(device_mesh=device_mesh,
entire_shape=entire_shape,
sharding_sequence=sharding_sequence)
strategy_name = str(sharding_spec.sharding_sequence)
sharding_strategy = ShardingStrategy(name=strategy_name, output_sharding_spec=sharding_spec)
strategies_vector_for_input.append(sharding_strategy)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
# generate bn strategy
strategies_vector = StrategiesVector(node=nodes[2])
bn_handler = BatchNormHandler(
node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
)
bn_handler.register_strategy()
# ['RS0 = RS0 x S0', 'S1S0 = RS0 x S0', 'RS1 = RS1 x S1', 'S0S1 = RS1 x S1', 'RR = RR x R', 'S0R = RR x R', 'S1R = RR x R', 'S01R = RR x R', 'RS01 = RS01 x S01',
# 'S0R = S0R x R WITH SYNC_BN', 'S1R = S1R x R WITH SYNC_BN', 'S0S1 = S0S1 x S1 WITH SYNC_BN', 'S1S0 = S1S0 x S0 WITH SYNC_BN', 'S01R = S01R x R WITH SYNC_BN']
strategy_name_list = [strategy.name for strategy in bn_handler.strategies_vector]
# RS = RS x S and strategies based on it, such as
# SS = RS x S
assert 'RS0 = RS0 x S0' in strategy_name_list
assert 'S1S0 = RS0 x S0' in strategy_name_list
assert 'RS1 = RS1 x S1' in strategy_name_list
assert 'S0S1 = RS1 x S1' in strategy_name_list
# RR = RR x R and strategies based on it, such as
# SR = SR x R
assert 'RR = RR x R' in strategy_name_list
assert 'S0R = RR x R' in strategy_name_list
assert 'S1R = RR x R' in strategy_name_list
assert 'S01R = RR x R' in strategy_name_list
# RS01 = RS01 x S01
assert 'RS01 = RS01 x S01' in strategy_name_list
# SR = SR x R WITH SYNC_BN
assert 'S0R = S0R x R WITH SYNC_BN' in strategy_name_list
assert 'S1R = S1R x R WITH SYNC_BN' in strategy_name_list
# SS = SS x S WITH SYNC_BN
assert 'S0S1 = S0S1 x S1 WITH SYNC_BN' in strategy_name_list
assert 'S1S0 = S1S0 x S0 WITH SYNC_BN' in strategy_name_list
# S01R = S01R x R WITH SYNC_BN
assert 'S01R = S01R x R WITH SYNC_BN' in strategy_name_list
if __name__ == '__main__':
test_bn_handler()
from cProfile import run
import pytest
import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1)
self.conv2 = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, stride=2)
def forward(self, x):
x1 = self.conv1(x)
x2 = x1 + 1
x1 = torch.reshape(x1, [1, -1, 64, 1])
x3 = self.conv2(x1)
x3 = torch.reshape(x3, [4, 1, 64, -1])
x = x1 + x3
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv1 : [#users=2] = call_module[target=conv1](args = (%x,), kwargs = {})
# %add : [#users=0] = call_function[target=operator.add](args = (%conv1, 1), kwargs = {})
# %reshape : [#users=2] = call_function[target=torch.reshape](args = (%conv1, [1, -1, 64, 1]), kwargs = {})
# %conv2 : [#users=1] = call_module[target=conv2](args = (%reshape,), kwargs = {})
# %reshape_1 : [#users=1] = call_function[target=torch.reshape](args = (%conv2, [4, 1, 64, -1]), kwargs = {})
# %add_1 : [#users=1] = call_function[target=operator.add](args = (%reshape, %reshape_1), kwargs = {})
# return add_1
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
# [x, conv1, add, reshape, conv2, reshape_1, add_1, output]
nodes = [node for node in gm.graph.nodes]
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
strategy_map = strategies_constructor.strategy_map
# check a tensor add with a scalar case
conv1_strategies = strategy_map[nodes[1]]
add_strategies = strategy_map[nodes[2]]
add_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in add_strategies]
for strategy in conv1_strategies:
assert strategy.output_sharding_spec.sharding_sequence in add_strategies_cover_list
# check two tensors element-wise add case
add_1_strategies = strategy_map[nodes[6]]
assert len(add_1_strategies) == 25
if __name__ == '__main__':
test_conv_handler()
import pytest
import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class MatmulModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x1, x2):
x = torch.matmul(x1, x2)
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
tracer = ColoTracer()
model = MatmulModel()
input_sample = {'x1': torch.rand(4, 4, 8).to('meta'), 'x2': torch.rand(4, 1, 8, 4).to('meta')}
# graph():
# %x1 : torch.Tensor [#users=1] = placeholder[target=x1]
# %x2 : torch.Tensor [#users=1] = placeholder[target=x2]
# %matmul : [#users=1] = call_function[target=torch.matmul](args = (%x1, %x2), kwargs = {})
# return matmul
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
# [x1, x2, matmul, output]
nodes = [node for node in gm.graph.nodes]
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
strategy_map = strategies_constructor.strategy_map
matmul_strategies = strategy_map[nodes[2]]
assert len(matmul_strategies) == 30
if __name__ == '__main__':
test_conv_handler()
import pytest
import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import ConvHandler
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
def forward(self, x):
x = x * 2
x = self.conv(x)
return x
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# return add
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
conv_node = list(graph.nodes)[4]
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0R x RR', 'S1R = S1R x RR', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RR = RS0 x S0R', 'RR = RS1 x S1R', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR', 'S01R = S01R x RR', 'RR = RS01 x S01R']
strategy_name_list = [strategy.name for strategy in conv_node.strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RS = RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
# RR= RR x RR
assert 'RR = RR x RR' in strategy_name_list
# SR = SR x RR
assert 'S0R = S0R x RR' in strategy_name_list
assert 'S1R = S1R x RR' in strategy_name_list
assert 'S01R = S01R x RR' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
assert 'RR = RS01 x S01R' in strategy_name_list
if __name__ == '__main__':
test_conv_handler()
import pytest
import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.dot_handler import DotHandler
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
class LinearModel(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, x):
x = x * 2
x = self.linear(x)
return x
@pytest.mark.skip('F.linear is not supported in deprecated handler')
def test_dot_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8))
tracer = ColoTracer()
model = LinearModel(8, 16)
input_sample = {'x': torch.rand(4, 8).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %linear_weight : [#users=1] = get_attr[target=linear.weight]
# %linear_bias : [#users=1] = get_attr[target=linear.bias]
# %linear : [#users=1] = call_function[target=torch._C._nn.linear](args = (%mul, %linear_weight), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%linear, %linear_bias), kwargs = {})
# return add
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
linear_node = list(graph.nodes)[4]
# ['S0S1 = S0R x RS1', 'S1S0 = S1R x RS0', 'S0R = S0S1 x S1R', 'S1R = S1S0 x S0R', 'RS1 = RS0 x S0S1', 'RS0 = RS1 x S1S0', 'RS0 = RR x RS0', 'RS1 = RR x RS1', 'RR = RR x RR']
strategy_name_list = [strategy.name for strategy in linear_node.strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
if __name__ == '__main__':
test_dot_handler()
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.auto_parallel.tensor_shard.deprecated import sharding_strategy
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.layer_norm_handler import LayerNormHandler
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.device.device_mesh import DeviceMesh
class LNModel(nn.Module):
def __init__(self, c):
super().__init__()
self.ln = nn.LayerNorm(c)
def forward(self, x):
x = x * 2
x = self.ln(x)
return x
def test_bn_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 4, 128))
tracer = ColoTracer()
model = LNModel(128)
input_sample = {'x': torch.rand(4, 4, 128).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %ln : [#users=1] = call_module[target=ln](args = (%mul,), kwargs = {})
# return ln
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# [x, mul, ln, output]
nodes = [node for node in gm.graph.nodes]
sharding_spec_for_input = ShardingSpec(device_mesh, entire_shape, {})
sharding_strategy_for_input = ShardingStrategy('node_1', sharding_spec_for_input)
strategies_vector_for_input = StrategiesVector(nodes[1])
strategies_vector_for_input.append(sharding_strategy_for_input)
setattr(nodes[1], 'strategies_vector', strategies_vector_for_input)
# generate bn strategy
strategies_vector = StrategiesVector(node=nodes[2])
ln_handler = LayerNormHandler(
node=nodes[2],
device_mesh=device_mesh,
strategies_vector=strategies_vector,
)
ln_handler.register_strategy()
# ['[S0, R, R] = [S0, R, R] x [R]', '[R, S0, R] = [R, S0, R] x [R]', '[S1, R, R] = [S1, R, R] x [R]', '[R, S1, R] = [R, S1, R] x [R]',
# '[S0, S1, R] = [S0, S1, R] x [R]', '[S1, S0, R] = [S1, S0, R] x [R]', '[S01, R, R] = [S01, R, R] x [R]', '[R, S01, R] = [R, S01, R] x [R]', 'RR = RR x R']
strategy_name_list = [strategy.name for strategy in ln_handler.strategies_vector]
assert len(strategy_name_list) == 9
if __name__ == '__main__':
test_bn_handler()
import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
def forward(self, x):
x = self.conv(x)
x = torch.flatten(x)
return x
def test_conv_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%x, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# %flatten : [#users=1] = call_function[target=torch.flatten](args = (%add,), kwargs = {})
# return flatten
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
# [x, conv, flatten, output]
nodes = [node for node in gm.graph.nodes]
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
strategy_map = strategies_constructor.strategy_map
add_strategies = strategy_map[nodes[5]]
flatten_strategies = strategy_map[nodes[6]]
flatten_strategies_cover_list = [strategy.input_shardings[0].sharding_sequence for strategy in flatten_strategies]
for strategy in add_strategies:
assert strategy.output_sharding_spec.sharding_sequence in flatten_strategies_cover_list
if __name__ == '__main__':
test_conv_handler()
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.dim_in = dim_in
self.dim_out = dim_out
def forward(self, condition, x, y):
output = torch.where(condition, x, y)
return output
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_where_handler():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {
'condition': torch.rand(16, 32).to('meta'),
'x': torch.rand(16, 32).to('meta'),
'y': torch.rand(16, 32).to('meta')
}
# graph():
# %condition : torch.Tensor [#users=1] = placeholder[target=condition]
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %y : torch.Tensor [#users=1] = placeholder[target=y]
# %where : [#users=1] = call_function[target=torch.where](args = (%condition, %x, %y), kwargs = {})
# return where
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
# [condition, x, y, where, output]
nodes = [node for node in gm.graph.nodes]
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
strategy_map = strategies_constructor.strategy_map
# check a tensor add with a scalar case
where_node = strategy_map[nodes[3]]
# ['[S0, S1] = [S0, S1] x [S0, S1] x [S0, S1]', '[S1, S0] = [S1, S0] x [S1, S0] x [S1, S0]', '[S01, R] = [S01, R] x [S01, R] x [S01, R]',
# '[R, S01] = [R, S01] x [R, S01] x [R, S01]', '[S0, R] = [S0, R] x [S0, R] x [S0, R]', '[R, S0] = [R, S0] x [R, S0] x [R, S0]',
# '[S1, R] = [S1, R] x [S1, R] x [S1, R]', '[R, S1] = [R, S1] x [R, S1] x [R, S1]', '[R, R] = [R, R] x [R, R] x [R, R]']
assert len(where_node) == 9
if __name__ == '__main__':
test_where_handler()
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.initialize import launch
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.logging import disable_existing_loggers
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.passes.experimental.adding_shape_consistency_pass import shape_consistency_pass, solution_annotatation_pass
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3, padding=1, bias=False)
def forward(self, x):
x = self.conv(x)
return x
def check_apply(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
input = torch.rand(4, 4, 4, 4).cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = torch.Size((4, 4, 8, 8))
tracer = ColoTracer()
model = ConvModel(4, 4).cuda()
origin_output = model(input)
input_sample = {'x': torch.rand(4, 4, 4, 4).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %conv : [#users=1] = call_module[target=conv](args = (%mul,), kwargs = {})
# return conv
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
solution = list(ret[0])
sharding_spec_dict, origin_spec_dict = solution_annotatation_pass(gm, solution, device_mesh)
shape_consistency_pass(gm)
gm.recompile()
nodes = [node for node in gm.graph.nodes]
# TODO: wrap the gm to avoid the influence of the user training code
output = gm(input, sharding_spec_dict, origin_spec_dict)
assert output.equal(origin_output)
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_apply():
world_size = 4
run_func = partial(check_apply, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_apply()
from copy import deepcopy
import pytest
import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv1 = nn.Conv2d(c_in, c_out, kernel_size=3)
self.conv2 = nn.Conv2d(c_out, c_out, kernel_size=3)
self.conv3 = nn.Conv2d(c_out, c_out, kernel_size=3)
self.relu = nn.ReLU()
def forward(self, x):
x = x * 2
x = self.conv1(x)
x = self.conv2(x)
x = x / 2
x = self.conv3(x)
x = self.relu(x)
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_solver():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv1 : [#users=1] = call_module[target=conv1](args = (%mul,), kwargs = {})
# %conv2 : [#users=1] = call_module[target=conv2](args = (%conv1,), kwargs = {})
# %truediv : [#users=1] = call_function[target=operator.truediv](args = (%conv2, 2), kwargs = {})
# %conv3 : [#users=1] = call_module[target=conv3](args = (%truediv,), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%conv3,), kwargs = {})
# return relu
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
graph_analyser = GraphAnalyser(gm)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
# [ 0 0 13 13 13 13 13 0]
strategies_combination_list = ret[0]
assert solver.leaf_strategies[2][13].name == 'S01R = S01R x RR'
if __name__ == '__main__':
test_solver()
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
from copy import deepcopy
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
import transformers
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
BATCH_SIZE = 8
SEQ_LENGHT = 8
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
config = transformers.GPT2Config(n_position=1024, n_layer=1, n_head=12)
model = transformers.GPT2LMHeadModel(config=config)
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
token_type_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
kwargs = dict(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
graph = tracer.trace(root=model, meta_args=meta_args)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
print(graph)
strategies_constructor.build_strategies_and_cost()
for check_node, strategies_vector in strategies_constructor.strategy_map.items():
print(check_node, len(strategies_vector))
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=1620017824.0)
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
ret = solver.call_solver_serialized_args()
print(ret)
strategies_list = list(ret[0])
print(strategies_list)
computation_cost = 0
communication_cost = 0
memory_cost = 0
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
for index, node in enumerate(nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost
print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
if __name__ == '__main__':
test_cost_graph()
import torch
from torch.fx import GraphModule
import torch.nn as nn
import pytest
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.device.device_mesh import DeviceMesh
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.auto_parallel.tensor_shard.deprecated.cost_graph import CostGraph
from copy import deepcopy
from colossalai.auto_parallel.tensor_shard.deprecated import Solver
from torchvision.models import resnet34, resnet50
from colossalai.auto_parallel.tensor_shard.deprecated.constants import *
from colossalai.auto_parallel.tensor_shard.deprecated.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.testing.pytest_wrapper import run_on_environment_flag
class MLP(torch.nn.Module):
def __init__(self, dim: int):
super().__init__()
self.linear1 = torch.nn.Linear(dim, dim * 4)
self.linear2 = torch.nn.Linear(dim * 4, dim)
self.dropout = torch.nn.Dropout(0)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.linear1(x)
x = self.dropout(x)
x = self.relu(x)
x = self.linear2(x)
return x
@run_on_environment_flag(name='AUTO_PARALLEL')
def test_cost_graph():
physical_mesh_id = torch.arange(0, 8)
mesh_shape = (2, 4)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
shape_consistency_manager = ShapeConsistencyManager()
tracer = ColoTracer()
model = MLP(32)
input_sample = {'x': torch.rand(16, 32).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %linear1 : [#users=1] = call_module[target=linear1](args = (%x,), kwargs = {})
# %dropout : [#users=1] = call_module[target=dropout](args = (%linear1,), kwargs = {})
# %relu : [#users=1] = call_module[target=relu](args = (%dropout,), kwargs = {})
# %linear2 : [#users=1] = call_module[target=linear2](args = (%relu,), kwargs = {})
# return linear2
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
graph_analyser = GraphAnalyser(gm)
liveness_list = graph_analyser.liveness_analysis()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
cost_graph.simplify_graph()
# # megatron mode if no memory constraints
# solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser)
# all sharding on out feature dim if memory budget is not sufficient for megatron mode
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser, memory_budget=5500.0)
ret = solver.call_solver_serialized_args()
strategies_list = list(ret[0])
computation_cost = 0
communication_cost = 0
memory_cost = 0
for index, node in enumerate(graph.nodes):
print(node.name, node.strategies_vector[strategies_list[index]].name)
computation_cost += node.strategies_vector[strategies_list[index]].compute_cost
communication_cost += node.strategies_vector[strategies_list[index]].communication_cost
node_memory_cost = node.strategies_vector[strategies_list[index]].memory_cost
if isinstance(node_memory_cost, tuple):
node_memory_cost = node_memory_cost[0]
memory_cost += node_memory_cost
print(f'computation cost is {computation_cost}')
print(f'communication cost is {communication_cost}')
print(f'memory cost is {memory_cost}')
if __name__ == '__main__':
test_cost_graph()
from copy import deepcopy
import pytest
import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.auto_parallel.tensor_shard.deprecated.op_handler.conv_handler import CONV_STRATEGIES_LIST
from colossalai.auto_parallel.tensor_shard.deprecated.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.deprecated.sharding_strategy import ShardingStrategy, StrategiesVector
from colossalai.auto_parallel.tensor_shard.deprecated.strategies_constructor import StrategiesConstructor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx.proxy import ColoProxy
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
class ConvModel(nn.Module):
def __init__(self, c_in, c_out):
super().__init__()
self.conv = nn.Conv2d(c_in, c_out, kernel_size=3)
def forward(self, x):
x = x * 2
x = self.conv(x)
return x
def test_strategies_constructor():
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
# [[0, 1]
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 16, 64, 64))
tracer = ColoTracer()
model = ConvModel(16, 32)
input_sample = {'x': torch.rand(4, 16, 64, 64).to('meta')}
# graph():
# %x : torch.Tensor [#users=1] = placeholder[target=x]
# %mul : [#users=1] = call_function[target=operator.mul](args = (%x, 2), kwargs = {})
# %conv_weight : [#users=1] = get_attr[target=conv.weight]
# %conv_bias : [#users=1] = get_attr[target=conv.bias]
# %conv2d : [#users=1] = call_function[target=torch.conv2d](args = (%mul, %conv_weight), kwargs = {groups: 1, dilation: (1, 1), stride: (1, 1), padding: (0, 0)})
# %view : [#users=1] = call_method[target=view](args = (%conv_bias, [1, -1, 1, 1]), kwargs = {})
# %add : [#users=1] = call_function[target=operator.add](args = (%conv2d, %view), kwargs = {})
# return add
graph = tracer.trace(root=model, meta_args=input_sample)
print(graph)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
solver_options = SolverOptions(fast=True)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
assert strategies_constructor.leaf_strategies == []
assert strategies_constructor.strategy_map == {}
strategies_constructor.build_strategies_and_cost()
# check leaf_strategies
# In fast mode, placeholder node only has replica strategy.
assert strategies_constructor.leaf_strategies[0][0].name == 'Replica Placeholder'
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
assert strategies_constructor.leaf_strategies[1][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
# Third node is conv.
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
for strategy in strategies_constructor.leaf_strategies[4]:
conv_check_list.remove(strategy.name)
assert len(conv_check_list) == 0
# In fast mode, output node only has replica strategy.
assert strategies_constructor.leaf_strategies[7][0].name == 'Replica Output'
# check strategy_map
nodes = [node for node in graph.nodes]
# In fast mode, placeholder node only has replica strategy.
x = nodes[0]
assert strategies_constructor.strategy_map[x][0].name == 'Replica Placeholder'
# Second node is mul which is a element-wise node, therefore the output sharding spec is same as input sharding spec.
mul = nodes[1]
assert strategies_constructor.strategy_map[mul][0].name == '[R, R, R, R] -> [R, R, R, R]_0'
# fifth node is conv.
conv = nodes[4]
conv_check_list = deepcopy(CONV_STRATEGIES_LIST)
for strategy in strategies_constructor.strategy_map[conv]:
conv_check_list.remove(strategy.name)
assert len(conv_check_list) == 0
# In fast mode, output node only has replica strategy.
output = nodes[-1]
assert strategies_constructor.strategy_map[output][0].name == 'Replica Output'
if __name__ == '__main__':
test_strategies_constructor()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment