Unverified Commit eb39154d authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

[dtensor] updated api and doc (#3845)

parent d51e83d6
# 🗄 Device
## 📚 Table of Contents
- [🗄 Device](#-device)
- [📚 Table of Contents](#-table-of-contents)
- [🔗 Introduction](#-introduction)
- [📝 Design](#-design)
- [🔨 Usage](#-usage)
## 🔗 Introduction
This module contains the implementation of the abstraction of the device topology. It is used to represent the device topology and manage the distributed information related to the network.
## 📝 Design
This module is inspired by the DeviceMesh in the [Alpa project](https://github.com/alpa-projects/alpa) and the device array can be represented as a 1D or 2D mesh. We will be extending the device mesh to support 3D mesh in the future.
## 🔨 Usage
- Create a device mesh
```python
# this is the list of global ranks involved in the device mesh
# assume we have 4 GPUs and the global ranks for these GPUs are 0, 1, 2, 3
physical_mesh_id = torch.arange(4)
mesh_shape = [2, 2]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
```
- View the mesh
```python
# view the mesh shape
# expect output
# [2, 2]
print(device_mesh.shape)
# view the logical mesh with global ranks
# expect output
# [
# [0, 1],
# [2, 3]
# ]
print(device_mesh.logical_mesh_id)
# view the number of devices in the mesh
# expect output
# 4
print(device_mesh.num_devices)
```
- Initialize the process group
```python
# intialize process group
device_mesh.init_logical_process_group()
# get the process group for a rank with respect to an axis
# this is the process group involving global ranks 0 and 2
print(device_mesh.get_process_group(axis=0, global_rank=0))
# get the ranks in the process with respect to an axis
# expect output
# [0, 2]
print(device_mesh.get_ranks_in_process_group(axis=0, global_rank=0))
```
This diff is collapsed.
from types import MethodType from types import MethodType
from typing import Callable, Optional, Union from typing import Callable, Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -8,8 +8,9 @@ from torch import Tensor ...@@ -8,8 +8,9 @@ from torch import Tensor
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai._analyzer._subclasses import MetaTensor from colossalai._analyzer._subclasses import MetaTensor
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.d_tensor import DTensor from colossalai.tensor.d_tensor.d_tensor import DTensor
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
# reference: https://pytorch.org/cppdocs/notes/tensor_creation.html # reference: https://pytorch.org/cppdocs/notes/tensor_creation.html
_NORMAL_FACTORY = [ _NORMAL_FACTORY = [
...@@ -172,7 +173,7 @@ class LazyTensor(torch.Tensor): ...@@ -172,7 +173,7 @@ class LazyTensor(torch.Tensor):
self.clean() self.clean()
return _convert_cls(self, target) return _convert_cls(self, target)
def distribute(self, layout: Layout) -> torch.Tensor: def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
Args: Args:
...@@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor): ...@@ -183,7 +184,7 @@ class LazyTensor(torch.Tensor):
""" """
target = self._materialize_data() target = self._materialize_data()
self.clean() self.clean()
local_tensor = DTensor(target, layout).local_tensor local_tensor = DTensor(target, device_mesh, sharding_spec).local_tensor
return _convert_cls(self, local_tensor) return _convert_cls(self, local_tensor)
def clean(self) -> None: def clean(self) -> None:
...@@ -536,7 +537,10 @@ class LazyInitContext: ...@@ -536,7 +537,10 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose) return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod @staticmethod
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: def distribute(module: nn.Module,
device_mesh: DeviceMesh,
sharding_spec_dict: Dict[str, ShardingSpec],
verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args: Args:
...@@ -546,7 +550,7 @@ class LazyInitContext: ...@@ -546,7 +550,7 @@ class LazyInitContext:
""" """
def apply_fn(name: str, p: LazyTensor): def apply_fn(name: str, p: LazyTensor):
p.distribute(layout_dict[name]) p.distribute(device_mesh, sharding_spec_dict[name])
return _apply_to_lazy_module(module, apply_fn, verbose) return _apply_to_lazy_module(module, apply_fn, verbose)
......
...@@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec): ...@@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec):
''' '''
Implement all gather operation on device mesh based on information provided by comm_spec. Implement all gather operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, process_group in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list:
tensor_list = [ tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis])
] ]
# without this contiguous operation, the all gather may get some unexpected results. # without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous() tensor = tensor.contiguous()
dist.all_gather(tensor_list, tensor, group=process_group) dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output return output
def _split(tensor, comm_spec): def _split(tensor, comm_spec):
''' '''
Implement shard operation on device mesh based on information provided by comm_spec. Implement shard operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, _ in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list:
dim = comm_spec.shard_dim dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
start = length * rank_list.index(dist.get_rank()) start = length * dist.get_rank(process_group)
output = torch.narrow(tensor, dim, start, length).contiguous() output = torch.narrow(tensor, dim, start, length).contiguous()
return output return output
def _all_to_all(tensor, comm_spec): def _all_to_all(tensor, comm_spec):
''' '''
Implement all to all operation on device mesh based on information provided by comm_spec. Implement all to all operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, process_group in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list: world_size = dist.get_world_size(process_group)
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) new_shape = list(tensor.shape)
new_shape = torch.Size(new_shape) new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
output_tensor_list = [ new_shape = torch.Size(new_shape)
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
] dim = comm_spec.shard_dim
dim = comm_spec.shard_dim length = tensor.shape[comm_spec.shard_dim] // world_size
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
input_tensor_list = [ group = process_group
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) dist.all_to_all(output_tensor_list, input_tensor_list, group)
] output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
group = process_group return output
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor, comm_spec, async_op=False): def _all_reduce(tensor, comm_spec, async_op=False):
''' '''
Implement all reduce operation on device mesh based on information provided by comm_spec. Implement all reduce operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, process_group in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list:
if not tensor.is_contiguous(): if not tensor.is_contiguous():
tensor = tensor.contiguous() tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor return tensor
def _mix_gather(tensor, comm_spec): def _mix_gather(tensor, comm_spec):
...@@ -414,7 +411,7 @@ class CommSpec: ...@@ -414,7 +411,7 @@ class CommSpec:
self.forward_only = forward_only self.forward_only = forward_only
if isinstance(self.logical_process_axis, list): if isinstance(self.logical_process_axis, list):
if not mix_gather: if not mix_gather:
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh self.device_mesh = self.sharding_spec.device_mesh.flatten()
self.logical_process_axis = 0 self.logical_process_axis = 0
else: else:
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes
......
# 🔢 Distributed Tensor
## 📚 Table of Contents
- [🔢 Distributed Tensor](#-distributed-tensor)
- [📚 Table of Contents](#-table-of-contents)
- [🔗 Introduction](#-introduction)
- [📝 Design](#-design)
- [🔨 Usage](#-usage)
- [🎈 Progress Log](#-progress-log)
## 🔗 Introduction
Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training.
It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor.
## 📝 Design
Our implementation is inspired by the work [Alpa](https://arxiv.org/abs/2201.12023), which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations `S` to represent the sharded dimension and `R` to represent the replicated dimension. For example, given a 2D matrix, `[S, R]` represents the tensor is sharded over the first dimension.
Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below:
```text
[1, 2, 3, 4 ]
A = [4, 5, 6, 7 ]
[8, 9, 10, 11]
[12, 13, 14, 15]
```
`[S0, R]` would mean that the first dimension is sharded over the rows in the device topology.
```text
| --------------------—————————————————————-|
| | |
| [1, 2, 3, 4 ] | [1, 2, 3, 4 ] |
| [4, 5, 6, 7 ] | [4, 5, 6, 7 ] |
| | |
| --------------------——————————————————-----
| | |
| [8, 9, 10, 11] | [8, 9, 10, 11] |
| [12, 13, 14, 15] | [12, 13, 14, 15] |
| | |
| --------------------——————————————————-----
```
`[S01, R]` would mean that the first dimension is sharded over both the row and column in the device topology.
```text
| --------------------—————————————————————-|
| | |
| [1, 2, 3, 4 ] | [4, 5, 6, 7 ] |
| | |
| --------------------——————————————————-----
| | |
| [8, 9, 10, 11] | [12, 13, 14, 15] |
| | |
| --------------------——————————————————-----
```
## 🔨 Usage
A sample API usage is given below.
```python
import torch
import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor import DTensor, ShardingSpec
colossalai.launch_from_torch(config={})
# define your device mesh
# assume you have 4 GPUs
physical_mesh_id = torch.arange(0, 4).reshape(1, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
# define a tensor
a = torch.rand(16, 32).cuda()
# create sharding spec for the tensor
# assume the sharding spec is [S0, R]
dim_partition_dict = {0: [0]}
sharding_spec = ShardingSpec(a.dim(), dim_partition_dict)
# create a distributed tensor
d_tensor = DTensor(a, device_mesh, sharding_spec)
print(d_tensor)
global_tensor = d_tensor.to_global()
print(global_tensor)
```
## 🎈 Progress Log
- [x] Support layout conversion
- [x] Support sharding on 2D device mesh
- [ ] Support sharding on 3D device mesh
- [ ] Support sharding 4D device mesh
- [ ] Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.)
from .d_tensor import DTensor
from .sharding_spec import ShardingSpec
__all__ = ['DTensor', 'ShardingSpec']
...@@ -24,12 +24,12 @@ class CommSpec: ...@@ -24,12 +24,12 @@ class CommSpec:
''' '''
Communication spec is used to record the communication action. It converts the communication spec Communication spec is used to record the communication action. It converts the communication spec
to real action which will be used in runtime. It contains comm_pattern to determine the to real action which will be used in runtime. It contains comm_pattern to determine the
communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis to determine the buffer shape, and logical_process_axis
Argument: Argument:
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered. gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded. shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
...@@ -37,7 +37,7 @@ class CommSpec: ...@@ -37,7 +37,7 @@ class CommSpec:
def __init__(self, def __init__(self,
comm_pattern: CollectiveCommPattern, comm_pattern: CollectiveCommPattern,
process_groups_dict: Dict, process_group_dict: Dict,
gather_dim: int = None, gather_dim: int = None,
shard_dim: int = None, shard_dim: int = None,
logical_process_axis: int = None): logical_process_axis: int = None):
...@@ -45,7 +45,7 @@ class CommSpec: ...@@ -45,7 +45,7 @@ class CommSpec:
self.gather_dim = gather_dim self.gather_dim = gather_dim
self.shard_dim = shard_dim self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis self.logical_process_axis = logical_process_axis
self.process_groups_dict = process_groups_dict self.process_group_dict = process_group_dict
def __repr__(self): def __repr__(self):
res_list = ["CommSpec:("] res_list = ["CommSpec:("]
...@@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): ...@@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
''' '''
Implement all gather operation on device mesh based on information provided by comm_spec. Implement all gather operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list: world_size = dist.get_world_size(process_group)
if dist.get_rank() in rank_list: tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
tensor_list = [ # without this contiguous operation, the all gather may get some unexpected results.
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) tensor = tensor.contiguous()
] dist.all_gather(tensor_list, tensor, group=process_group)
# without this contiguous operation, the all gather may get some unexpected results. output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
tensor = tensor.contiguous() return output
dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
def _split(tensor: torch.Tensor, comm_spec: CommSpec): def _split(tensor: torch.Tensor, comm_spec: CommSpec):
''' '''
Implement shard operation on device mesh based on information provided by comm_spec. Implement shard operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list: dim = comm_spec.shard_dim
if dist.get_rank() in rank_list: length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
dim = comm_spec.shard_dim start = length * dist.get_rank(process_group)
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) output = torch.narrow(tensor, dim, start, length).contiguous()
start = length * rank_list.index(dist.get_rank()) return output
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
''' '''
Implement all to all operation on device mesh based on information provided by comm_spec. Implement all to all operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list: world_size = dist.get_world_size(process_group)
if dist.get_rank() in rank_list: new_shape = list(tensor.shape)
new_shape = list(tensor.shape) new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) new_shape = torch.Size(new_shape)
new_shape = torch.Size(new_shape) output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
output_tensor_list = [ dim = comm_spec.shard_dim
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) length = tensor.shape[comm_spec.shard_dim] // world_size
] input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
dim = comm_spec.shard_dim group = process_group
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) dist.all_to_all(output_tensor_list, input_tensor_list, group)
input_tensor_list = [ output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) return output
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
''' '''
Implement all reduce operation on device mesh based on information provided by comm_spec. Implement all reduce operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list: if not tensor.is_contiguous():
if dist.get_rank() in rank_list: tensor = tensor.contiguous()
if not tensor.is_contiguous(): dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
tensor = tensor.contiguous() return tensor
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor
class _ReduceGrad(torch.autograd.Function): class _ReduceGrad(torch.autograd.Function):
...@@ -269,7 +257,7 @@ class _AllToAll(torch.autograd.Function): ...@@ -269,7 +257,7 @@ class _AllToAll(torch.autograd.Function):
def forward(ctx, input_, comm_spec): def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec) output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
process_groups_dict=comm_spec.process_groups_dict, process_group_dict=comm_spec.process_group_dict,
gather_dim=comm_spec.shard_dim, gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim, shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis) logical_process_axis=comm_spec.logical_process_axis)
......
...@@ -3,55 +3,119 @@ from typing import Optional ...@@ -3,55 +3,119 @@ from typing import Optional
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
from colossalai.device.device_mesh import DeviceMesh
from .layout import Layout from .layout import Layout
from .layout_converter import LayoutConverter, to_global from .layout_converter import LayoutConverter, to_global
from .sharding_spec import ShardingSpec from .sharding_spec import ShardingSpec
__all__ = ['DTensor', 'distribute_tensor', 'distribute_module', 'construct_default_sharding_spec']
layout_converter = LayoutConverter() layout_converter = LayoutConverter()
class DTensor(torch.Tensor): class DTensor(torch.Tensor):
"""
DTensor stands for distributed tensor. It is a subclass of `torch.Tensor` and contains meta information
about the tensor distribution. The meta information includes the device mesh, the sharding specification,
and the entire shape of the tensor.
During runtime, we will not directly use the DTensor objects for computation. Instead, we will only use the
`DTensor.local_tensor` for computation. The `DTensor.local_tensor` is the local tensor in the current rank.
In this way, all tensors involved in computation will only be native PyTorch tensors.
Example:
```python
from colossalai.device import DeviceMesh
# define your device mesh
# assume you have 4 GPUs
physical_mesh_id = torch.arange(0, 4).reshape(1, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
# define a tensor
x = torch.rand(16, 32)
# create sharding spec for the tensor
# assume the sharding spec is [S, R]
dim_partition_dict = {
0: 1
}
sharding_spec = ShardingSpec(a.dim(), dim_partition_dict)
# create a distributed tensor
d_tensor = DTensor(x, device_mesh, sharding_spec)
```
def __init__(self, local_tensor: torch.Tensor, dist_layout: Layout): Args:
self.local_tensor = local_tensor tensor (`torch.Tensor`): the unsharded tensor.
self.data_type = local_tensor.dtype device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices.
self.entire_shape = local_tensor.shape sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded.
"""
def __init__(self, tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec):
# ensure this tensor is not a DTensor
assert not isinstance(tensor, DTensor), 'The input tensor should not be a DTensor.'
# store meta info
self.local_tensor = tensor
self.data_type = tensor.dtype
self.global_shape = tensor.shape
# create distributed layout
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape)
self.dist_layout = dist_layout self.dist_layout = dist_layout
# shard the tensor
self._apply_layout() self._apply_layout()
@staticmethod @staticmethod
def __new__(cls, local_tensor, layout): def __new__(cls, tensor, *args, **kwargs):
return torch.Tensor._make_subclass(cls, local_tensor, local_tensor.requires_grad) return torch.Tensor._make_subclass(cls, tensor, tensor.requires_grad)
def __repr__(self): def __repr__(self):
return f"DTensor({self.to_global()}, {self.dist_layout})" return f"DTensor(\n{self.to_global()}\n{self.dist_layout}"
def __str__(self): def __str__(self):
return self.__repr__() return self.__repr__()
def layout_convert(self, target_layout): def layout_convert(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
''' '''
Convert the layout of the tensor from source_spec to target_spec. Convert the layout of the tensor from source_spec to target_spec.
This will update the `local_tensor` and `dist_layout` in place.
Args:
target_layout (Layout): the target layout specification.
''' '''
self.local_tensor = layout_converter.apply(self.local_tensor, self.dist_layout, target_layout) target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=self.global_shape)
self.local_tensor = layout_converter.apply(tensor=self.local_tensor,
source_layout=self.dist_layout,
target_layout=target_layout)
self.dist_layout = target_layout self.dist_layout = target_layout
def _apply_layout(self): def _apply_layout(self):
''' '''
Apply the layout to the local tensor during initializing process. Apply the layout to the local tensor during initializing process.
''' '''
# layout converter requires a source and target laytout
# we construct the source layer for an unsharded tensor
# and use self.dist_layer as the targer layout for the sharded tensor
source_spec = construct_default_sharding_spec(self.local_tensor) source_spec = construct_default_sharding_spec(self.local_tensor)
source_layout = Layout(device_mesh=self.dist_layout.device_mesh, source_layout = Layout(device_mesh=self.dist_layout.device_mesh,
device_type=self.dist_layout.device_type,
sharding_spec=source_spec, sharding_spec=source_spec,
entire_shape=self.entire_shape) global_shape=self.global_shape)
self.local_tensor = layout_converter.apply(self.local_tensor, source_layout, self.dist_layout) self.local_tensor = layout_converter.apply(tensor=self.local_tensor,
source_layout=source_layout,
target_layout=self.dist_layout)
@classmethod @classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None: if kwargs is None:
kwargs = {} kwargs = {}
# convert all DTensors to native pytorch tensors
# so that operations will be conducted on native tensors
def filter_arg(arg): def filter_arg(arg):
if isinstance(arg, DTensor): if isinstance(arg, DTensor):
return arg.local_tensor return arg.local_tensor
...@@ -60,9 +124,9 @@ class DTensor(torch.Tensor): ...@@ -60,9 +124,9 @@ class DTensor(torch.Tensor):
args = tree_map(filter_arg, args) args = tree_map(filter_arg, args)
kwargs = tree_map(filter_arg, kwargs) kwargs = tree_map(filter_arg, kwargs)
# if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
# and op type.
# NOTE: if we want to convert the result into DTensor, we need to infer the layout of result from the layout of input tensors
# and op type.
return func(*args, **kwargs) return func(*args, **kwargs)
@property @property
...@@ -85,7 +149,6 @@ class DTensor(torch.Tensor): ...@@ -85,7 +149,6 @@ class DTensor(torch.Tensor):
''' '''
self.local_tensor = self.local_tensor.to(*args, **kwargs) self.local_tensor = self.local_tensor.to(*args, **kwargs)
self.data_type = self.local_tensor.dtype self.data_type = self.local_tensor.dtype
self.dist_layout.device_type = self.local_tensor.device
# TODO: update the device mesh process groups or we should just cache # TODO: update the device mesh process groups or we should just cache
# both the cpu process groups and the cuda process groups? # both the cpu process groups and the cuda process groups?
return self return self
...@@ -98,7 +161,7 @@ class DTensor(torch.Tensor): ...@@ -98,7 +161,7 @@ class DTensor(torch.Tensor):
def to_global(self): def to_global(self):
''' '''
Recover the global tensor from the distributed tensor. Recover the global tensor from the distributed tensor by returning a new `torch.Tensor` object.
Note: This function will all_gather the local tensor to the global tensor and it Note: This function will all_gather the local tensor to the global tensor and it
will not change the layout of the DTensor. This function is mainly used for debugging or will not change the layout of the DTensor. This function is mainly used for debugging or
...@@ -107,24 +170,29 @@ class DTensor(torch.Tensor): ...@@ -107,24 +170,29 @@ class DTensor(torch.Tensor):
return to_global(self.local_tensor, self.dist_layout) return to_global(self.local_tensor, self.dist_layout)
def distribute_tensor(local_tensor: torch.Tensor, dist_layout: Layout) -> DTensor: def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> DTensor:
''' '''
Distribute the local tensor to the distributed tensor according to the dist_layout specified. Distribute the local tensor to the distributed tensor according to the dist_layout specified.
Args: Args:
local_tensor: tensor to be distributed. tensor (`torch.Tensor`): tensor to be distributed.
dist_layout: the layout specification of the distributed tensor. device_mesh (`DeviceMesh`): the device mesh for abstraction of the compute devices.
sharding_spec (`ShardingSpec`): the sharding specification which describes how the tensor will be sharded.
Returns: Returns:
A 'DTensor' object. A 'DTensor' object.
''' '''
return DTensor(local_tensor, dist_layout) return DTensor(tensor, device_mesh, sharding_spec)
def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module: def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] = None) -> torch.nn.Module:
''' '''
This function converts all the parameters in the module to DTensor(DParam). This function converts all the parameters in the module to DTensor(DParam).
Args:
module (`torch.nn.Module`): the module to be distributed.
partition_fn (callable): the partition function which will be used to partition the parameters.
Note: This function is subject to future change as the DParam has not been implemented yet. Note: This function is subject to future change as the DParam has not been implemented yet.
''' '''
for name, param in module.named_parameters(): for name, param in module.named_parameters():
...@@ -138,5 +206,11 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable] ...@@ -138,5 +206,11 @@ def distribute_module(module: torch.nn.Module, partition_fn: Optional[callable]
def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec: def construct_default_sharding_spec(tensor: torch.Tensor,) -> ShardingSpec:
''' '''
Construct the default sharding specification for the tensor. Construct the default sharding specification for the tensor.
Args:
tensor (`torch.Tensor`): the tensor to be sharded.
Returns:
A `ShardingSpec` object without any sharding specified.
''' '''
return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={}) return ShardingSpec(dim_size=tensor.dim(), dim_partition_dict={})
...@@ -11,28 +11,32 @@ from .sharding_spec import ShardingSpec ...@@ -11,28 +11,32 @@ from .sharding_spec import ShardingSpec
class Layout: class Layout:
"""Layout of a tensor. """
Layout of a tensor refers to the tensor placement on the device mesh and how the tensor is sharded over the devices.
Attributes: Args:
device_mesh: the device mesh to store the tensor distributed. device_mesh (`DeviceMesh`): the device mesh to store the tensor distributed.
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'. sharding_spec (`ShardingSpec`): the sharding specification to describe how the tensor is sharded.
sharding_spec: the sharding specification to describe how the tensor is sharded. global_shape (`torch.Size`): the entire shape of the global tensor.
entire_shape: the entire shape of the global tensor.
""" """
def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
entire_shape: torch.Size):
self.device_mesh = device_mesh self.device_mesh = device_mesh
self.device_type = device_type
self.sharding_spec = sharding_spec self.sharding_spec = sharding_spec
self.entire_shape = entire_shape self.global_shape = global_shape
self._sanity_check() self._sanity_check()
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(f'{self.sharding_spec}') return hash(f'{self.sharding_spec}')
def get_sharded_shape_per_device(self): def get_sharded_shape_per_device(self) -> torch.Size:
sharded_shape = list(self.entire_shape) """
Compute the shape of the sharded tensor on each device.
Returns:
`torch.Size`: the shape of the sharded tensor on each device.
"""
sharded_shape = list(self.global_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1) shard_partitions = reduce(operator.mul, mesh_list, 1)
...@@ -56,7 +60,7 @@ class Layout: ...@@ -56,7 +60,7 @@ class Layout:
# make sure that the sharding for a dimension is divisible by the number of devices # make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in sharding_spec.dim_partition_dict.items(): for dim, shard_list in sharding_spec.dim_partition_dict.items():
tensor_dim_size = self.entire_shape[dim] tensor_dim_size = self.global_shape[dim]
num_devices = 1 num_devices = 1
for element in shard_list: for element in shard_list:
......
...@@ -3,10 +3,8 @@ from copy import deepcopy ...@@ -3,10 +3,8 @@ from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np
import torch import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout import Layout
...@@ -28,13 +26,21 @@ class LayoutConverterOptions: ...@@ -28,13 +26,21 @@ class LayoutConverterOptions:
pass pass
def to_global(distributed_tensor: torch.Tensor, layout: Layout) -> torch.Tensor: def to_global(distributed_tensor: "DTensor", layout: Layout) -> torch.Tensor:
"""
Convert a distributed tensor to the global tensor with the given layout.
This function returns a native `torch.Tensor` object.
Args:
distributed_tensor (`DTensor`): the distributed tensor to be converted.
layout (`Layout`): the target layout specification.
"""
layout_converter = LayoutConverter() layout_converter = LayoutConverter()
global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {}) global_sharding_spec = ShardingSpec(distributed_tensor.dim(), {})
global_layout = Layout(device_mesh=layout.device_mesh, global_layout = Layout(device_mesh=layout.device_mesh,
device_type=layout.device_type,
sharding_spec=global_sharding_spec, sharding_spec=global_sharding_spec,
entire_shape=layout.entire_shape) global_shape=layout.global_shape)
with torch.no_grad(): with torch.no_grad():
global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout) global_tensor = layout_converter.apply(distributed_tensor, layout, global_layout)
return global_tensor return global_tensor
...@@ -49,6 +55,9 @@ def set_layout_converting_options(options: LayoutConverterOptions): ...@@ -49,6 +55,9 @@ def set_layout_converting_options(options: LayoutConverterOptions):
class LayoutConverter(metaclass=SingletonMeta): class LayoutConverter(metaclass=SingletonMeta):
"""
LayoutConverter is a singleton class which converts the layout of a distributed tensor.
"""
def __init__(self): def __init__(self):
self._options = None self._options = None
...@@ -91,15 +100,14 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -91,15 +100,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]} dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R] # [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec, sharding_spec=sharding_spec,
entire_shape=entire_shape) global_shape=global_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout) rst_dict = layout_converter.all_gather_transform_layouts(layout)
for layout, comm_spec in rst_dict.items(): for layout, comm_spec in rst_dict.items():
...@@ -112,7 +120,12 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -112,7 +120,12 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {} valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
source_spec = source_layout.sharding_spec source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
for target_pair in source_spec.dim_partition_dict.items(): for target_pair in source_spec.dim_partition_dict.items():
shard_list = all_gather_simulator(target_pair) shard_list = all_gather_simulator(target_pair)
index = target_pair[0] index = target_pair[0]
...@@ -130,7 +143,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -130,7 +143,7 @@ class LayoutConverter(metaclass=SingletonMeta):
logical_process_axis = target_pair[1][-1] logical_process_axis = target_pair[1][-1]
comm_spec = CommSpec( comm_spec = CommSpec(
comm_pattern, comm_pattern,
process_groups_dict=process_groups_dict, process_group_dict=process_group_dict,
gather_dim=gather_dim, gather_dim=gather_dim,
# shard_dim will be used during backward # shard_dim will be used during backward
shard_dim=gather_dim, shard_dim=gather_dim,
...@@ -141,8 +154,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -141,8 +154,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh, new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec, sharding_spec=new_sharding_spec,
device_type=source_layout.device_type, global_shape=source_layout.global_shape)
entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except LayoutException: except LayoutException:
...@@ -167,15 +179,14 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -167,15 +179,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]} dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R] # [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec, sharding_spec=sharding_spec,
entire_shape=entire_shape) global_shape=global_shape)
rst_dict = layout_converter.all_to_all_transform_layout(layout) rst_dict = layout_converter.all_to_all_transform_layout(layout)
for layout, comm_spec in rst_dict.items(): for layout, comm_spec in rst_dict.items():
...@@ -188,7 +199,12 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -188,7 +199,12 @@ class LayoutConverter(metaclass=SingletonMeta):
''' '''
valid_spec_dict = {} valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
source_spec = source_layout.sharding_spec source_spec = source_layout.sharding_spec
tensor_dims = source_spec.dims tensor_dims = source_spec.dims
for f_index in range(tensor_dims - 1): for f_index in range(tensor_dims - 1):
...@@ -229,7 +245,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -229,7 +245,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = f_index shard_dim = f_index
logical_process_axis = b_target_pair[1][-1] logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(comm_pattern, comm_spec = CommSpec(comm_pattern,
process_groups_dict, process_group_dict=process_group_dict,
gather_dim=gather_dim, gather_dim=gather_dim,
shard_dim=shard_dim, shard_dim=shard_dim,
logical_process_axis=logical_process_axis) logical_process_axis=logical_process_axis)
...@@ -252,8 +268,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -252,8 +268,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh, new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec, sharding_spec=new_sharding_spec,
device_type=source_layout.device_type, global_shape=source_layout.global_shape)
entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except LayoutException: except LayoutException:
pass pass
...@@ -278,16 +293,15 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -278,16 +293,15 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_dict = {0: [0]} dim_partition_dict = {0: [0]}
# [S0,R,R] # [S0,R,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec, sharding_spec=sharding_spec,
entire_shape=entire_shape) global_shape=global_shape)
rst_dict = layout_converter.shard_transform_layout(layout) rst_dict = layout_converter.shard_transform_layout(layout)
for layout, comm_spec in rst_dict.items(): for layout, comm_spec in rst_dict.items():
...@@ -301,7 +315,11 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -301,7 +315,11 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {} valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
source_spec = source_layout.sharding_spec source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
# legal sharding dims means the mesh_id is still available to use. # legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))] legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.mesh_shape))]
...@@ -329,7 +347,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -329,7 +347,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = index shard_dim = index
logical_process_axis = shard_list[-1] logical_process_axis = shard_list[-1]
comm_spec = CommSpec(comm_pattern, comm_spec = CommSpec(comm_pattern,
process_groups_dict, process_group_dict=process_group_dict,
gather_dim=shard_dim, gather_dim=shard_dim,
shard_dim=shard_dim, shard_dim=shard_dim,
logical_process_axis=logical_process_axis) logical_process_axis=logical_process_axis)
...@@ -340,8 +358,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -340,8 +358,7 @@ class LayoutConverter(metaclass=SingletonMeta):
dim_partition_dict=new_dim_partition_dict) dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh, new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec, sharding_spec=new_sharding_spec,
device_type=source_layout.device_type, global_shape=source_layout.global_shape)
entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except LayoutException: except LayoutException:
pass pass
...@@ -399,7 +416,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -399,7 +416,7 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_source = {1: [0, 1]} dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]} dim_partition_target = {0: [0, 1]}
...@@ -407,16 +424,14 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -407,16 +424,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [R,S01,R] # [R,S01,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source, sharding_spec=sharding_spec_source,
entire_shape=entire_shape) global_shape=global_shape)
# [S01,R,R] # [S01,R,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target, sharding_spec=sharding_spec_target,
entire_shape=entire_shape) global_shape=global_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
...@@ -505,21 +520,19 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -505,21 +520,19 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
# [S0,R,R] # [S0,R,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source, sharding_spec=sharding_spec_source,
entire_shape=entire_shape) global_shape=global_shape)
# [R,S0,R] # [R,S0,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target, sharding_spec=sharding_spec_target,
entire_shape=entire_shape) global_shape=global_shape)
if rank in (0, 1): if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1) sharded_tensor_0 = torch.zeros(2, 1)
...@@ -554,3 +567,4 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -554,3 +567,4 @@ class LayoutConverter(metaclass=SingletonMeta):
for comm_spec in comm_action_sequence: for comm_spec in comm_action_sequence:
tensor = comm_spec.covert_spec_to_action(tensor) tensor = comm_spec.covert_spec_to_action(tensor)
return tensor return tensor
return tensor
from colossalai.device.device_mesh import DeviceMesh
import torch import torch
from colossalai.device.device_mesh import DeviceMesh
def test_device_mesh(): def test_device_mesh():
physical_mesh_id = torch.arange(0, 16).reshape(2, 8) physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4) mesh_shape = (4, 4)
# [[0, 1, 2, 3], # [[0, 1, 2, 3],
# [4, 5, 6, 7], # [4, 5, 6, 7],
# [8, 9, 10,11], # [8, 9, 10,11],
# [12,13,14,15]] # [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
assert device_mesh.convert_map[5] == [1, 1] assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
assert device_mesh.convert_map[11] == [2, 3] assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]]
assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3]
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -20,16 +20,12 @@ def check_layer(rank, world_size, port): ...@@ -20,16 +20,12 @@ def check_layer(rank, world_size, port):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
logical_process_groups = device_mesh.process_groups_dict for axis in range(len(mesh_shape)):
tensor = torch.ones(4).cuda()
for mesh_dim, pgs in logical_pg_dict.items(): pg = device_mesh.get_process_group(axis=axis)
for index, pg in enumerate(pgs): dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
if rank in pg: assert tensor.equal(tensor_to_check)
tensor = torch.ones(4).cuda()
group = logical_process_groups[mesh_dim][index][1]
dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
assert tensor.equal(tensor_to_check)
gpc.destroy() gpc.destroy()
......
...@@ -6,7 +6,9 @@ import numpy as np ...@@ -6,7 +6,9 @@ import numpy as np
import torch import torch
from packaging import version from packaging import version
from colossalai.device.device_mesh import DeviceMesh
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.layout_converter import to_global from colossalai.tensor.d_tensor.layout_converter import to_global
from tests.kit.model_zoo.registry import ModelAttribute from tests.kit.model_zoo.registry import ModelAttribute
...@@ -81,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, ...@@ -81,7 +83,8 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
print(f'{model.__class__.__name__} pass') print(f'{model.__class__.__name__} pass')
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None: def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh,
sharding_spec_dict: dict) -> None:
state = model.state_dict() state = model.state_dict()
distributed_state = distributed_model.state_dict() distributed_state = distributed_model.state_dict()
...@@ -91,6 +94,7 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn. ...@@ -91,6 +94,7 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
assert n1 == n2 assert n1 == n2
t1 = t1.cuda() t1 = t1.cuda()
t2 = t2.cuda() t2 = t2.cuda()
if n2 in layout_dict: if n2 in sharding_spec_dict:
t2 = to_global(t2, layout_dict[n2]) layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape)
t2 = to_global(t2, layout)
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
...@@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]: ...@@ -26,23 +26,19 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]:
return dim return dim
def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout: def make_sharding_spec(original_tensor: torch.Tensor) -> Layout:
shard_dim = find_shard_dim(original_tensor.shape) shard_dim = find_shard_dim(original_tensor.shape)
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, return target_sharding_spec
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
entire_shape=original_tensor.shape)
return layout
def _get_current_name(prefix: str, name: str) -> str: def _get_current_name(prefix: str, name: str) -> str:
return f'{prefix}.{name}'.lstrip('.') return f'{prefix}.{name}'.lstrip('.')
def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: def generate_sharding_spec_dict(model: nn.Module) -> dict:
layout_dict = {} sharding_spec_dict = {}
@torch.no_grad() @torch.no_grad()
def generate_recursively(module: nn.Module, prefix: str = ''): def generate_recursively(module: nn.Module, prefix: str = ''):
...@@ -53,17 +49,17 @@ def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict: ...@@ -53,17 +49,17 @@ def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
# initialize tensors directly attached to the current module # initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False): for name, param in module.named_parameters(recurse=False):
if isinstance(param, LazyTensor): if isinstance(param, LazyTensor):
layout = make_layout(device_mesh, param) sharding_spec = make_sharding_spec(param)
layout_dict[_get_current_name(prefix, name)] = layout sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
for name, buf in module.named_buffers(recurse=False): for name, buf in module.named_buffers(recurse=False):
if isinstance(buf, LazyTensor): if isinstance(buf, LazyTensor):
layout = make_layout(device_mesh, buf) sharding_spec = make_sharding_spec(buf)
layout_dict[_get_current_name(prefix, name)] = layout sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
generate_recursively(model) generate_recursively(model)
return layout_dict return sharding_spec_dict
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) @parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
...@@ -85,9 +81,9 @@ def run_dist_lazy_init(subset, seed: int = 42): ...@@ -85,9 +81,9 @@ def run_dist_lazy_init(subset, seed: int = 42):
ctx = LazyInitContext() ctx = LazyInitContext()
with ctx: with ctx:
deferred_model = model_fn() deferred_model = model_fn()
layout_dict = generate_layout_dict(deferred_model, device_mesh) sharding_spec_dict = generate_sharding_spec_dict(deferred_model)
ctx.distribute(deferred_model, layout_dict, verbose=True) ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True)
assert_dist_model_equal(model, deferred_model, layout_dict) assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict)
def run_dist(rank, world_size, port) -> None: def run_dist(rank, world_size, port) -> None:
......
...@@ -125,23 +125,6 @@ def check_all_reduce_bwd(process_groups_dict, rank): ...@@ -125,23 +125,6 @@ def check_all_reduce_bwd(process_groups_dict, rank):
assert tensor_to_comm.equal(tensor_to_check) assert tensor_to_comm.equal(tensor_to_check)
def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank):
# tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank
# reduce through logical process axis 0 at flatten device mesh
# tensor to check
# tensor([[6., 6.],
# [6., 6.]])
tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_comm(rank, world_size, port): def check_comm(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
...@@ -153,24 +136,22 @@ def check_comm(rank, world_size, port): ...@@ -153,24 +136,22 @@ def check_comm(rank, world_size, port):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
process_groups_dict = device_mesh.process_groups_dict
process_group_dict = device_mesh._process_group_dict[rank]
# test all gather # test all gather
check_all_gather(process_groups_dict, rank) check_all_gather(process_group_dict, rank)
# test shard # test shard
check_shard(process_groups_dict, rank) check_shard(process_group_dict, rank)
# test all to all # test all to all
check_all_to_all(process_groups_dict, rank) check_all_to_all(process_group_dict, rank)
# test all reduce # test all reduce
check_all_reduce_fwd(process_groups_dict, rank) check_all_reduce_fwd(process_group_dict, rank)
check_all_reduce_bwd(process_groups_dict, rank) check_all_reduce_bwd(process_group_dict, rank)
flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict
# test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank)
gpc.destroy() gpc.destroy()
......
...@@ -31,13 +31,9 @@ def check_dtensor(rank, world_size, port): ...@@ -31,13 +31,9 @@ def check_dtensor(rank, world_size, port):
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
layout = Layout(device_mesh=device_mesh, d_tensor = DTensor(original_tensor, device_mesh, target_sharding_spec)
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor = DTensor(original_tensor, layout)
assert d_tensor.entire_shape == original_tensor.shape assert d_tensor.global_shape == original_tensor.shape
assert d_tensor.data_type == original_tensor.dtype assert d_tensor.data_type == original_tensor.dtype
if rank in (0, 1): if rank in (0, 1):
...@@ -57,12 +53,7 @@ def check_dtensor(rank, world_size, port): ...@@ -57,12 +53,7 @@ def check_dtensor(rank, world_size, port):
raise ValueError(f'rank {rank} is not in the device mesh') raise ValueError(f'rank {rank} is not in the device mesh')
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
new_layout = Layout(device_mesh=device_mesh, d_tensor.layout_convert(device_mesh, new_sharding_spec)
device_type=torch.device('cuda'),
sharding_spec=new_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor.layout_convert(new_layout)
if rank == 0: if rank == 0:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
...@@ -75,7 +66,7 @@ def check_dtensor(rank, world_size, port): ...@@ -75,7 +66,7 @@ def check_dtensor(rank, world_size, port):
else: else:
raise ValueError(f'rank {rank} is not in the device mesh') raise ValueError(f'rank {rank} is not in the device mesh')
dtensor_from_local = distribute_tensor(original_tensor, new_layout) dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec)
if rank == 0: if rank == 0:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
......
...@@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter ...@@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
entire_shape = torch.Size((64, 32, 16)) global_shape = torch.Size((64, 32, 16))
layout_converter = LayoutConverter() layout_converter = LayoutConverter()
physical_mesh_id = torch.arange(0, 4).reshape(2, 2) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
...@@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port): ...@@ -30,10 +30,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R # shard_sequence: S0,S1,R
# device_mesh_shape: (2, 2) # device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout) rst_dict = layout_converter.all_gather_transform_layouts(layout)
...@@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port): ...@@ -49,10 +46,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R # shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
layout_all2all = Layout(device_mesh=device_mesh, layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_all2all,
entire_shape=entire_shape)
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
...@@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port): ...@@ -71,10 +65,7 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,R,R # shard_sequence: S0,R,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
shard_layout = Layout(device_mesh=device_mesh, shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_shard,
entire_shape=entire_shape)
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
...@@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port): ...@@ -100,19 +91,13 @@ def check_layout_converting(rank, world_size, port):
# shard_sequence: R,S01,R # shard_sequence: R,S01,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
# DistSpec: # DistSpec:
# shard_sequence: S01,R,R # shard_sequence: S01,R,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
...@@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port): ...@@ -159,21 +144,15 @@ def check_layout_converting_apply(rank, world_size, port):
# shard_sequence: R,S01,R # shard_sequence: R,S01,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
# DistSpec: # DistSpec:
# shard_sequence: S01,R,R # shard_sequence: S01,R,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape)
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
original_tensor = torch.rand(entire_shape).cuda() original_tensor = torch.rand(global_shape).cuda()
# tensor_to_apply: [R, S01, R] # tensor_to_apply: [R, S01, R]
tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)
......
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
import torch import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
physical_mesh_id = torch.arange(0, 16).reshape(2, 8) physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4) mesh_shape = (4, 4)
# [[0, 1, 2, 3], # [[0, 1, 2, 3],
# [4, 5, 6, 7], # [4, 5, 6, 7],
......
...@@ -26,7 +26,7 @@ def run_dist(rank, world_size, port): ...@@ -26,7 +26,7 @@ def run_dist(rank, world_size, port):
# the mesh is in the following topo # the mesh is in the following topo
# [[0, 1], # [[0, 1],
# [2, 3]] # [2, 3]]
physical_mesh_id = torch.arange(0, 4).reshape(2, 2) physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2) mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
row_id = rank // 2 row_id = rank // 2
......
...@@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec ...@@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
def test_sharding_spec(): def test_sharding_spec():
physical_mesh_id = torch.arange(0, 16).reshape(2, 8) physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4) mesh_shape = (4, 4)
# [[0, 1, 2, 3], # [[0, 1, 2, 3],
# [4, 5, 6, 7], # [4, 5, 6, 7],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment