Unverified Commit 30e50c8b authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[autoparallel] implemented all matmul strategy generator (#1650)

parent 03978aad
...@@ -50,8 +50,16 @@ class LinearModuleHandler(ModuleHandler): ...@@ -50,8 +50,16 @@ class LinearModuleHandler(ModuleHandler):
if op_data.name == "weight": if op_data.name == "weight":
assert op_data.logical_shape != op_data.data.shape assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and last dim of the linear module weight # switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0] first_dim_partition = dim_partition_dict.pop(-1, None)
last_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if last_dim_partition:
dim_partition_dict[-1] = last_dim_partition
# re-init the sharding spec # re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
...@@ -111,8 +119,16 @@ class LinearFunctionHandler(NodeHandler): ...@@ -111,8 +119,16 @@ class LinearFunctionHandler(NodeHandler):
if op_data.name == str(self.node.args[1]): if op_data.name == str(self.node.args[1]):
assert op_data.logical_shape != op_data.data.shape assert op_data.logical_shape != op_data.data.shape
dim_partition_dict = sharding_spec.dim_partition_dict dim_partition_dict = sharding_spec.dim_partition_dict
# switch first and last dim of the linear module weight # switch first and last dim of the linear module weight
dim_partition_dict[0], dim_partition_dict[-1] = dim_partition_dict[-1], dim_partition_dict[0] first_dim_partition = dim_partition_dict.pop(-1, None)
last_dim_partition = dim_partition_dict.pop(0, None)
if first_dim_partition:
dim_partition_dict[0] = first_dim_partition
if last_dim_partition:
dim_partition_dict[-1] = last_dim_partition
# re-init the sharding spec # re-init the sharding spec
sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict) sharding_spec.__init__(sharding_spec.device_mesh, sharding_spec.entire_shape, dim_partition_dict)
......
...@@ -33,12 +33,12 @@ class NodeHandler(ABC): ...@@ -33,12 +33,12 @@ class NodeHandler(ABC):
Register different sharding strategies for the current node. Register different sharding strategies for the current node.
""" """
strategy_generators = self.get_strategy_generator() strategy_generators = self.get_strategy_generator()
operand_mapping = self.get_operation_data_mapping()
for generator in strategy_generators: for generator in strategy_generators:
strategies = generator.generate(operand_mapping) strategies = generator.generate()
self.strategies_vector.extend(strategies) self.strategies_vector.extend(strategies)
self.strategies_vector = map(self.post_process, self.strategies_vector) strategies_vector = map(self.post_process, self.strategies_vector)
self.strategies_vector = list(strategies_vector)
return self.strategies_vector return self.strategies_vector
def post_process(self, strategy: ShardingStrategy_V2): def post_process(self, strategy: ShardingStrategy_V2):
......
...@@ -75,6 +75,12 @@ class OperationData: ...@@ -75,6 +75,12 @@ class OperationData:
if self.logical_shape is None: if self.logical_shape is None:
self.logical_shape = self.data.shape self.logical_shape = self.data.shape
def __repr__(self) -> str:
return f'OperationData(name={self.name}, type={self.type})'
def __hash__(self) -> int:
return hash(f'{self.name}-{self.type}')
@dataclass @dataclass
class TrainCycleItem: class TrainCycleItem:
......
...@@ -7,7 +7,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec ...@@ -7,7 +7,7 @@ from colossalai.tensor.shape_consistency import CollectiveCommPattern, CommSpec
from colossalai.tensor.sharding_spec import ShardingSpec from colossalai.tensor.sharding_spec import ShardingSpec
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from typing import Dict, List, Union, Any from typing import Dict, List, Union, Any
from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem from ..sharding_strategy import OperationData, ShardingStrategy_V2, TrainCycleItem, OperationDataType
class StrategyGenerator_V2(ABC): class StrategyGenerator_V2(ABC):
...@@ -21,6 +21,10 @@ class StrategyGenerator_V2(ABC): ...@@ -21,6 +21,10 @@ class StrategyGenerator_V2(ABC):
self.op_data = operation_data_mapping self.op_data = operation_data_mapping
self.device_mesh = device_mesh self.device_mesh = device_mesh
def is_param(self, op_data_name):
other_data = self.op_data[op_data_name]
return other_data.type == OperationDataType.PARAM
def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec], def get_sharding_strategy(self, name: str, sharding_spec_mapping: Dict[str, ShardingSpec],
communication_action_mapping: Dict[str, CommSpec]): communication_action_mapping: Dict[str, CommSpec]):
""" """
...@@ -80,7 +84,7 @@ class StrategyGenerator_V2(ABC): ...@@ -80,7 +84,7 @@ class StrategyGenerator_V2(ABC):
Compute the communication cost involved in the forward and backward iteration. Compute the communication cost involved in the forward and backward iteration.
""" """
comm_cost = TrainCycleItem(fwd=0, bwd=0) comm_cost = TrainCycleItem(fwd=0, bwd=0, total=0)
def _compute_and_add(data: OperationData, comm_spec: CommSpec): def _compute_and_add(data: OperationData, comm_spec: CommSpec):
num_ele_in_comm = comm_spec.get_comm_cost() num_ele_in_comm = comm_spec.get_comm_cost()
...@@ -92,7 +96,7 @@ class StrategyGenerator_V2(ABC): ...@@ -92,7 +96,7 @@ class StrategyGenerator_V2(ABC):
# TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost. # TODO: comm_spec.get_comm_cost should return a TrainCycleItem instead of the total cost.
# it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used, # it works fine here because only REDUCE_FWD_IDENTITY_BWD and IDENTITY_FWD_ALLREDUCE_BWD are used,
# so total cost is either for fwd or bwd. # so total cost is either for fwd or bwd.
if comm_spec.comm_pattern == CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD: if comm_spec.comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
comm_cost.fwd += cost comm_cost.fwd += cost
elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD: elif comm_spec.comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
comm_cost.fwd += cost comm_cost.fwd += cost
...@@ -102,9 +106,12 @@ class StrategyGenerator_V2(ABC): ...@@ -102,9 +106,12 @@ class StrategyGenerator_V2(ABC):
# check if communication action exists # check if communication action exists
# if so, loop over each action and compute the cost of each action # if so, loop over each action and compute the cost of each action
if strategy.communication_actions is not None: if strategy.communication_actions is not None:
for operand, comm_spec in strategy.communication_actions: for operand, comm_spec in strategy.communication_actions.items():
_compute_and_add(operand, comm_spec) _compute_and_add(operand, comm_spec)
# update the total cost
comm_cost.total = comm_cost.fwd + comm_cost.bwd
# update the communication cost attribute in-place # update the communication cost attribute in-place
strategy.communication_cost = comm_cost strategy.communication_cost = comm_cost
return strategy return strategy
...@@ -146,7 +153,7 @@ class StrategyGenerator_V2(ABC): ...@@ -146,7 +153,7 @@ class StrategyGenerator_V2(ABC):
pass pass
@abstractmethod @abstractmethod
def validate(self, *args, **kwargs) -> bool: def validate(self) -> bool:
""" """
Validate if the operands are of desired shape. Validate if the operands are of desired shape.
If True, means this generator can be used for the current operation. If True, means this generator can be used for the current operation.
......
...@@ -8,9 +8,9 @@ from colossalai.device.device_mesh import DeviceMesh ...@@ -8,9 +8,9 @@ from colossalai.device.device_mesh import DeviceMesh
def test_linear_module_handler(): def test_linear_module_handler():
model = nn.Sequential(nn.Linear(10, 20).to('meta')) model = nn.Sequential(nn.Linear(16, 32).to('meta'))
tracer = ColoTracer() tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')}) graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
...@@ -34,32 +34,55 @@ def test_linear_module_handler(): ...@@ -34,32 +34,55 @@ def test_linear_module_handler():
assert mapping['input'].name == "input_1" assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 10]) assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 10]) assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['other'].name == "weight" assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10]) assert mapping['other'].data.shape == torch.Size([32, 16])
assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['bias'].name == "bias" assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20]) assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['output'].name == "_0" assert mapping['output'].name == "_0"
assert mapping['output'].data.is_meta assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 20]) assert mapping['output'].data.shape == torch.Size([4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
def test_linear_function_handler(): def test_linear_function_handler():
model = nn.Linear(10, 20).to('meta') model = nn.Linear(16, 32).to('meta')
tracer = ColoTracer() tracer = ColoTracer()
graph = tracer.trace(model, meta_args={"input": torch.rand(4, 10).to('meta')}) graph = tracer.trace(model, meta_args={"input": torch.rand(4, 16).to('meta')})
gm = ColoGraphModule(model, graph) gm = ColoGraphModule(model, graph)
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4)
...@@ -77,27 +100,50 @@ def test_linear_function_handler(): ...@@ -77,27 +100,50 @@ def test_linear_function_handler():
assert mapping['input'].name == "input_1" assert mapping['input'].name == "input_1"
assert mapping['input'].data.is_meta assert mapping['input'].data.is_meta
assert mapping['input'].data.shape == torch.Size([4, 10]) assert mapping['input'].data.shape == torch.Size([4, 16])
assert mapping['input'].type == OperationDataType.ARG assert mapping['input'].type == OperationDataType.ARG
assert mapping['input'].logical_shape == torch.Size([4, 10]) assert mapping['input'].logical_shape == torch.Size([4, 16])
assert mapping['other'].name == "weight" assert mapping['other'].name == "weight"
assert mapping['other'].data.is_meta assert mapping['other'].data.is_meta
assert mapping['other'].data.shape == torch.Size([20, 10]) assert mapping['other'].data.shape == torch.Size([32, 16])
assert mapping['other'].type == OperationDataType.PARAM assert mapping['other'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['bias'].name == "bias" assert mapping['bias'].name == "bias"
assert mapping['bias'].data.is_meta assert mapping['bias'].data.is_meta
assert mapping['bias'].data.shape == torch.Size([20]) assert mapping['bias'].data.shape == torch.Size([32])
assert mapping['bias'].type == OperationDataType.PARAM assert mapping['bias'].type == OperationDataType.PARAM
assert mapping['other'].logical_shape == torch.Size([10, 20]) assert mapping['other'].logical_shape == torch.Size([16, 32])
assert mapping['output'].name == "linear" assert mapping['output'].name == "linear"
assert mapping['output'].data.is_meta assert mapping['output'].data.is_meta
assert mapping['output'].data.shape == torch.Size([4, 20]) assert mapping['output'].data.shape == torch.Size([4, 32])
assert mapping['output'].type == OperationDataType.OUTPUT assert mapping['output'].type == OperationDataType.OUTPUT
strategies_vector = handler.register_strategy()
strategy_name_list = [val.name for val in strategies_vector]
# SS = SR x RS
assert 'S0S1 = S0R x RS1' in strategy_name_list
assert 'S1S0 = S1R x RS0' in strategy_name_list
# SR = SS x SR
assert 'S0R = S0S1 x S1R' in strategy_name_list
assert 'S1R = S1S0 x S0R' in strategy_name_list
# RS = RS x SS
assert 'RS0 = RS1 x S1S0' in strategy_name_list
assert 'RS1 = RS0 x S0S1' in strategy_name_list
# RR = RS x SR
assert 'RR = RS0 x S0R' in strategy_name_list
assert 'RR = RS1 x S1R' in strategy_name_list
# RS= RR x RS
assert 'RS0 = RR x RS0' in strategy_name_list
assert 'RS1 = RR x RS1' in strategy_name_list
if __name__ == '__main__': if __name__ == '__main__':
test_linear_module_handler() test_linear_module_handler()
......
from curses import meta
from math import dist
from xml.dom import HierarchyRequestErr
from colossalai.fx.tracer import meta_patch from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops from colossalai.fx.tracer.meta_patch.patched_function import python_ops
......
from curses import meta
from math import dist
from xml.dom import HierarchyRequestErr
from colossalai.fx.tracer import meta_patch from colossalai.fx.tracer import meta_patch
from colossalai.fx.tracer.tracer import ColoTracer from colossalai.fx.tracer.tracer import ColoTracer
from colossalai.fx.tracer.meta_patch.patched_function import python_ops from colossalai.fx.tracer.meta_patch.patched_function import python_ops
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment