Unverified Commit 702dbc52 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[tensor] use communication autograd func (#1617)

* [tensor] use communication autograd func

* change all to all comm spec info

* rename pattern and distinguish fwd/bwd

* polish code
parent c7ac0f4a
...@@ -8,6 +8,7 @@ import warnings ...@@ -8,6 +8,7 @@ import warnings
from functools import reduce from functools import reduce
import functools import functools
import operator import operator
from .constants import INFINITY_COST
def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh, def generate_sharding_spec(input_: Union[Node, torch.Tensor], device_mesh: DeviceMesh,
...@@ -68,19 +69,16 @@ def generate_resharding_costs(nodes: List[Node], ...@@ -68,19 +69,16 @@ def generate_resharding_costs(nodes: List[Node],
for strategy in input_node.strategies_vector: for strategy in input_node.strategies_vector:
input_sharding_spec = strategy.output_sharding_spec input_sharding_spec = strategy.output_sharding_spec
assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.' assert isinstance(input_sharding_spec, ShardingSpec), f'The input node should NOT be a tuple of tensor.'
# compute the resharding cost during forward phase try:
_, _, resharding_cost_forward = shape_consistency_manager.shape_consistency(input_sharding_spec, input_spec) # compute the resharding cost
_, _, total_resharding_cost = shape_consistency_manager.shape_consistency(
if count_backward: input_sharding_spec, input_spec)
# In backward phase, we should convert grad with target_spec into input_sharding_spec
_, _, resharding_cost_backward = shape_consistency_manager.shape_consistency( # we need multiply the size of elem dtype to get correct communication cost
input_spec, input_sharding_spec) resharding_cost = total_resharding_cost * size_per_elem_bytes
total_resharding_cost = resharding_cost_forward + resharding_cost_backward except AssertionError as e:
else: warnings.warn(f'{e}')
total_resharding_cost = resharding_cost_forward resharding_cost = INFINITY_COST
# we need multiply the size of elem dtype to get correct communication cost
resharding_cost = total_resharding_cost * size_per_elem_bytes
resharding_costs[input_node].append(resharding_cost) resharding_costs[input_node].append(resharding_cost)
return resharding_costs return resharding_costs
......
...@@ -4,7 +4,7 @@ import operator ...@@ -4,7 +4,7 @@ import operator
__all__ = [ __all__ = [
'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP', 'ELEMENTWISE_MODULE_OP', 'ELEMENTWISE_FUNC_OP', 'RESHAPE_FUNC_OP', 'CONV_MODULE_OP', 'CONV_FUNC_OP',
'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP', 'LINEAR_MODULE_OP', 'LINEAR_FUNC_OP', 'BATCHNORM_MODULE_OP', 'POOL_MODULE_OP', 'NON_PARAM_FUNC_OP', 'BCAST_FUNC_OP',
'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP' 'EMBEDDING_MODULE_OP', 'LAYERNORM_MODULE_OP', 'ELEMENTWISE_METHOD_OP', 'RESHAPE_METHOD_OP', 'INFINITY_COST'
] ]
ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU] ELEMENTWISE_MODULE_OP = [torch.nn.Dropout, torch.nn.ReLU]
......
This diff is collapsed.
...@@ -33,7 +33,10 @@ def check_all_gather(device_mesh, rank): ...@@ -33,7 +33,10 @@ def check_all_gather(device_mesh, rank):
sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(device_mesh, tensor_to_check.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1) # CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.ALLGATHER, sharding_spec, gather_dim=1, logical_process_axis=1) comm_spec = CommSpec(CollectiveCommPattern.GATHER_FWD_SPLIT_BWD,
sharding_spec,
gather_dim=1,
logical_process_axis=1)
comm_spec.covert_spec_to_action(sharded_tensor_to_comm) comm_spec.covert_spec_to_action(sharded_tensor_to_comm)
assert sharded_tensor_to_comm.equal(tensor_to_check) assert sharded_tensor_to_comm.equal(tensor_to_check)
...@@ -56,7 +59,7 @@ def check_shard(device_mesh, rank): ...@@ -56,7 +59,7 @@ def check_shard(device_mesh, rank):
sharding_spec = ShardingSpec(device_mesh, tensor_to_shard.shape, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(device_mesh, tensor_to_shard.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.SHARD, sharding_spec, shard_dim=1, logical_process_axis=1) comm_spec = CommSpec(CollectiveCommPattern.SPLIT_FWD_GATHER_BWD, sharding_spec, shard_dim=1, logical_process_axis=1)
comm_spec.covert_spec_to_action(tensor_to_shard) comm_spec.covert_spec_to_action(tensor_to_shard)
if rank in (0, 2): if rank in (0, 2):
...@@ -102,7 +105,7 @@ def check_all_to_all(device_mesh, rank): ...@@ -102,7 +105,7 @@ def check_all_to_all(device_mesh, rank):
sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(device_mesh, torch.Size((4, 2)), dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1) # CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1)
comm_spec = CommSpec(CollectiveCommPattern.ALLTOALL, comm_spec = CommSpec(CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD,
sharding_spec, sharding_spec,
gather_dim=0, gather_dim=0,
shard_dim=1, shard_dim=1,
...@@ -112,7 +115,7 @@ def check_all_to_all(device_mesh, rank): ...@@ -112,7 +115,7 @@ def check_all_to_all(device_mesh, rank):
assert tensor_to_comm.equal(tensor_to_check) assert tensor_to_comm.equal(tensor_to_check)
def check_all_reduce(device_mesh, rank): def check_all_reduce_fwd(device_mesh, rank):
# tensor to comm # tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank tensor_to_comm = torch.ones(2, 2).cuda() * rank
...@@ -133,8 +136,25 @@ def check_all_reduce(device_mesh, rank): ...@@ -133,8 +136,25 @@ def check_all_reduce(device_mesh, rank):
# device_mesh_shape: (2, 2) # device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:0) comm_spec = CommSpec(CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=0)
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE, sharding_spec, logical_process_axis=0) comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_all_reduce_bwd(device_mesh, rank):
# tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank
tensor_to_check = torch.ones(2, 2).cuda() * rank
dim_partition_dict = {}
# DistSpec:
# shard_sequence: R,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
comm_spec = CommSpec(CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD, sharding_spec, logical_process_axis=0)
comm_spec.covert_spec_to_action(tensor_to_comm) comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check) assert tensor_to_comm.equal(tensor_to_check)
...@@ -157,7 +177,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank): ...@@ -157,7 +177,7 @@ def check_all_reduce_in_flatten_device_mesh(device_mesh, rank):
sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(device_mesh, tensor_to_comm.shape, dim_partition_dict=dim_partition_dict)
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1]) # CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE, sharding_spec, logical_process_axis=[0, 1]) comm_spec = CommSpec(CollectiveCommPattern.REDUCE_FWD_IDENTITY_BWD, sharding_spec, logical_process_axis=[0, 1])
comm_spec.covert_spec_to_action(tensor_to_comm) comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check) assert tensor_to_comm.equal(tensor_to_check)
...@@ -184,7 +204,8 @@ def check_comm(rank, world_size, port): ...@@ -184,7 +204,8 @@ def check_comm(rank, world_size, port):
check_all_to_all(device_mesh, rank) check_all_to_all(device_mesh, rank)
# test all reduce # test all reduce
check_all_reduce(device_mesh, rank) check_all_reduce_fwd(device_mesh, rank)
check_all_reduce_bwd(device_mesh, rank)
# test all reduce in 1D flatten device mesh # test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh(device_mesh, rank) check_all_reduce_in_flatten_device_mesh(device_mesh, rank)
......
...@@ -106,18 +106,18 @@ def test_shape_consistency(): ...@@ -106,18 +106,18 @@ def test_shape_consistency():
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]' assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
# all-gather(S01) -> S0 # all-gather(S01) -> S0
assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.ALLGATHER assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
assert comm_action_sequence[0].gather_dim == 1 assert comm_action_sequence[0].gather_dim == 1
assert comm_action_sequence[0].logical_process_axis == 1 assert comm_action_sequence[0].logical_process_axis == 1
# all-to-all(R, S0) -> [S0, R] # all-to-all(R, S0) -> [S0, R]
assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALLTOALL assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
assert comm_action_sequence[1].gather_dim == 1 assert comm_action_sequence[1].gather_dim == 1
assert comm_action_sequence[1].shard_dim == 0 assert comm_action_sequence[1].shard_dim == 0
assert comm_action_sequence[1].logical_process_axis == 0 assert comm_action_sequence[1].logical_process_axis == 0
# shard(S0) -> [S01] # shard(S0) -> [S01]
assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SHARD assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
assert comm_action_sequence[2].shard_dim == 0 assert comm_action_sequence[2].shard_dim == 0
assert comm_action_sequence[2].logical_process_axis == 1 assert comm_action_sequence[2].logical_process_axis == 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment