Unverified Commit fea3cb66 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] support addmm in tracer and solver (#1961)

* [fx] patch addmm

* [autoparallel] support addmm in tracer and solver
parent f7e276fa
from .addmm_handler import ADDMMFunctionHandler
from .batch_norm_handler import BatchNormModuleHandler
from .binary_elementwise_handler import BinaryElementwiseHandler
from .bmm_handler import AddBMMFunctionHandler, BMMFunctionHandler
......@@ -18,5 +19,5 @@ __all__ = [
'LinearFunctionHandler', 'LinearModuleHandler', 'BMMFunctionHandler', 'AddBMMFunctionHandler',
'LayerNormModuleHandler', 'BatchNormModuleHandler', 'ConvModuleHandler', 'ConvFunctionHandler',
'UnaryElementwiseHandler', 'ReshapeHandler', 'PlacehodlerHandler', 'OuputHandler', 'WhereHandler',
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry'
'NormPoolingHandler', 'BinaryElementwiseHandler', 'MatMulHandler', 'operator_registry', 'ADDMMFunctionHandler'
]
from typing import Dict, List, Union
import torch
from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec, ShapeConsistencyManager
from ..sharding_strategy import CommAction, CommType, OperationData, OperationDataType, ShardingStrategy
from ..utils import comm_actions_for_oprands, recover_sharding_spec_for_broadcast_shape
from .node_handler import NodeHandler
from .registry import operator_registry
from .strategy import LinearProjectionStrategyGenerator, StrategyGenerator
__all__ = ['ADDMMFunctionHandler']
@operator_registry.register(torch.addmm)
@operator_registry.register(torch.Tensor.addmm)
class ADDMMFunctionHandler(NodeHandler):
"""
This is a NodeHandler class which deals with the batched matrix multiplication operation in PyTorch.
Such operations including `torch.bmm` and `torch.Tensor.bmm` require the tensor to be 3D, thus, there is
no logical-physical shape conversion in this handler.
"""
def _infer_op_data_type(self, tensor: torch.Tensor) -> OperationDataType:
if isinstance(tensor, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM
else:
data_type = OperationDataType.ARG
return data_type
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
# input operand
input_data = self.node.args[1]._meta_data
physical_input_operand = OperationData(name=str(self.node.args[1]),
type=self._infer_op_data_type(input_data),
data=input_data)
# other operand
other_data = self.node.args[2]._meta_data
physical_other_operand = OperationData(name=str(self.node.args[2]),
type=self._infer_op_data_type(other_data),
data=other_data)
# bias physical shape
bias_logical_shape = self.node._meta_data.shape
bias_data = self.node.args[0]._meta_data
physical_bias_operand = OperationData(name=str(self.node.args[0]),
type=self._infer_op_data_type(bias_data),
data=bias_data,
logical_shape=bias_logical_shape)
# output
physical_output = OperationData(name=str(self.node), type=OperationDataType.OUTPUT, data=self.node._meta_data)
mapping = {
"input": physical_input_operand,
"other": physical_other_operand,
"output": physical_output,
'bias': physical_bias_operand
}
return mapping
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='addmm'))
return generators
def post_process(self, strategy: ShardingStrategy) -> Union[ShardingStrategy, List[ShardingStrategy]]:
# convert bias from its logical sharding spec to its physical sharding spec
op_data_mapping = self.get_operation_data_mapping()
bias_op_data = op_data_mapping['bias']
bias_physical_shape = bias_op_data.data.shape
bias_logical_shape = bias_op_data.logical_shape
bias_sharding_spec = strategy.get_sharding_spec_by_name(bias_op_data.name)
bias_sharding_spec, removed_dims = recover_sharding_spec_for_broadcast_shape(
bias_sharding_spec, bias_logical_shape, bias_physical_shape)
strategy.sharding_specs[bias_op_data] = bias_sharding_spec
if len(removed_dims) > 0:
comm_action = comm_actions_for_oprands(node=self.node,
removed_dims=removed_dims,
op_data=bias_op_data,
sharding_spec=bias_sharding_spec)
strategy.communication_actions[bias_op_data] = comm_action
return strategy
......@@ -140,7 +140,8 @@ class LinearModuleHandler(ModuleHandler):
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......@@ -199,7 +200,8 @@ class LinearFunctionHandler(NodeHandler):
def get_strategy_generator(self) -> List[StrategyGenerator]:
op_data_mapping = self.get_operation_data_mapping()
generators = []
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......
......@@ -363,7 +363,8 @@ class MatMulHandler(NodeHandler):
elif self.matmul_type == MatMulType.MV:
generators.append(MatVecStrategyGenerator(op_data_mapping, self.device_mesh))
elif self.matmul_type == MatMulType.MM:
generators.append(LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh))
generators.append(
LinearProjectionStrategyGenerator(op_data_mapping, self.device_mesh, linear_projection_type='linear'))
return generators
def get_operation_data_mapping(self) -> Dict[str, OperationData]:
......
......@@ -209,6 +209,10 @@ class MatVecStrategyGenerator(MatMulStrategyGenerator):
class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
def __init__(self, operation_data_mapping, device_mesh, linear_projection_type='linear'):
super().__init__(operation_data_mapping, device_mesh)
self.linear_projection_type = linear_projection_type
def update_compute_cost(self, strategy: ShardingStrategy) -> ShardingStrategy:
# C = AB
# C: [M, N], A: [M, P], B: [P, N]
......@@ -272,14 +276,21 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
"other": {
-1: [mesh_dim_1]
},
"bias": {
-1: [mesh_dim_1]
},
"output": {
0: [mesh_dim_0],
-1: [mesh_dim_1]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {-1: [mesh_dim_1]}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0], -1: [mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# set communication action
......@@ -293,13 +304,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
......@@ -308,7 +319,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
communication_action_mapping['input'] = input_comm_action
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
......@@ -347,6 +360,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
0: [mesh_dim_0]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action mapping
......@@ -360,13 +383,13 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
if self.is_param('other'):
other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.HOOK)
else:
other_comm_action = self.get_communication_action(
sharding_spec_mapping["output"],
sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=mesh_dim_0,
comm_type=CommType.BEFORE,
......@@ -375,7 +398,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
communication_action_mapping['other'] = other_comm_action
communication_action_mapping['output'] = output_comm_action
if self.has_bias:
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec_mapping["bias"],
......@@ -415,6 +440,10 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
-1: [mesh_dim_1]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
......@@ -451,7 +480,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
......@@ -484,7 +514,8 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
-1: [mesh_dim]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication actions
......@@ -515,6 +546,16 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
0: [mesh_dim_0, mesh_dim_1]
},
}
# linear bias only has one dimension, but addmm bias has same dimensions
# as the output logically.
if self.linear_projection_type == 'linear':
dim_partition_dict_mapping['bias'] = {}
elif self.linear_projection_type == 'addmm':
dim_partition_dict_mapping['bias'] = {0: [mesh_dim_0, mesh_dim_1]}
else:
raise ('Unsupported linear projection type')
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
......@@ -534,7 +575,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
arg_index=1)
communication_action_mapping['other'] = other_comm_action
if self.has_bias:
# we only add allreduce comm action for linear bias, because
# allreduce comm action for addmm bias will be considered in post processing
if self.has_bias and self.linear_projection_type == 'linear':
if self.is_param('bias'):
bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping['bias'],
......@@ -568,6 +611,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
"bias": {},
"output": {},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
......@@ -600,6 +646,9 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
-1: [mesh_dim_0, mesh_dim_1]
},
}
# We don't have to do anything special for bias here, because
# the bias is already the same sharding spec as the output.
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
# get communication action
......@@ -626,10 +675,7 @@ class LinearProjectionStrategyGenerator(MatMulStrategyGenerator):
assert input_data.data.dim() > 0 and other_data.data.dim() == 2
assert other_data.logical_shape[0] == input_data.logical_shape[-1]
# check if bias has the same a valid dim
has_bias = "bias" in self.op_data
if has_bias:
if self.has_bias:
bias_data = self.op_data['bias']
assert bias_data.logical_shape[-1] == other_data.logical_shape[-1]
......
......@@ -72,11 +72,21 @@ def torch_linear(input, mat2, bias=None, *, out=None):
def torch_addbmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
batch_size, n, m = mat1.shape
_, n, _ = mat1.shape
_, _, p = mat2.shape
return torch.empty(n, p, device="meta")
@meta_patched_function.register(torch.addmm)
@meta_patched_function.register(torch.Tensor.addmm)
def torch_addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None):
if out is not None:
raise ValueError("Don't support in-place abs for MetaTensor analysis")
n, _ = mat1.shape
_, p = mat2.shape
return torch.empty(n, p, device="meta")
@meta_patched_function.register(torch.var_mean)
def torch_var_mean(input, dim, unbiased=True, keepdim=False, *, out=None):
assert out is None, 'saving to out is not supported yet'
......
from faulthandler import disable
from functools import partial
from xml.dom import WrongDocumentErr
import pytest
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from typing_extensions import Self
from colossalai.auto_parallel.tensor_shard.node_handler import ADDMMFunctionHandler
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
OperationData,
OperationDataType,
ShardingStrategy,
StrategiesVector,
)
from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_close, parameterize, rerun_if_address_is_in_use
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.utils import free_port
from tests.test_auto_parallel.test_tensor_shard.test_node_handler.utils import numerical_test_for_node_strategy
class AddmmModel(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, m1, m2):
x = torch.addmm(input, m1, m2)
return x
def check_linear_function_handler(rank, input_shape, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
model = AddmmModel().cuda()
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
input = torch.rand(input_shape).cuda()
m1 = torch.rand(4, 8).cuda()
m2 = torch.rand(8, 16).cuda()
# the index of addmm node in computation graph
node_index = 3
# strategy number of linear node
strategy_number = 10
# construct input args
input_args = [input, m1, m2]
# construct meta arg names
meta_arg_names = ['input', 'm1', 'm2']
numerical_test_for_node_strategy(model=model,
device_mesh=device_mesh,
node_index=node_index,
strategy_number=strategy_number,
input_args=input_args,
meta_arg_names=meta_arg_names)
tracer = ColoTracer()
graph = tracer.trace(model,
meta_args={
"input": torch.rand(input_shape).to('meta'),
'm1': torch.rand(4, 8).to('meta'),
'm2': torch.rand(8, 16).to('meta'),
})
gm = ColoGraphModule(model, graph)
# [input_1, m1, m2, addmm, output]
node_list = list(graph.nodes)
addmm_node = node_list[3]
strategies_vector = StrategiesVector(addmm_node)
# build handler
handler = ADDMMFunctionHandler(node=addmm_node, device_mesh=device_mesh, strategies_vector=strategies_vector)
handler.register_strategy(compute_resharding_cost=False)
strategy_name_list = [val.name for val in strategies_vector]
# check operation data mapping
mapping = handler.get_operation_data_mapping()
assert mapping['input'].name == "m1"
assert mapping['input'].data.shape == torch.Size([4, 8])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 8])
assert mapping['other'].name == "m2"
assert mapping['other'].data.shape == torch.Size([8, 16])
assert mapping['other'].type == OperationDataType.ARG
assert mapping['other'].logical_shape == torch.Size([8, 16])
assert mapping['bias'].name == "input_1"
assert mapping['bias'].data.shape == torch.Size(input_shape)
assert mapping['bias'].type == OperationDataType.ARG
assert mapping['bias'].logical_shape == torch.Size([4, 16])
assert mapping['output'].name == "addmm"
assert mapping['output'].data.shape == torch.Size([4, 16])
assert mapping['output'].type == OperationDataType.OUTPUT
# one strategy will be converted to different physical sharding spec
assert len(strategy_name_list) > 8
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
for strategy in strategies_vector:
strategy: ShardingStrategy
input_sharding_spec = strategy.get_sharding_spec_by_name('m1')
weight_sharding_spec = strategy.get_sharding_spec_by_name('m2')
output_sharding_spec = strategy.get_sharding_spec_by_name('addmm')
bias_sharding_spec = strategy.get_sharding_spec_by_name('input_1')
# make sure the sharding matches across different operation data
assert input_sharding_spec.sharding_sequence[:-1] == output_sharding_spec.sharding_sequence[:-1]
assert weight_sharding_spec.sharding_sequence[0] == input_sharding_spec.sharding_sequence[1]
assert weight_sharding_spec.sharding_sequence[1] == output_sharding_spec.sharding_sequence[1]
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@parameterize('input_shape', [(16,), (4, 16)])
@run_on_environment_flag(name='AUTO_PARALLEL')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_addmm_handler(input_shape):
world_size = 4
run_func_function = partial(check_linear_function_handler,
input_shape=input_shape,
world_size=world_size,
port=free_port())
mp.spawn(run_func_function, nprocs=world_size)
if __name__ == '__main__':
test_addmm_handler()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment