Unverified Commit 8e4e8601 authored by YuliangLiu0306's avatar YuliangLiu0306 Committed by GitHub
Browse files

[DTensor] implement layout converter (#3055)

* [DTensor] refactor LayoutConverter for DTensor

* polish code

* polish docstring
parent 89aa7926
This diff is collapsed.
import operator
from functools import reduce
from typing import Dict
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern, CommSpec
from colossalai.tensor.d_tensor.layout import Layout
def get_comm_cost(layout: Layout, comm_spec: CommSpec, forward_only: bool = False) -> Dict[str, float]:
'''
This method is used to compute the communication cost for a given layout and comm_spec.
For all_gather, all2all, and all_reduce operation, the formula provided in DeviceMesh with alpha-beta model is used to
compute the communication cost. For shard operation, it is an on-chip operation, so the communication cost is a tiny cost.
Args:
layout: the layout of the tensor.
comm_spec: the comm_spec to instruct the communication operation.
forward_only: if it is True, we will just count the forward communication cost.
If it is False, we will count both forward and backward communication cost.
'''
comm_size = reduce(operator.mul, layout.get_sharded_shape_per_device(), 1)
device_mesh = layout.device_mesh
comm_pattern = comm_spec.comm_pattern
logical_process_axis = comm_spec.logical_process_axis
cost_dict = {}
if comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD:
# the comm size for all gather is the size of the gathered tensor
gather_dim = comm_spec.gather_dim
all_gather_axis = layout.sharding_spec.dim_partition_dict[gather_dim][-1]
all_gather_size = device_mesh.mesh_shape[all_gather_axis]
comm_size_for_all_gather = comm_size * all_gather_size
forward_communication_cost = device_mesh.all_gather_cost(comm_size_for_all_gather, logical_process_axis)
# give a tiny cost to shard
backward_communication_cost = 100
if comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD:
forward_communication_cost = device_mesh.all_to_all_cost(comm_size, logical_process_axis)
# grad should have same shape as input tensor
# all to all operation has same logical process axis as forward.
backward_communication_cost = device_mesh.all_to_all_cost(comm_size, logical_process_axis)
if comm_pattern == CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD:
forward_communication_cost = device_mesh.all_reduce_cost(comm_size, logical_process_axis)
backward_communication_cost = 0
if comm_pattern == CollectiveCommPattern.IDENTITY_FWD_ALLREDUCE_BWD:
forward_communication_cost = 0
backward_communication_cost = device_mesh.all_reduce_cost(comm_size, logical_process_axis)
if comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD:
# give a tiny cost to shard
forward_communication_cost = 100
backward_communication_cost = device_mesh.all_gather_cost(comm_size, logical_process_axis)
if forward_only:
cost_dict["forward"] = forward_communication_cost
cost_dict["backward"] = 0
cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"]
else:
cost_dict["forward"] = forward_communication_cost
cost_dict["backward"] = backward_communication_cost
cost_dict["total"] = cost_dict["forward"] + cost_dict["backward"]
return cost_dict
import math
from functools import partial
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor.comm_spec import CollectiveCommPattern
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
entire_shape = torch.Size((64, 32, 16))
layout_converter = LayoutConverter()
physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
mesh_shape = (2, 2)
def check_one_step_transform(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
# [[0, 1],
# [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
dim_partition_dict = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout)
assert '[R, S1, R]' in [
str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
]
assert '[S0, R, R]' in [
str(all_gather_layout.sharding_spec.sharding_sequence) for all_gather_layout in rst_dict.keys()
]
dim_partition_dict_all2all = {0: [0], 1: [1]}
# DistSpec:
# shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4)
sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
layout_all2all = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_all2all,
entire_shape=entire_shape)
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
assert '[S01, R, R]' in [
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
]
assert '[R, S1, S0]' in [
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
]
assert '[S0, R, S1]' in [
str(all2all_layout.sharding_spec.sharding_sequence) for all2all_layout in rst_dict_all2all.keys()
]
dim_partition_shard = {0: [0]}
# DistSpec:
# shard_sequence: S0,R,R
# device_mesh_shape: (4, 4)
sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
shard_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_shard,
entire_shape=entire_shape)
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
assert '[S01, R, R]' in [
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
]
assert '[S0, S1, R]' in [
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
]
assert '[S0, R, S1]' in [
str(shard_layout.sharding_spec.sharding_sequence) for shard_layout in rst_dict_shard.keys()
]
def check_layout_converting(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# DistSpec:
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
# check transform path
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
assert transform_path_str == '[R, S01, R]->[R, S0, R]->[S0, R, R]->[S01, R, R]'
# check comm action sequence
# all-gather(S01) -> S0
assert comm_action_sequence[0].comm_pattern == CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
assert comm_action_sequence[0].gather_dim == 1
assert comm_action_sequence[0].logical_process_axis == 1
# all-to-all(R, S0) -> [S0, R]
assert comm_action_sequence[1].comm_pattern == CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
assert comm_action_sequence[1].gather_dim == 1
assert comm_action_sequence[1].shard_dim == 0
assert comm_action_sequence[1].logical_process_axis == 0
# shard(S0) -> [S01]
assert comm_action_sequence[2].comm_pattern == CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
assert comm_action_sequence[2].shard_dim == 0
assert comm_action_sequence[2].logical_process_axis == 1
# checkout chached_spec_pairs_transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][0] == transform_path
assert layout_converter.cached_solution[('[R, S01, R]', '[S01, R, R]')][1] == comm_action_sequence
comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout)
assert comm_cost['forward'] == comm_cost['backward']
assert math.floor(comm_cost['total']) == math.floor(comm_cost['forward'] + comm_cost['backward'])
def check_layout_converting_apply(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]}
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# DistSpec:
# shard_sequence: R,S01,R
# device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
# DistSpec:
# shard_sequence: S01,R,R
# device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
original_tensor = torch.rand(entire_shape).cuda()
# tensor_to_apply: [R, S01, R]
tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)
# tensor_to_check: [S01, R, R]
tensor_to_check = original_tensor.narrow(0, rank * 16, 16)
converted_tensor = layout_converter.apply(tensor_to_apply, source_layout, target_layout)
assert converted_tensor.equal(tensor_to_check)
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_layout_converter():
world_size = 4
run_func = partial(check_one_step_transform, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
run_func = partial(check_layout_converting, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
run_func = partial(check_layout_converting_apply, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_layout_converter()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment