Unverified Commit 30e50c8b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] implemented all matmul strategy generator (#1650)

parent 03978aad
......@@ -50,8 +50,16 @@ class LinearModuleHandler(ModuleHandler):
if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
first_dim_partition = dim_partition_dict.pop(-1, None)
last_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if last_dim_partition:
dim_partition_dict[-1] = last_dim_partition
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
......@@ -111,8 +119,16 @@ class LinearFunctionHandler(NodeHandler):
if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0]
first_dim_partition = dim_partition_dict.pop(-1, None)
last_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if last_dim_partition:
dim_partition_dict[-1] = last_dim_partition
# re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
......
......@@ -33,12 +33,12 @@ class NodeHandler(ABC):
Register different sharding strategies for the current node.
"""
strategy_generators = self.get_strategy_generator()
operand_mapping = self.get_operation_data_mapping()
for generator in strategy_generators:
strategies = generator.generate(operand_mapping)
strategies = generator.generate()
self.strategies_vector.extend(strategies)
self.strategies_vector = map(self.post_process, self.strategies_vector)
strategies_vector = map(self.post_process, self.strategies_vector)
self.strategies_vector = list(strategies_vector)
return self.strategies_vector
def post_process(self, strategy: ShardingStrategy_V2):
......
......@@ -75,6 +75,12 @@ class OperationData:
if self.logical_shape is None:
self.logical_shape = self.data.shape
def __repr__(self) -> str:
return f'OperationData(name={self.name}, type={self.type})'
def __hash__(self) -> int:
return hash(f'{self.name}-{self.type}')
@dataclass
class TrainCycleItem:
......
......@@ -7,7 +7,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List, Union, Any
from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem
from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem, OperationDataType
class StrategyGenerator_V2(ABC):
......@@ -21,6 +21,10 @@ class StrategyGenerator_V2(ABC):
self.op_data = operation_data_mapping
self.device_mesh = device_mesh
def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.PARAM
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
communication_action_mapping: Dict[str, CommSpec]):
"""
......@@ -80,7 +84,7 @@ class StrategyGenerator_V2(ABC):
Compute the communication cost involved in the forward and backward iteration.
"""
comm_cost = TrainCycleItem(fwd=0, bwd=0)
comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
def _compute_and_add(data: OperationData, comm_spec: CommSpec):
num_ele_in_comm = comm_spec.get_comm_cost()
......@@ -92,7 +96,7 @@ class StrategyGenerator_V2(ABC):
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost.
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used,
# so total cost is either for fwd or bwd.
if comm_spec.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD:
if comm_spec.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
comm_cost.fwd += cost
elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
comm_cost.fwd += cost
......@@ -102,9 +106,12 @@ class StrategyGenerator_V2(ABC):
# check if communication action exists
# if so, loop over each action and compute the cost of each action
if strategy.communication_actions is not None:
for operand, comm_spec in strategy.communication_actions:
for operand, comm_spec in strategy.communication_actions.items():
_compute_and_add(operand, comm_spec)
# update the total cost
comm_cost.total = comm_cost.fwd + comm_cost.bwd
# update the communication cost attribute in-place
strategy.communication_cost = comm_cost
return strategy
......@@ -146,7 +153,7 @@ class StrategyGenerator_V2(ABC):
pass
@abstractmethod
def validate(self, *args, **kwargs) -> bool:
def validate(self) -> bool:
"""
Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation.
......
......@@ -8,9 +8,9 @@ from colossalai.device.device_mesh import DeviceMesh
def test_linear_module_handler():
model = nn.Sequential(nn.Linear(10, 20).to('meta'))
model = nn.Sequential(nn.Linear(16, 32).to('meta'))
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
......@@ -34,32 +34,55 @@ def test_linear_module_handler():
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 10])
assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 10])
assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10])
assert mapping['other'].data.shape == torch.Size([32, 16])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20])
assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 20])
assert mapping['output'].data.shape == torch.Size([4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
def test_linear_function_handler():
model = nn.Linear(10, 20).to('meta')
model = nn.Linear(16, 32).to('meta')
tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')})
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4)
......@@ -77,27 +100,50 @@ def test_linear_function_handler():
assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 10])
assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 10])
assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10])
assert mapping['other'].data.shape == torch.Size([32, 16])
assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20])
assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20])
assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['output'].name == "linear"
assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 20])
assert mapping['output'].data.shape == torch.Size([4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
if __name__ == '__main__':
test_linear_module_handler()
......
from curses import meta
from math import dist
from xml.dom import HierarchyRequestErr
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
......
from curses import meta
from math import dist
from xml.dom import HierarchyRequestErr
from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment