Unverified Commit b0f7c8bd authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[autoparallel] update CommSpec to CommActions (#1768)

* [autoparallel] update CommSpec to CommActions

* polish code
parent 16b0abf9
...@@ -202,16 +202,17 @@ class LinearFunctionHandler(NodeHandler): ...@@ -202,16 +202,17 @@ class LinearFunctionHandler(NodeHandler):
mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output} mapping = {"input": physical_input_operand, "other": physical_other_operand, "output": physical_output}
if self.node.args[2] is not None: if 'bias' in self.node.kwargs and self.node.kwargs['bias'] is not None:
# check if the other operand is a parameter # check if the other operand is a parameter
if isinstance(self.node.args[2]._meta_data, torch.nn.parameter.Parameter): if isinstance(self.node.kwargs["bias"]._meta_data, torch.nn.parameter.Parameter):
data_type = OperationDataType.PARAM data_type = OperationDataType.PARAM
else: else:
data_type = OperationDataType.ARG data_type = OperationDataType.ARG
physical_bias_operand = OperationData(name=str(self.node.args[2]), physical_bias_operand = OperationData(name=str(self.node.kwargs["bias"]),
type=data_type, type=data_type,
data=self.node.args[2]._meta_data) data=self.node.kwargs["bias"]._meta_data)
mapping['bias'] = physical_bias_operand mapping['bias'] = physical_bias_operand
return mapping return mapping
def post_process(self, strategy: ShardingStrategy): def post_process(self, strategy: ShardingStrategy):
......
...@@ -3,7 +3,12 @@ import operator ...@@ -3,7 +3,12 @@ import operator
from functools import reduce from functools import reduce
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, ShardingStrategy, TrainCycleItem from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
...@@ -204,12 +209,13 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -204,12 +209,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for weight and bias. # For SyncBN case, we don't need to do communication for weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node. # to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"], sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=mesh_dim_0) logical_process_axis=mesh_dim_0,
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_spec} communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
...@@ -238,12 +244,13 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -238,12 +244,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for gradients of weight and bias. # For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node. # to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"], sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0, mesh_dim_1]) logical_process_axis=[mesh_dim_0, mesh_dim_1],
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_spec} communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
...@@ -282,12 +289,13 @@ class BatchNormStrategyGenerator(StrategyGenerator): ...@@ -282,12 +289,13 @@ class BatchNormStrategyGenerator(StrategyGenerator):
# For SyncBN case, we don't need to do communication for gradients of weight and bias. # For SyncBN case, we don't need to do communication for gradients of weight and bias.
# TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation # TODO: the communication happens interally at SyncBN operation. We need to replace the BN operation
# to SyncBN operation instead of inserting a communication node. # to SyncBN operation instead of inserting a communication node.
output_comm_spec = self.get_communication_spec( output_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["output"], sharding_spec=sharding_spec_mapping["output"],
communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, communication_pattern=CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD,
logical_process_axis=[mesh_dim_0]) logical_process_axis=[mesh_dim_0],
comm_type=CommType.AFTER)
communication_action_mapping = {"output": output_comm_spec} communication_action_mapping = {"output": output_comm_action}
return self.get_sharding_strategy(name=name, return self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
......
import copy import copy
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
CommType,
MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import FollowingStrategyGenerator from .strategy_generator import FollowingStrategyGenerator
...@@ -83,11 +88,13 @@ class TensorStrategyGenerator(GetItemStrategyGenerator): ...@@ -83,11 +88,13 @@ class TensorStrategyGenerator(GetItemStrategyGenerator):
} }
sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping) sharding_spec_mapping = self.to_sharding_spec_mapping(dim_partition_dict_mapping)
if gather_input: if gather_input:
input_communication_spec = self.get_communication_spec( input_communication_action = self.get_communication_action(
sharding_spec_mapping["input"], sharding_spec_mapping["input"],
communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD, communication_pattern=CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
logical_process_axis=logical_process_axis) logical_process_axis=logical_process_axis,
communication_action_mapping["input"] = input_communication_spec comm_type=CommType.BEFORE,
arg_index=0)
communication_action_mapping["input"] = input_communication_action
name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}' name = f'{sharding_spec_mapping["output"].sharding_sequence} = {sharding_spec_mapping["input"].sharding_sequence}'
......
...@@ -3,9 +3,16 @@ import operator ...@@ -3,9 +3,16 @@ import operator
from functools import reduce from functools import reduce
from typing import List from typing import List
from colossalai.auto_parallel.tensor_shard.sharding_strategy import (MemoryCost, ShardingStrategy, TrainCycleItem) from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
from colossalai.auto_parallel.tensor_shard.utils import (enumerate_all_possible_1d_sharding, CommType,
enumerate_all_possible_2d_sharding) MemoryCost,
ShardingStrategy,
TrainCycleItem,
)
from colossalai.auto_parallel.tensor_shard.utils import (
enumerate_all_possible_1d_sharding,
enumerate_all_possible_2d_sharding,
)
from colossalai.tensor.shape_consistency import CollectiveCommPattern from colossalai.tensor.shape_consistency import CollectiveCommPattern
from .strategy_generator import StrategyGenerator from .strategy_generator import StrategyGenerator
...@@ -107,18 +114,20 @@ class LayerNormGenerator(StrategyGenerator): ...@@ -107,18 +114,20 @@ class LayerNormGenerator(StrategyGenerator):
total_mesh_dim_list = total_mesh_dim_list[0] total_mesh_dim_list = total_mesh_dim_list[0]
communication_action_mapping = {} communication_action_mapping = {}
other_comm_spec = self.get_communication_spec( other_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["other"], sharding_spec=sharding_spec_mapping["other"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list) logical_process_axis=total_mesh_dim_list,
communication_action_mapping["other"] = other_comm_spec comm_type=CommType.HOOK)
communication_action_mapping["other"] = other_comm_action
if self.has_bias: if self.has_bias:
bias_comm_spec = self.get_communication_spec( bias_comm_action = self.get_communication_action(
sharding_spec=sharding_spec_mapping["bias"], sharding_spec=sharding_spec_mapping["bias"],
communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, communication_pattern=CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD,
logical_process_axis=total_mesh_dim_list) logical_process_axis=total_mesh_dim_list,
communication_action_mapping["bias"] = bias_comm_spec comm_type=CommType.HOOK)
communication_action_mapping["bias"] = bias_comm_action
strategy = self.get_sharding_strategy(name=name, strategy = self.get_sharding_strategy(name=name,
sharding_spec_mapping=sharding_spec_mapping, sharding_spec_mapping=sharding_spec_mapping,
......
...@@ -41,7 +41,7 @@ def _split(tensor, comm_spec): ...@@ -41,7 +41,7 @@ def _split(tensor, comm_spec):
dim = comm_spec.shard_dim dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) length = tensor.shape[comm_spec.shard_dim] // len(rank_list)
start = length * rank_list.index(dist.get_rank()) start = length * rank_list.index(dist.get_rank())
output = torch.narrow(tensor, dim, start, length) output = torch.narrow(tensor, dim, start, length).contiguous()
return output return output
...@@ -76,6 +76,8 @@ def _all_reduce(tensor, comm_spec): ...@@ -76,6 +76,8 @@ def _all_reduce(tensor, comm_spec):
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list: for rank_list, process_group in process_groups_list:
if dist.get_rank() in rank_list: if dist.get_rank() in rank_list:
if not tensor.is_contiguous():
tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group) dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group)
return tensor return tensor
......
...@@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import ( ...@@ -11,6 +11,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import (
) )
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.fx import ColoGraphModule, ColoTracer from colossalai.fx import ColoGraphModule, ColoTracer
from colossalai.testing.pytest_wrapper import run_on_environment_flag
from colossalai.testing.utils import parameterize from colossalai.testing.utils import parameterize
...@@ -109,6 +110,7 @@ def test_linear_module_handler(bias): ...@@ -109,6 +110,7 @@ def test_linear_module_handler(bias):
assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1] assert bias_sharding_spec.sharding_sequence[-1] == output_sharding_spec.sharding_sequence[-1]
@run_on_environment_flag(name='AUTO_PARALLEL')
@parameterize('bias', [True, False]) @parameterize('bias', [True, False])
def test_linear_function_handler(bias): def test_linear_function_handler(bias):
model = nn.Linear(16, 32, bias=bias).to('meta') model = nn.Linear(16, 32, bias=bias).to('meta')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment