Unverified Commit 0f304236 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[tensor] shape consistency generate transform path and communication cost (#1435)

* [tensor] shape consistency output transform path and communication cost

* polish code
parent 5774fe02
This diff is collapsed.
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from copy import deepcopy
from enum import Enum
from functools import reduce
import operator
ALLGATHER_COST = 20
SHARD_COST = 5
STEP_PENALTY = 6
NAN = 'nan'
class _DimSpec:
......@@ -15,6 +26,7 @@ class _DimSpec:
def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0
self.shard_list = shard_list
self.build_difference_2d_dict()
def __eq__(self, other):
return str(self) == str(other)
......@@ -27,11 +39,101 @@ class _DimSpec:
target += str(dim)
return target
def _convert_str_to_shard_list(self, str_spec):
'''
Conver str_spec into shard_list.
Argument:
str_spec(str): dim spec in str type.
'''
if str_spec == 'R':
return []
if str_spec == 'S0':
return [0]
if str_spec == 'S1':
return [1]
if str_spec == 'S01':
return [0, 1]
def build_difference_2d_dict(self):
'''
Build a difference maping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
'''
source_spec_list = ['R', 'S0', 'S1', 'S01']
target_spec_list = ['R', 'S0', 'S1', 'S01']
difference_dict = {}
for source_spec in source_spec_list:
for target_spec in target_spec_list:
legal_sharding_dims = []
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
source_shard_list = self._convert_str_to_shard_list(source_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
# source same as target
if source_shard_list == target_shard_list:
difference = 0
# all_gather(source) -> target
elif len(source_shard_list
) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list:
difference = ALLGATHER_COST
# shard(source) -> target
elif len(source_shard_list) == len(
target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[
-1] not in source_shard_list:
difference = SHARD_COST
# S1 -> S0 or S0 -> S1
elif len(source_shard_list) == len(target_shard_list):
# source -> R -> target
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST
# R -> S01
elif len(source_shard_list) == len(target_shard_list) - 2:
difference = SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> R
elif len(source_shard_list) == len(target_shard_list) + 2:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST
# S1 -> S01
elif len(source_shard_list) == len(target_shard_list) - 1:
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> S1
elif len(source_shard_list) == len(target_shard_list) + 1:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST
else:
difference = NAN
difference_dict[spec_pair] = difference
self.difference_dict = difference_dict
def difference(self, other):
'''
This function is temporarily NOT implemented, it will be codesigned with ShapeConsistency feature.
The difference between two _DimSpec.
Argument:
other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
'''
pass
difference = self.difference_dict[(str(self), str(other))]
return difference
class ShardingSpec:
......@@ -43,8 +145,9 @@ class ShardingSpec:
Argument:
device_mesh(DeviceMesh): A logical view of a physical mesh.
entire_shape(torch.Size): The entire shape of tensor before sharded.
dim_partition_dict(Dict[int, List[int]]): The key is the dimension of tensor to be sharded,
dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension.
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
'''
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None):
......@@ -79,12 +182,18 @@ class ShardingSpec:
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
def convert_dict_to_shard_sequence(self):
'''
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
'''
sharding_sequence = [_DimSpec([])] * len(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items():
sharding_sequence[dim] = _DimSpec(shard_list)
self.sharding_sequence = sharding_sequence
def convert_shard_sequence_to_dict(self):
'''
Convert sharding_sequence into dim_partition_dict.
'''
new_dim_partition_dict = {}
for index, dim_spec in enumerate(self.sharding_sequence):
if not dim_spec.is_replica:
......@@ -95,6 +204,45 @@ class ShardingSpec:
def sharding_sequence_difference(self, other):
'''
This function is temporarily NOT implemented, it will be codesigned with ShapeConsistency feature.
This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
pair of sharding sequence.
Example:
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
dim_partition_dict_to_compare = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
Argument:
other(ShardingSpec): The ShardingSpec to compared with.
Return:
difference(int): Difference between two ShardingSpec.
'''
pass
assert len(self.sharding_sequence) == len(
other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.'
difference = 0
for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence):
difference += orig_dim_spec.difference(other_dim_spec)
return difference
def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
sharded_shape[dim] //= shard_partitions
return torch.Size(sharded_shape)
......@@ -5,6 +5,90 @@ import torch.nn as nn
from colossalai.tensor.colo_tensor import ColoTensor
def all_gather_simulator(target_pair):
'''
Simulating all-gather operation, analyze the communication cost
and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
Therefore, all gather operation just remove the last element in shard list,
e.g.:
all-gather(S01) -> S0
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, shard_list = target_pair
new_shard_list = shard_list[:-1]
return new_shard_list
def all_to_all_simulator(f_target_pair, b_target_pair):
'''
Simulating all-to-all operation, analyze the communication cost
and simulate the influence of the DimSpec.
We BANNED all representations which shard_list in decreasing order,
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
e.g.:
all-to-all(S0, S1) -> [S01, R]
all-to-all(S0, R) -> [R, S0]
Otherwise, we extend the front shard_list to behind.
e.g.:
all-to-all(R, S1) -> [S1, R]
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, f_shard_list = f_target_pair
_, b_shard_list = b_target_pair
if not len(b_shard_list):
b_shard_list.extend(f_shard_list)
f_shard_list = []
else:
f_shard_list.extend(b_shard_list)
b_shard_list = []
return f_shard_list, b_shard_list
def shard_simulator(target_pair, legal_sharding_dims):
'''
Simulating shard operation, analyze the communication cost(always ZERO)
and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
In addition, We BANNED all representations which shard_list in decreasing order,
such as S10, so shard(S0) -> S10 is NOT allowed.
Therefore, for the R dimension, we could just append any legal sharding dim on it.
e.g.:
shard(R) -> S0
For the S dimension, we need to make sure the shard_list after sharding still keep rising order.
e.g:
shard(S0) -> S01
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, shard_list = target_pair
shard_list_list = []
for dim in legal_sharding_dims:
if len(shard_list) != 0 and dim <= shard_list[-1]:
continue
new_shard_list = shard_list + [dim]
shard_list_list.append(new_shard_list)
return shard_list_list
# The function is credited to PyTorch Team
def named_params_with_colotensor(
module: nn.Module,
......
from colossalai.tensor.shape_consistency import ShapeConsistencyManager
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((64, 32, 16))
shape_consistency_manager = ShapeConsistencyManager()
def test_one_step_transform():
def test_shape_consistency():
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8, 6))
dim_partition_dict = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
# {DistSpec:
# shard_sequence: R,S1,R
# device_mesh_shape: (4, 4): 0, DistSpec:
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec:
# shard_sequence: S0,R,R
# device_mesh_shape: (4, 4): 0}
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)}
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
assert '[R, S1, R]' in [
......@@ -39,12 +42,12 @@ def test_shape_consistency():
# device_mesh_shape: (4, 4)
sharding_spec_all2all = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_all2all)
# {DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4): 0, DistSpec:
# shard_sequence: R,S1,S0
# device_mesh_shape: (4, 4): 0, DistSpec:
# shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): 0}
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 1), 0), DistSpec:
# shard_sequence: R,S1,S0
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec:
# shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)}
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, 0)
assert '[S01, R, R]' in [
......@@ -63,12 +66,12 @@ def test_shape_consistency():
# device_mesh_shape: (4, 4)
sharding_spec_shard = ShardingSpec(device_mesh, entire_shape, dim_partition_shard)
# {DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4): 0, DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4): 0, DistSpec:
# shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): 0}
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1), 0), DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec:
# shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)}
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, 0)
assert '[S01, R, R]' in [
......@@ -82,5 +85,48 @@ def test_shape_consistency():
]
def test_shape_consistency():
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
# DistSpec:
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
sharding_spec_source, sharding_spec_target)
transform_path_str = '->'.join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path])
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
# all-gather(S01) -> S0
assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.ALLGATHER
assert comm_action_sequence[0].gather_dim == 1
assert comm_action_sequence[0].logical_process_axis == 1
# all-to-all(R, S0) -> [S0, R]
assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALLTOALL
assert comm_action_sequence[1].gather_dim == 1
assert comm_action_sequence[1].shard_dim == 0
assert comm_action_sequence[1].logical_process_axis == 0
# shard(S0) -> [S01]
assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SHARD
assert comm_action_sequence[2].shard_dim == 0
assert comm_action_sequence[2].logical_process_axis == 1
assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]',
'[S01, R, R]')][0] == transform_path
assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]',
'[S01, R, R]')][1] == comm_action_sequence
if __name__ == '__main__':
test_one_step_transform()
test_shape_consistency()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment