Unverified Commit 0f304236 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[tensor] shape consistency generate transform path and communication cost (#1435)

* [tensor] shape consistency output transform path and communication cost

* polish code
parent 5774fe02
This diff is collapsed.
import torch
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from copy import deepcopy
from enum import Enum
from functools import reduce
import operator
ALLGATHER_COST = 20
SHARD_COST = 5
STEP_PENALTY = 6
NAN = 'nan'
class _DimSpec: class _DimSpec:
...@@ -15,6 +26,7 @@ class _DimSpec: ...@@ -15,6 +26,7 @@ class _DimSpec:
def __init__(self, shard_list): def __init__(self, shard_list):
self.is_replica = len(shard_list) == 0 self.is_replica = len(shard_list) == 0
self.shard_list = shard_list self.shard_list = shard_list
self.build_difference_2d_dict()
def __eq__(self, other): def __eq__(self, other):
return str(self) == str(other) return str(self) == str(other)
...@@ -27,11 +39,101 @@ class _DimSpec: ...@@ -27,11 +39,101 @@ class _DimSpec:
target += str(dim) target += str(dim)
return target return target
def _convert_str_to_shard_list(self, str_spec):
'''
Conver str_spec into shard_list.
Argument:
str_spec(str): dim spec in str type.
'''
if str_spec == 'R':
return []
if str_spec == 'S0':
return [0]
if str_spec == 'S1':
return [1]
if str_spec == 'S01':
return [0, 1]
def build_difference_2d_dict(self):
'''
Build a difference maping for 2D device mesh case. It will be used to
compute the difference between DimSpec pairs.
'''
source_spec_list = ['R', 'S0', 'S1', 'S01']
target_spec_list = ['R', 'S0', 'S1', 'S01']
difference_dict = {}
for source_spec in source_spec_list:
for target_spec in target_spec_list:
legal_sharding_dims = []
spec_pair = (deepcopy(source_spec), deepcopy(target_spec))
source_shard_list = self._convert_str_to_shard_list(source_spec)
target_shard_list = self._convert_str_to_shard_list(target_spec)
# source same as target
if source_shard_list == target_shard_list:
difference = 0
# all_gather(source) -> target
elif len(source_shard_list
) == len(target_shard_list) + 1 and source_shard_list[:-1] == target_shard_list:
difference = ALLGATHER_COST
# shard(source) -> target
elif len(source_shard_list) == len(
target_shard_list) - 1 and source_shard_list == target_shard_list[:-1] and target_shard_list[
-1] not in source_shard_list:
difference = SHARD_COST
# S1 -> S0 or S0 -> S1
elif len(source_shard_list) == len(target_shard_list):
# source -> R -> target
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST
# R -> S01
elif len(source_shard_list) == len(target_shard_list) - 2:
difference = SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> R
elif len(source_shard_list) == len(target_shard_list) + 2:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST
# S1 -> S01
elif len(source_shard_list) == len(target_shard_list) - 1:
difference = ALLGATHER_COST + STEP_PENALTY + SHARD_COST + STEP_PENALTY + SHARD_COST
# S01 -> S1
elif len(source_shard_list) == len(target_shard_list) + 1:
difference = ALLGATHER_COST + STEP_PENALTY + ALLGATHER_COST + STEP_PENALTY + SHARD_COST
else:
difference = NAN
difference_dict[spec_pair] = difference
self.difference_dict = difference_dict
def difference(self, other): def difference(self, other):
''' '''
This function is temporarily NOT implemented, it will be codesigned with ShapeConsistency feature. The difference between two _DimSpec.
Argument:
other(_DimSpec): the dim spec to compare with.
Return:
difference(int): the difference between two _DimSpec.
Example:
dim_spec = _DimSpec([0])
other_dim_spec = _DimSpec([0, 1])
print(dim_spec.difference(other_dim_spec))
Output:
5
''' '''
pass difference = self.difference_dict[(str(self), str(other))]
return difference
class ShardingSpec: class ShardingSpec:
...@@ -43,8 +145,9 @@ class ShardingSpec: ...@@ -43,8 +145,9 @@ class ShardingSpec:
Argument: Argument:
device_mesh(DeviceMesh): A logical view of a physical mesh. device_mesh(DeviceMesh): A logical view of a physical mesh.
entire_shape(torch.Size): The entire shape of tensor before sharded. entire_shape(torch.Size): The entire shape of tensor before sharded.
dim_partition_dict(Dict[int, List[int]]): The key is the dimension of tensor to be sharded, dim_partition_dict(Dict[int, List[int]], optional): The key is the dimension of tensor to be sharded,
and the value of the key decribe which logical axis will be sharded in that dimension. and the value of the key decribe which logical axis will be sharded in that dimension.
sharding_sequence(List[_DimSpec], optional): A straight view of ShardingSpec looks like [R, R, S0, S1].
''' '''
def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None): def __init__(self, device_mesh, entire_shape, dim_partition_dict=None, sharding_sequence=None):
...@@ -79,12 +182,18 @@ class ShardingSpec: ...@@ -79,12 +182,18 @@ class ShardingSpec:
f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.") f"find an invalid sharding axis {element} in dim_partition_dict in tensor dimension {dim}.")
def convert_dict_to_shard_sequence(self): def convert_dict_to_shard_sequence(self):
'''
Convert dim_partition_dict into list of _DimSpec, and assign it to sharding_sequence.
'''
sharding_sequence = [_DimSpec([])] * len(self.entire_shape) sharding_sequence = [_DimSpec([])] * len(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items(): for dim, shard_list in self.dim_partition_dict.items():
sharding_sequence[dim] = _DimSpec(shard_list) sharding_sequence[dim] = _DimSpec(shard_list)
self.sharding_sequence = sharding_sequence self.sharding_sequence = sharding_sequence
def convert_shard_sequence_to_dict(self): def convert_shard_sequence_to_dict(self):
'''
Convert sharding_sequence into dim_partition_dict.
'''
new_dim_partition_dict = {} new_dim_partition_dict = {}
for index, dim_spec in enumerate(self.sharding_sequence): for index, dim_spec in enumerate(self.sharding_sequence):
if not dim_spec.is_replica: if not dim_spec.is_replica:
...@@ -95,6 +204,45 @@ class ShardingSpec: ...@@ -95,6 +204,45 @@ class ShardingSpec:
def sharding_sequence_difference(self, other): def sharding_sequence_difference(self, other):
''' '''
This function is temporarily NOT implemented, it will be codesigned with ShapeConsistency feature. This function is a naive version of difference computation. It just simply accumulates difference every dimension between the
pair of sharding sequence.
Example:
dim_partition_dict = {0: [0, 1]}
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
dim_partition_dict_to_compare = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_to_compare = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_to_compare)
print(sharding_spec.sharding_sequence_difference(sharding_spec_to_compare))
Output:
25
Argument:
other(ShardingSpec): The ShardingSpec to compared with.
Return:
difference(int): Difference between two ShardingSpec.
''' '''
pass assert len(self.sharding_sequence) == len(
other.sharding_sequence), f'Cannot compare difference for two sharding specs with different length.'
difference = 0
for orig_dim_spec, other_dim_spec in zip(self.sharding_sequence, other.sharding_sequence):
difference += orig_dim_spec.difference(other_dim_spec)
return difference
def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
sharded_shape[dim] //= shard_partitions
return torch.Size(sharded_shape)
...@@ -5,6 +5,90 @@ import torch.nn as nn ...@@ -5,6 +5,90 @@ import torch.nn as nn
from colossalai.tensor.colo_tensor import ColoTensor from colossalai.tensor.colo_tensor import ColoTensor
def all_gather_simulator(target_pair):
'''
Simulating all-gather operation, analyze the communication cost
and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as all-gather(S012)->S02 is NOT allowed.
Therefore, all gather operation just remove the last element in shard list,
e.g.:
all-gather(S01) -> S0
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, shard_list = target_pair
new_shard_list = shard_list[:-1]
return new_shard_list
def all_to_all_simulator(f_target_pair, b_target_pair):
'''
Simulating all-to-all operation, analyze the communication cost
and simulate the influence of the DimSpec.
We BANNED all representations which shard_list in decreasing order,
such as S10, so all-to-all(S0, S1) -> RS01 is NOT allowed.
Therefore, if the behind shard_list is not None, we just extend it to the front shard_list.
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
e.g.:
all-to-all(S0, S1) -> [S01, R]
all-to-all(S0, R) -> [R, S0]
Otherwise, we extend the front shard_list to behind.
e.g.:
all-to-all(R, S1) -> [S1, R]
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, f_shard_list = f_target_pair
_, b_shard_list = b_target_pair
if not len(b_shard_list):
b_shard_list.extend(f_shard_list)
f_shard_list = []
else:
f_shard_list.extend(b_shard_list)
b_shard_list = []
return f_shard_list, b_shard_list
def shard_simulator(target_pair, legal_sharding_dims):
'''
Simulating shard operation, analyze the communication cost(always ZERO)
and simulate the influence of the DimSpec.
We don't allow uncontiguous layout, such as shard(S0)->S02 is NOT allowed.
In addition, We BANNED all representations which shard_list in decreasing order,
such as S10, so shard(S0) -> S10 is NOT allowed.
Therefore, for the R dimension, we could just append any legal sharding dim on it.
e.g.:
shard(R) -> S0
For the S dimension, we need to make sure the shard_list after sharding still keep rising order.
e.g:
shard(S0) -> S01
Argument:
target_pair(Tuple[int, List[int]]): The first element is the dimension of tensor to be sharded,
and the second element decribes which logical axis will be sharded in that dimension.
'''
_, shard_list = target_pair
shard_list_list = []
for dim in legal_sharding_dims:
if len(shard_list) != 0 and dim <= shard_list[-1]:
continue
new_shard_list = shard_list + [dim]
shard_list_list.append(new_shard_list)
return shard_list_list
# The function is credited to PyTorch Team # The function is credited to PyTorch Team
def named_params_with_colotensor( def named_params_with_colotensor(
module: nn.Module, module: nn.Module,
......
from colossalai.tensor.shape_consistency import ShapeConsistencyManager from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
import torch import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((64, 32, 16))
shape_consistency_manager = ShapeConsistencyManager()
def test_one_step_transform():
def test_shape_consistency():
physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4)
# [[0, 1, 2, 3],
# [4, 5, 6, 7],
# [8, 9, 10,11],
# [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
entire_shape = torch.Size((4, 8, 6))
dim_partition_dict = {0: [0], 1: [1]} dim_partition_dict = {0: [0], 1: [1]}
# DistSpec: # DistSpec:
# shard_sequence: S0,S1,R # shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict) sharding_spec = ShardingSpec(device_mesh, entire_shape, dim_partition_dict)
shape_consistency_manager = ShapeConsistencyManager()
# {DistSpec: # {DistSpec:
# shard_sequence: R,S1,R # shard_sequence: R,S1,R
# device_mesh_shape: (4, 4): 0, DistSpec: # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:0, logical_process_axis:0), 0), DistSpec:
# shard_sequence: S0,R,R # shard_sequence: S0,R,R
# device_mesh_shape: (4, 4): 0} # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:allgather, gather_dim:1, logical_process_axis:1), 0)}
rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0) rst_dict = shape_consistency_manager.get_all_all_gather_spec(sharding_spec, 0)
assert '[R, S1, R]' in [ assert '[R, S1, R]' in [
...@@ -40,11 +43,11 @@ def test_shape_consistency(): ...@@ -40,11 +43,11 @@ def test_shape_consistency():
sharding_spec_all2all = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_all2all) sharding_spec_all2all = ShardingSpec(device_mesh, entire_shape, dim_partition_dict_all2all)
# {DistSpec: # {DistSpec:
# shard_sequence: S01,R,R # shard_sequence: S01,R,R
# device_mesh_shape: (4, 4): 0, DistSpec: # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:0, logical_process_axis: 1), 0), DistSpec:
# shard_sequence: R,S1,S0 # shard_sequence: R,S1,S0
# device_mesh_shape: (4, 4): 0, DistSpec: # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:0, shard_dim:2, logical_process_axis: 0), 0), DistSpec:
# shard_sequence: S0,R,S1 # shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): 0} # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:all2all, gather_dim:1, shard_dim:2, logical_process_axis: 1), 0)}
rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, 0) rst_dict_all2all = shape_consistency_manager.get_all_all_to_all_spec(sharding_spec_all2all, 0)
assert '[S01, R, R]' in [ assert '[S01, R, R]' in [
...@@ -64,11 +67,11 @@ def test_shape_consistency(): ...@@ -64,11 +67,11 @@ def test_shape_consistency():
sharding_spec_shard = ShardingSpec(device_mesh, entire_shape, dim_partition_shard) sharding_spec_shard = ShardingSpec(device_mesh, entire_shape, dim_partition_shard)
# {DistSpec: # {DistSpec:
# shard_sequence: S01,R,R # shard_sequence: S01,R,R
# device_mesh_shape: (4, 4): 0, DistSpec: # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:0, logical_process_axis:1), 0), DistSpec:
# shard_sequence: S0,S1,R # shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4): 0, DistSpec: # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:1, logical_process_axis:1), 0), DistSpec:
# shard_sequence: S0,R,S1 # shard_sequence: S0,R,S1
# device_mesh_shape: (4, 4): 0} # device_mesh_shape: (4, 4): (CommSpec:(comm_pattern:shard, shard_dim:2, logical_process_axis:1), 0)}
rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, 0) rst_dict_shard = shape_consistency_manager.get_all_shard_spec(sharding_spec_shard, 0)
assert '[S01, R, R]' in [ assert '[S01, R, R]' in [
...@@ -82,5 +85,48 @@ def test_shape_consistency(): ...@@ -82,5 +85,48 @@ def test_shape_consistency():
] ]
def test_shape_consistency():
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
# DistSpec:
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(device_mesh, entire_shape, dim_partition_source)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(device_mesh, entire_shape, dim_partition_target)
transform_path, comm_action_sequence, total_cost = shape_consistency_manager.shape_consistency(
sharding_spec_source, sharding_spec_target)
transform_path_str = '->'.join([str(sharding_spec.sharding_sequence) for sharding_spec in transform_path])
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
# all-gather(S01) -> S0
assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.ALLGATHER
assert comm_action_sequence[0].gather_dim == 1
assert comm_action_sequence[0].logical_process_axis == 1
# all-to-all(R, S0) -> [S0, R]
assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALLTOALL
assert comm_action_sequence[1].gather_dim == 1
assert comm_action_sequence[1].shard_dim == 0
assert comm_action_sequence[1].logical_process_axis == 0
# shard(S0) -> [S01]
assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SHARD
assert comm_action_sequence[2].shard_dim == 0
assert comm_action_sequence[2].logical_process_axis == 1
assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]',
'[S01, R, R]')][0] == transform_path
assert shape_consistency_manager.cached_spec_pairs_transform_path[('[R, S01, R]',
'[S01, R, R]')][1] == comm_action_sequence
if __name__ == '__main__': if __name__ == '__main__':
test_one_step_transform()
test_shape_consistency() test_shape_consistency()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment