Commit c4b1b659 authored by Frank Lee's avatar Frank Lee
Browse files

[test] fixed tests failed due to dtensor change (#4082)

* [test] fixed tests failed due to dtensor change

* polish code
parent 92f67910
...@@ -188,7 +188,7 @@ class NodeHandler(ABC): ...@@ -188,7 +188,7 @@ class NodeHandler(ABC):
remove_strategy_list = [] remove_strategy_list = []
for strategy in self.strategies_vector: for strategy in self.strategies_vector:
shard_axis_list = [] shard_axis_list = []
last_axis = len(self.device_mesh.mesh_shape) - 1 last_axis = len(self.device_mesh.shape) - 1
for op_data, sharding_spec in strategy.sharding_specs.items(): for op_data, sharding_spec in strategy.sharding_specs.items():
if op_data.data is not None and isinstance(op_data.data, torch.Tensor): if op_data.data is not None and isinstance(op_data.data, torch.Tensor):
for dim, shard_axes in sharding_spec.dim_partition_dict.items(): for dim, shard_axes in sharding_spec.dim_partition_dict.items():
......
...@@ -984,7 +984,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): ...@@ -984,7 +984,7 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
def collate_strategies(self) -> List[ShardingStrategy]: def collate_strategies(self) -> List[ShardingStrategy]:
strategy_list = [] strategy_list = []
device_mesh_is_1d = True device_mesh_is_1d = True
if len(self.device_mesh.mesh_shape) == 2 and 1 not in self.device_mesh.mesh_shape: if len(self.device_mesh.shape) == 2 and 1 not in self.device_mesh.shape:
device_mesh_is_1d = False device_mesh_is_1d = False
if device_mesh_is_1d: if device_mesh_is_1d:
...@@ -992,10 +992,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator): ...@@ -992,10 +992,10 @@ class BatchedMatMulStrategyGenerator(MatMulStrategyGenerator):
# Sb = Sb x Sb # Sb = Sb x Sb
# can be None as it is only for 1D device mesh # can be None as it is only for 1D device mesh
# only for 1D device mesh # only for 1D device mesh
if len(self.device_mesh.mesh_shape) == 1: if len(self.device_mesh.shape) == 1:
mesh_dim = 0 mesh_dim = 0
else: else:
mesh_dim = self.device_mesh.mesh_shape.index(1) mesh_dim = self.device_mesh.shape.index(1)
strategy_list.append(self.split_one_batch_dim(mesh_dim)) strategy_list.append(self.split_one_batch_dim(mesh_dim))
else: else:
# for 2D device mesh # for 2D device mesh
......
...@@ -46,8 +46,8 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens ...@@ -46,8 +46,8 @@ def check_sharding_spec_validity(sharding_spec: ShardingSpec, tensor: torch.Tens
# make sure all dims are covered in sharding spec # make sure all dims are covered in sharding spec
sharding_len = len(sharding_spec.sharding_sequence) sharding_len = len(sharding_spec.sharding_sequence)
tensor_num_dim = tensor.dim() tensor_num_dim = tensor.dim()
num_devices_in_col = sharding_spec.device_mesh.mesh_shape[0] num_devices_in_col = sharding_spec.device_mesh.shape[0]
num_devices_in_row = sharding_spec.device_mesh.mesh_shape[1] num_devices_in_row = sharding_spec.device_mesh.shape[1]
assert sharding_len == tensor_num_dim, \ assert sharding_len == tensor_num_dim, \
f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).' f'The ShardingSpec ({sharding_spec.sharding_sequence}) is created for {sharding_len}-dimension tensor, but the given tensor is {tensor_num_dim}-dimension ({tensor.shape}).'
......
...@@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) ...@@ -99,7 +99,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024)
for key, weight in state_dict.items(): for key, weight in state_dict.items():
ret_block = None ret_block = None
ret_block_size = 0 ret_block_size = 0
if is_distributed_tensor(weight): if not is_distributed_tensor(weight):
weight_size = calculate_tensor_size(weight) weight_size = calculate_tensor_size(weight)
# If this weight is going to tip up over the maximal size, we split. # If this weight is going to tip up over the maximal size, we split.
...@@ -146,7 +146,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> ...@@ -146,7 +146,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
continue continue
# If the states are stored as DTensors, mark isDTensor as true. # If the states are stored as DTensors, mark isDTensor as true.
if type(state_tensor) == DTensor: if is_distributed_tensor(state_tensor):
isDTensor = True isDTensor = True
state_size += calculate_tensor_size(state_tensor) state_size += calculate_tensor_size(state_tensor)
......
from types import MethodType from types import MethodType
from typing import Callable, Optional, Union from typing import Callable, Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -173,7 +173,7 @@ class LazyTensor(torch.Tensor): ...@@ -173,7 +173,7 @@ class LazyTensor(torch.Tensor):
self.clean() self.clean()
return _convert_cls(self, target) return _convert_cls(self, target)
def distribute(self, layout: Layout) -> torch.Tensor: def distribute(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> torch.Tensor:
"""Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout. """Distribute the ``LazyTensor`` to ``torch.Tensor`` by modifying __class__ (inplace), according to the layout.
Args: Args:
...@@ -537,7 +537,10 @@ class LazyInitContext: ...@@ -537,7 +537,10 @@ class LazyInitContext:
return _apply_to_lazy_module(module, apply_fn, verbose) return _apply_to_lazy_module(module, apply_fn, verbose)
@staticmethod @staticmethod
def distribute(module: nn.Module, layout_dict: dict, verbose: bool = False) -> nn.Module: def distribute(module: nn.Module,
device_mesh: DeviceMesh,
sharding_spec_dict: Dict[str, ShardingSpec],
verbose: bool = False) -> nn.Module:
"""Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place. """Distribute all ``nn.Parameter`` from ``LazyTensor``. This function will modify the module in-place.
Args: Args:
...@@ -547,7 +550,7 @@ class LazyInitContext: ...@@ -547,7 +550,7 @@ class LazyInitContext:
""" """
def apply_fn(name: str, p: LazyTensor): def apply_fn(name: str, p: LazyTensor):
p.distribute(layout_dict[name]) p.distribute(device_mesh, sharding_spec_dict[name])
return _apply_to_lazy_module(module, apply_fn, verbose) return _apply_to_lazy_module(module, apply_fn, verbose)
......
...@@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec): ...@@ -16,69 +16,66 @@ def _all_gather(tensor, comm_spec):
''' '''
Implement all gather operation on device mesh based on information provided by comm_spec. Implement all gather operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, process_group in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list:
tensor_list = [ tensor_list = [
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device)
for _ in range(comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis]) for _ in range(comm_spec.device_mesh.shape[comm_spec.logical_process_axis])
] ]
# without this contiguous operation, the all gather may get some unexpected results. # without this contiguous operation, the all gather may get some unexpected results.
tensor = tensor.contiguous() tensor = tensor.contiguous()
dist.all_gather(tensor_list, tensor, group=process_group) dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous() output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output return output
def _split(tensor, comm_spec): def _split(tensor, comm_spec):
''' '''
Implement shard operation on device mesh based on information provided by comm_spec. Implement shard operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, _ in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list:
dim = comm_spec.shard_dim dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
start = length * rank_list.index(dist.get_rank()) start = length * dist.get_rank(process_group)
output = torch.narrow(tensor, dim, start, length).contiguous() output = torch.narrow(tensor, dim, start, length).contiguous()
return output return output
def _all_to_all(tensor, comm_spec): def _all_to_all(tensor, comm_spec):
''' '''
Implement all to all operation on device mesh based on information provided by comm_spec. Implement all to all operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, process_group in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list: world_size = dist.get_world_size(process_group)
new_shape = list(tensor.shape)
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) new_shape = list(tensor.shape)
new_shape = torch.Size(new_shape) new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
output_tensor_list = [ new_shape = torch.Size(new_shape)
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
] dim = comm_spec.shard_dim
dim = comm_spec.shard_dim length = tensor.shape[comm_spec.shard_dim] // world_size
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
input_tensor_list = [ group = process_group
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) dist.all_to_all(output_tensor_list, input_tensor_list, group)
] output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
group = process_group return output
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor, comm_spec, async_op=False): def _all_reduce(tensor, comm_spec, async_op=False):
''' '''
Implement all reduce operation on device mesh based on information provided by comm_spec. Implement all reduce operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.device_mesh.process_groups_dict[comm_spec.logical_process_axis] process_groups = comm_spec.device_mesh.get_process_group_for_all_axes()
for rank_list, process_group in process_groups_list: process_group = process_groups[comm_spec.logical_process_axis]
if dist.get_rank() in rank_list:
if not tensor.is_contiguous(): if not tensor.is_contiguous():
tensor = tensor.contiguous() tensor = tensor.contiguous()
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op) dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor return tensor
def _mix_gather(tensor, comm_spec): def _mix_gather(tensor, comm_spec):
...@@ -128,7 +125,7 @@ def _mix_gather(tensor, comm_spec): ...@@ -128,7 +125,7 @@ def _mix_gather(tensor, comm_spec):
process_group = "[0, 1, 2, 3, 4, 5, 6, 7]" process_group = "[0, 1, 2, 3, 4, 5, 6, 7]"
tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)] tensor_list = [(0,0),(1,0),(2,0),(3,0),(4,0),(5,0),(6,0),(7,0)]
''' '''
total_slices = comm_spec.device_mesh.mesh_shape[0] total_slices = comm_spec.device_mesh.shape[0]
tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)] tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(total_slices)]
leading_group_dim = comm_spec.logical_process_axes[0] leading_group_dim = comm_spec.logical_process_axes[0]
assert len(comm_spec.device_mesh.process_groups_dict) == 1 assert len(comm_spec.device_mesh.process_groups_dict) == 1
...@@ -149,7 +146,7 @@ def _mix_gather(tensor, comm_spec): ...@@ -149,7 +146,7 @@ def _mix_gather(tensor, comm_spec):
if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]: if comm_spec.logical_process_axes[0] == comm_spec.logical_process_axes[1]:
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous() output = torch.cat(tuple(tensor_list), comm_spec.gather_dim[0]).contiguous()
else: else:
mesh_shape = comm_spec.device_meshes.mesh_shape mesh_shape = comm_spec.device_meshes.shape
cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]] cat_slice = [mesh_shape[comm_spec.logical_process_axes[0]], mesh_shape[comm_spec.logical_process_axes[1]]]
tmp_tensor_shape = list(tensor.shape) tmp_tensor_shape = list(tensor.shape)
tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0] tmp_tensor_shape[comm_spec.gather_dim[0]] *= cat_slice[0]
...@@ -181,9 +178,9 @@ def _mix_split(tensor, comm_spec): ...@@ -181,9 +178,9 @@ def _mix_split(tensor, comm_spec):
# [4, 5, 6, 7]] # [4, 5, 6, 7]]
# return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]} # return {0: [0, 4, 1, 5, 2, 6, 3, 7], 1: [0, 1, 2, 3, 4, 5, 6, 7]}
''' '''
mesh_shape = comm_spec.device_meshes.mesh_shape mesh_shape = comm_spec.device_meshes.shape
dim = comm_spec.gather_dim dim = comm_spec.gather_dim
total_slices = comm_spec.device_mesh.mesh_shape[0] total_slices = comm_spec.device_mesh.shape[0]
# Get global rank # Get global rank
rank = dist.get_rank() rank = dist.get_rank()
...@@ -414,7 +411,7 @@ class CommSpec: ...@@ -414,7 +411,7 @@ class CommSpec:
self.forward_only = forward_only self.forward_only = forward_only
if isinstance(self.logical_process_axis, list): if isinstance(self.logical_process_axis, list):
if not mix_gather: if not mix_gather:
self.device_mesh = self.sharding_spec.device_mesh.flatten_device_mesh self.device_mesh = self.sharding_spec.device_mesh.flatten()
self.logical_process_axis = 0 self.logical_process_axis = 0
else: else:
self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes self.device_meshes = self.sharding_spec.device_mesh.flatten_device_meshes
......
...@@ -24,12 +24,12 @@ class CommSpec: ...@@ -24,12 +24,12 @@ class CommSpec:
''' '''
Communication spec is used to record the communication action. It converts the communication spec Communication spec is used to record the communication action. It converts the communication spec
to real action which will be used in runtime. It contains comm_pattern to determine the to real action which will be used in runtime. It contains comm_pattern to determine the
communication method, process_groups_dict to determine the process groups, gather_dim and shard_dim communication method, process_group_dict to determine the process groups, gather_dim and shard_dim
to determine the buffer shape, and logical_process_axis to determine the buffer shape, and logical_process_axis
Argument: Argument:
comm_pattern(CollectiveCommPattern): describe the communication method used in this spec. comm_pattern(CollectiveCommPattern): decribe the communication method used in this spec.
process_groups_dict(Dict): A dict which contains the process groups used to apply this CommSpec. process_group_dict(Dict): A dict which contains the process groups used to apply this CommSpec.
gather_dim(int, Optional): The gather_dim of the tensor will be gathered. gather_dim(int, Optional): The gather_dim of the tensor will be gathered.
shard_dim(int, Optional): The shard_dim of the tensor will be sharded. shard_dim(int, Optional): The shard_dim of the tensor will be sharded.
logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action. logical_process_axis(Union(int, List[int]), Optional): The mesh_dim to implement the communication action.
...@@ -37,7 +37,7 @@ class CommSpec: ...@@ -37,7 +37,7 @@ class CommSpec:
def __init__(self, def __init__(self,
comm_pattern: CollectiveCommPattern, comm_pattern: CollectiveCommPattern,
process_groups_dict: Dict, process_group_dict: Dict,
gather_dim: int = None, gather_dim: int = None,
shard_dim: int = None, shard_dim: int = None,
logical_process_axis: int = None): logical_process_axis: int = None):
...@@ -45,7 +45,7 @@ class CommSpec: ...@@ -45,7 +45,7 @@ class CommSpec:
self.gather_dim = gather_dim self.gather_dim = gather_dim
self.shard_dim = shard_dim self.shard_dim = shard_dim
self.logical_process_axis = logical_process_axis self.logical_process_axis = logical_process_axis
self.process_groups_dict = process_groups_dict self.process_group_dict = process_group_dict
def __repr__(self): def __repr__(self):
res_list = ["CommSpec:("] res_list = ["CommSpec:("]
...@@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec): ...@@ -92,68 +92,56 @@ def _all_gather(tensor: torch.Tensor, comm_spec: CommSpec):
''' '''
Implement all gather operation on device mesh based on information provided by comm_spec. Implement all gather operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list: world_size = dist.get_world_size(process_group)
if dist.get_rank() in rank_list: tensor_list = [torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
tensor_list = [ # without this contiguous operation, the all gather may get some unexpected results.
torch.zeros(tensor.shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) tensor = tensor.contiguous()
] dist.all_gather(tensor_list, tensor, group=process_group)
# without this contiguous operation, the all gather may get some unexpected results. output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
tensor = tensor.contiguous() return output
dist.all_gather(tensor_list, tensor, group=process_group)
output = torch.cat(tuple(tensor_list), comm_spec.gather_dim).contiguous()
return output
def _split(tensor: torch.Tensor, comm_spec: CommSpec): def _split(tensor: torch.Tensor, comm_spec: CommSpec):
''' '''
Implement shard operation on device mesh based on information provided by comm_spec. Implement shard operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, _ in process_groups_list: dim = comm_spec.shard_dim
if dist.get_rank() in rank_list: length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
dim = comm_spec.shard_dim start = length * dist.get_rank(process_group)
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) output = torch.narrow(tensor, dim, start, length).contiguous()
start = length * rank_list.index(dist.get_rank()) return output
output = torch.narrow(tensor, dim, start, length).contiguous()
return output
def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec): def _all_to_all(tensor: torch.Tensor, comm_spec: CommSpec):
''' '''
Implement all to all operation on device mesh based on information provided by comm_spec. Implement all to all operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list: world_size = dist.get_world_size(process_group)
if dist.get_rank() in rank_list: new_shape = list(tensor.shape)
new_shape = list(tensor.shape) new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // world_size
new_shape[comm_spec.shard_dim] = new_shape[comm_spec.shard_dim] // len(rank_list) new_shape = torch.Size(new_shape)
new_shape = torch.Size(new_shape) output_tensor_list = [torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(world_size)]
output_tensor_list = [ dim = comm_spec.shard_dim
torch.zeros(new_shape, dtype=tensor.dtype, device=tensor.device) for _ in range(len(rank_list)) length = tensor.shape[comm_spec.shard_dim] // world_size
] input_tensor_list = [torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(world_size)]
dim = comm_spec.shard_dim group = process_group
length = tensor.shape[comm_spec.shard_dim] // len(rank_list) dist.all_to_all(output_tensor_list, input_tensor_list, group)
input_tensor_list = [ output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
torch.narrow(tensor, dim, length * i, length).contiguous() for i in range(len(rank_list)) return output
]
group = process_group
dist.all_to_all(output_tensor_list, input_tensor_list, group)
output = torch.cat(tuple(output_tensor_list), comm_spec.gather_dim).contiguous()
return output
def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False): def _all_reduce(tensor: torch.Tensor, comm_spec: CommSpec, async_op: bool = False):
''' '''
Implement all reduce operation on device mesh based on information provided by comm_spec. Implement all reduce operation on device mesh based on information provided by comm_spec.
''' '''
process_groups_list = comm_spec.process_groups_dict[comm_spec.logical_process_axis] process_group = comm_spec.process_group_dict[comm_spec.logical_process_axis]
for rank_list, process_group in process_groups_list: if not tensor.is_contiguous():
if dist.get_rank() in rank_list: tensor = tensor.contiguous()
if not tensor.is_contiguous(): dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
tensor = tensor.contiguous() return tensor
dist.all_reduce(tensor, op=ReduceOp.SUM, group=process_group, async_op=async_op)
return tensor
class _ReduceGrad(torch.autograd.Function): class _ReduceGrad(torch.autograd.Function):
...@@ -269,7 +257,7 @@ class _AllToAll(torch.autograd.Function): ...@@ -269,7 +257,7 @@ class _AllToAll(torch.autograd.Function):
def forward(ctx, input_, comm_spec): def forward(ctx, input_, comm_spec):
output = _all_to_all(input_, comm_spec) output = _all_to_all(input_, comm_spec)
comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern, comm_spec_for_backward = CommSpec(comm_pattern=comm_spec.comm_pattern,
process_groups_dict=comm_spec.process_groups_dict, process_group_dict=comm_spec.process_group_dict,
gather_dim=comm_spec.shard_dim, gather_dim=comm_spec.shard_dim,
shard_dim=comm_spec.gather_dim, shard_dim=comm_spec.gather_dim,
logical_process_axis=comm_spec.logical_process_axis) logical_process_axis=comm_spec.logical_process_axis)
......
...@@ -14,24 +14,21 @@ class Layout: ...@@ -14,24 +14,21 @@ class Layout:
Attributes: Attributes:
device_mesh: the device mesh to store the tensor distributed. device_mesh: the device mesh to store the tensor distributed.
device_type: the type of the device mesh, e.g. 'cpu' or 'cuda'.
sharding_spec: the sharding specification to describe how the tensor is sharded. sharding_spec: the sharding specification to describe how the tensor is sharded.
entire_shape: the entire shape of the global tensor. global_shape: the entire shape of the global tensor.
""" """
def __init__(self, device_mesh: DeviceMesh, device_type: torch.device, sharding_spec: ShardingSpec, def __init__(self, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size):
entire_shape: torch.Size):
self.device_mesh = device_mesh self.device_mesh = device_mesh
self.device_type = device_type
self.sharding_spec = sharding_spec self.sharding_spec = sharding_spec
self.entire_shape = entire_shape self.global_shape = global_shape
self._sanity_check() self._sanity_check()
def __hash__(self) -> int: def __hash__(self) -> int:
return hash(f'{self.sharding_spec}') return hash(f'{self.sharding_spec}')
def get_sharded_shape_per_device(self): def get_sharded_shape_per_device(self):
sharded_shape = list(self.entire_shape) sharded_shape = list(self.global_shape)
for dim, shard_list in self.sharding_spec.dim_partition_dict.items(): for dim, shard_list in self.sharding_spec.dim_partition_dict.items():
mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list] mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1) shard_partitions = reduce(operator.mul, mesh_list, 1)
...@@ -56,7 +53,7 @@ class Layout: ...@@ -56,7 +53,7 @@ class Layout:
# make sure that the sharding for a dimension is divisible by the number of devices # make sure that the sharding for a dimension is divisible by the number of devices
for dim, shard_list in sharding_spec.dim_partition_dict.items(): for dim, shard_list in sharding_spec.dim_partition_dict.items():
tensor_dim_size = self.entire_shape[dim] tensor_dim_size = self.global_shape[dim]
num_devices = 1 num_devices = 1
for element in shard_list: for element in shard_list:
......
...@@ -3,10 +3,8 @@ from copy import deepcopy ...@@ -3,10 +3,8 @@ from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Tuple from typing import Dict, List, Tuple
import numpy as np
import torch import torch
from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, TrainCycleItem
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import * from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout from colossalai.tensor.d_tensor.layout import Layout
...@@ -37,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions): ...@@ -37,6 +35,9 @@ def set_layout_converting_options(options: LayoutConverterOptions):
class LayoutConverter(metaclass=SingletonMeta): class LayoutConverter(metaclass=SingletonMeta):
"""
LayoutConverter is a singleton class which converts the layout of a distributed tensor.
"""
def __init__(self): def __init__(self):
self._options = None self._options = None
...@@ -79,15 +80,14 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -79,15 +80,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]} dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R] # [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec, sharding_spec=sharding_spec,
entire_shape=entire_shape) global_shape=global_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout) rst_dict = layout_converter.all_gather_transform_layouts(layout)
for layout, comm_spec in rst_dict.items(): for layout, comm_spec in rst_dict.items():
...@@ -100,7 +100,12 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -100,7 +100,12 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {} valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD comm_pattern = CollectiveCommPattern.GATHER_FWD_SPLIT_BWD
source_spec = source_layout.sharding_spec source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
for target_pair in source_spec.dim_partition_dict.items(): for target_pair in source_spec.dim_partition_dict.items():
shard_list = all_gather_simulator(target_pair) shard_list = all_gather_simulator(target_pair)
index = target_pair[0] index = target_pair[0]
...@@ -118,7 +123,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -118,7 +123,7 @@ class LayoutConverter(metaclass=SingletonMeta):
logical_process_axis = target_pair[1][-1] logical_process_axis = target_pair[1][-1]
comm_spec = CommSpec( comm_spec = CommSpec(
comm_pattern, comm_pattern,
process_groups_dict=process_groups_dict, process_group_dict=process_group_dict,
gather_dim=gather_dim, gather_dim=gather_dim,
# shard_dim will be used during backward # shard_dim will be used during backward
shard_dim=gather_dim, shard_dim=gather_dim,
...@@ -129,8 +134,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -129,8 +134,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh, new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec, sharding_spec=new_sharding_spec,
device_type=source_layout.device_type, global_shape=source_layout.global_shape)
entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except LayoutException: except LayoutException:
...@@ -155,15 +159,14 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -155,15 +159,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_dict = {0: [0], 1: [1]} dim_partition_dict = {0: [0], 1: [1]}
# [S0,S1,R] # [S0,S1,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec, sharding_spec=sharding_spec,
entire_shape=entire_shape) global_shape=global_shape)
rst_dict = layout_converter.all_to_all_transform_layout(layout) rst_dict = layout_converter.all_to_all_transform_layout(layout)
for layout, comm_spec in rst_dict.items(): for layout, comm_spec in rst_dict.items():
...@@ -176,7 +179,12 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -176,7 +179,12 @@ class LayoutConverter(metaclass=SingletonMeta):
''' '''
valid_spec_dict = {} valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD comm_pattern = CollectiveCommPattern.ALL2ALL_FWD_ALL2ALL_BWD
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
source_spec = source_layout.sharding_spec source_spec = source_layout.sharding_spec
tensor_dims = source_spec.dims tensor_dims = source_spec.dims
for f_index in range(tensor_dims - 1): for f_index in range(tensor_dims - 1):
...@@ -217,7 +225,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -217,7 +225,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = f_index shard_dim = f_index
logical_process_axis = b_target_pair[1][-1] logical_process_axis = b_target_pair[1][-1]
comm_spec = CommSpec(comm_pattern, comm_spec = CommSpec(comm_pattern,
process_groups_dict, process_group_dict=process_group_dict,
gather_dim=gather_dim, gather_dim=gather_dim,
shard_dim=shard_dim, shard_dim=shard_dim,
logical_process_axis=logical_process_axis) logical_process_axis=logical_process_axis)
...@@ -240,8 +248,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -240,8 +248,7 @@ class LayoutConverter(metaclass=SingletonMeta):
new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict) new_sharding_spec = ShardingSpec(source_spec.dims, dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh, new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec, sharding_spec=new_sharding_spec,
device_type=source_layout.device_type, global_shape=source_layout.global_shape)
entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except LayoutException: except LayoutException:
pass pass
...@@ -266,16 +273,15 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -266,16 +273,15 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_dict = {0: [0]} dim_partition_dict = {0: [0]}
# [S0,R,R] # [S0,R,R]
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec, sharding_spec=sharding_spec,
entire_shape=entire_shape) global_shape=global_shape)
rst_dict = layout_converter.shard_transform_layout(layout) rst_dict = layout_converter.shard_transform_layout(layout)
for layout, comm_spec in rst_dict.items(): for layout, comm_spec in rst_dict.items():
...@@ -289,7 +295,11 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -289,7 +295,11 @@ class LayoutConverter(metaclass=SingletonMeta):
valid_spec_dict = {} valid_spec_dict = {}
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
source_spec = source_layout.sharding_spec source_spec = source_layout.sharding_spec
process_groups_dict = source_layout.device_mesh.process_groups_dict
# the key of the dict is the axis
# the value is the process group
current_rank = source_layout.device_mesh._global_rank_of_current_process
process_group_dict = source_layout.device_mesh._process_group_dict[current_rank]
# legal sharding dims means the mesh_id is still available to use. # legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))] legal_sharding_dims = [i for i in range(len(source_layout.device_mesh.shape))]
...@@ -317,7 +327,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -317,7 +327,7 @@ class LayoutConverter(metaclass=SingletonMeta):
shard_dim = index shard_dim = index
logical_process_axis = shard_list[-1] logical_process_axis = shard_list[-1]
comm_spec = CommSpec(comm_pattern, comm_spec = CommSpec(comm_pattern,
process_groups_dict, process_group_dict=process_group_dict,
gather_dim=shard_dim, gather_dim=shard_dim,
shard_dim=shard_dim, shard_dim=shard_dim,
logical_process_axis=logical_process_axis) logical_process_axis=logical_process_axis)
...@@ -328,8 +338,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -328,8 +338,7 @@ class LayoutConverter(metaclass=SingletonMeta):
dim_partition_dict=new_dim_partition_dict) dim_partition_dict=new_dim_partition_dict)
new_layout = Layout(device_mesh=source_layout.device_mesh, new_layout = Layout(device_mesh=source_layout.device_mesh,
sharding_spec=new_sharding_spec, sharding_spec=new_sharding_spec,
device_type=source_layout.device_type, global_shape=source_layout.global_shape)
entire_shape=source_layout.entire_shape)
valid_spec_dict[new_layout] = comm_spec valid_spec_dict[new_layout] = comm_spec
except LayoutException: except LayoutException:
pass pass
...@@ -387,7 +396,7 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -387,7 +396,7 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
dim_partition_source = {1: [0, 1]} dim_partition_source = {1: [0, 1]}
dim_partition_target = {0: [0, 1]} dim_partition_target = {0: [0, 1]}
...@@ -395,16 +404,14 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -395,16 +404,14 @@ class LayoutConverter(metaclass=SingletonMeta):
# [R,S01,R] # [R,S01,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source, sharding_spec=sharding_spec_source,
entire_shape=entire_shape) global_shape=global_shape)
# [S01,R,R] # [S01,R,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target, sharding_spec=sharding_spec_target,
entire_shape=entire_shape) global_shape=global_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path]) transform_path_str = '->'.join([str(layout.sharding_spec.sharding_sequence) for layout in transform_path])
...@@ -493,21 +500,19 @@ class LayoutConverter(metaclass=SingletonMeta): ...@@ -493,21 +500,19 @@ class LayoutConverter(metaclass=SingletonMeta):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
entire_shape = (4, 4, 4) global_shape = (4, 4, 4)
# [S0,R,R] # [S0,R,R]
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source, sharding_spec=sharding_spec_source,
entire_shape=entire_shape) global_shape=global_shape)
# [R,S0,R] # [R,S0,R]
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target, sharding_spec=sharding_spec_target,
entire_shape=entire_shape) global_shape=global_shape)
if rank in (0, 1): if rank in (0, 1):
sharded_tensor_0 = torch.zeros(2, 1) sharded_tensor_0 = torch.zeros(2, 1)
......
...@@ -285,7 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): ...@@ -285,7 +285,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD comm_pattern = CollectiveCommPattern.SPLIT_FWD_GATHER_BWD
# legal sharding dims means the mesh_id is still available to use. # legal sharding dims means the mesh_id is still available to use.
legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.mesh_shape))] legal_sharding_dims = [i for i in range(len(source_spec.device_mesh.shape))]
for dim, shard_list in source_spec.dim_partition_dict.items(): for dim, shard_list in source_spec.dim_partition_dict.items():
for element in shard_list: for element in shard_list:
legal_sharding_dims.remove(element) legal_sharding_dims.remove(element)
...@@ -435,7 +435,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): ...@@ -435,7 +435,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
""" """
input_shape = compute_shape(comm_spec.sharding_spec) input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape) input_numel = np.prod(input_shape)
output_numel = input_numel * comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] output_numel = input_numel * comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
peak_numel = max(peak_numel, alloc_numel + output_numel * 2) peak_numel = max(peak_numel, alloc_numel + output_numel * 2)
alloc_numel += output_numel alloc_numel += output_numel
if discard_input: if discard_input:
...@@ -461,7 +461,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta): ...@@ -461,7 +461,7 @@ class ShapeConsistencyManager(metaclass=SingletonMeta):
# generate a new tensor # generate a new tensor
input_shape = compute_shape(comm_spec.sharding_spec) input_shape = compute_shape(comm_spec.sharding_spec)
input_numel = np.prod(input_shape) input_numel = np.prod(input_shape)
output_numel = input_numel // comm_spec.device_mesh.mesh_shape[comm_spec.logical_process_axis] output_numel = input_numel // comm_spec.device_mesh.shape[comm_spec.logical_process_axis]
alloc_numel += output_numel alloc_numel += output_numel
peak_numel = max(peak_numel, alloc_numel) peak_numel = max(peak_numel, alloc_numel)
if discard_input: if discard_input:
......
...@@ -195,7 +195,7 @@ class ShardingSpec: ...@@ -195,7 +195,7 @@ class ShardingSpec:
def __repr__(self): def __repr__(self):
res_list = ["DistSpec:"] res_list = ["DistSpec:"]
res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence)) res_list.append(f"\n\tshard_sequence: " + ",".join(str(dimspec) for dimspec in self.sharding_sequence))
res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.mesh_shape}") res_list.append(f"\n\tdevice_mesh_shape: {self.device_mesh.shape}")
return ' '.join(res_list) return ' '.join(res_list)
def _sanity_check(self): def _sanity_check(self):
...@@ -222,7 +222,7 @@ class ShardingSpec: ...@@ -222,7 +222,7 @@ class ShardingSpec:
num_devices = 1 num_devices = 1
for element in shard_list: for element in shard_list:
num_devices *= self.device_mesh.mesh_shape[element] num_devices *= self.device_mesh.shape[element]
if tensor_dim_size % num_devices != 0: if tensor_dim_size % num_devices != 0:
raise ShardingNotDivisibleError( raise ShardingNotDivisibleError(
...@@ -288,7 +288,7 @@ class ShardingSpec: ...@@ -288,7 +288,7 @@ class ShardingSpec:
sharded_shape = list(self.entire_shape) sharded_shape = list(self.entire_shape)
for dim, shard_list in self.dim_partition_dict.items(): for dim, shard_list in self.dim_partition_dict.items():
mesh_list = [self.device_mesh.mesh_shape[mesh_dim] for mesh_dim in shard_list] mesh_list = [self.device_mesh.shape[mesh_dim] for mesh_dim in shard_list]
shard_partitions = reduce(operator.mul, mesh_list, 1) shard_partitions = reduce(operator.mul, mesh_list, 1)
assert sharded_shape[ assert sharded_shape[
dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.' dim] % shard_partitions == 0, f'Cannot shard dimension {dim} into {shard_partitions} partitions.'
......
from colossalai.tensor.d_tensor.api import to_distributed_tensor
...@@ -58,13 +58,4 @@ def test_evoformer_block(model, shape, max_memory): ...@@ -58,13 +58,4 @@ def test_evoformer_block(model, shape, max_memory):
if __name__ == "__main__": if __name__ == "__main__":
run_test( test_evoformer_block()
rank=0,
data=get_data(LATENTS_SHAPE),
max_memory=None,
model=UNet2DModel,
print_code=False,
print_mem=True,
print_est_mem=False,
print_progress=False,
)
...@@ -22,7 +22,7 @@ from tests.kit.model_zoo import model_zoo ...@@ -22,7 +22,7 @@ from tests.kit.model_zoo import model_zoo
@parameterize('use_safetensors', [False, True]) @parameterize('use_safetensors', [False, True])
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool): def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
from transformers import BertForSequenceClassification from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn() bert_model = model_fn()
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
...@@ -53,7 +53,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b ...@@ -53,7 +53,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b
@parameterize('shard', [True, False]) @parameterize('shard', [True, False])
@parameterize('model_name', ['transformers_gpt']) @parameterize('model_name', ['transformers_gpt'])
def exam_state_dict(placement_policy, shard: bool, model_name: str): def exam_state_dict(placement_policy, shard: bool, model_name: str):
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean() criterion = lambda x: x.mean()
plugin = GeminiPlugin(placement_policy=placement_policy) plugin = GeminiPlugin(placement_policy=placement_policy)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
......
...@@ -8,18 +8,16 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn ...@@ -8,18 +8,16 @@ from colossalai.testing import rerun_if_address_is_in_use, spawn
def test_device_mesh(): def test_device_mesh():
physical_mesh_id = torch.arange(0, 16).reshape(2, 8) physical_mesh_id = torch.arange(0, 16)
mesh_shape = (4, 4) mesh_shape = (4, 4)
# [[0, 1, 2, 3], # [[0, 1, 2, 3],
# [4, 5, 6, 7], # [4, 5, 6, 7],
# [8, 9, 10,11], # [8, 9, 10,11],
# [12,13,14,15]] # [12,13,14,15]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
assert device_mesh.convert_map[5] == [1, 1] assert device_mesh.global_rank_to_local_rank(5) == [1, 1]
assert device_mesh.convert_map[11] == [2, 3] assert device_mesh.global_rank_to_local_rank(11) == [2, 3]
assert device_mesh.global_rank_to_process_groups_with_logical_rank(0)[0] == [[0, 0], [1, 0], [2, 0], [3, 0]] assert device_mesh.get_ranks_in_process_group(axis=1, global_rank=2) == [0, 1, 2, 3]
assert device_mesh.global_rank_to_process_groups_with_logical_rank(2)[1] == [[0, 0], [0, 1], [0, 2], [0, 3]]
assert device_mesh.global_rank_to_process_groups_with_global_rank(2)[1] == [0, 1, 2, 3]
def check_1d_device_mesh(): def check_1d_device_mesh():
......
...@@ -20,16 +20,12 @@ def check_layer(rank, world_size, port): ...@@ -20,16 +20,12 @@ def check_layer(rank, world_size, port):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
logical_pg_dict = {0: [[0, 2], [1, 3]], 1: [[0, 1], [2, 3]]}
logical_process_groups = device_mesh.process_groups_dict for axis in range(len(mesh_shape)):
tensor = torch.ones(4).cuda()
for mesh_dim, pgs in logical_pg_dict.items(): pg = device_mesh.get_process_group(axis=axis)
for index, pg in enumerate(pgs): dist.all_reduce(tensor, op=ReduceOp.SUM, group=pg)
if rank in pg: assert tensor.equal(tensor_to_check)
tensor = torch.ones(4).cuda()
group = logical_process_groups[mesh_dim][index][1]
dist.all_reduce(tensor, op=ReduceOp.SUM, group=group)
assert tensor.equal(tensor_to_check)
gpc.destroy() gpc.destroy()
......
from typing import List
import torch import torch
from numpy import isin from numpy import isin
from torch.fx import GraphModule from torch.fx import GraphModule
...@@ -7,19 +9,23 @@ from torch.utils._pytree import tree_flatten ...@@ -7,19 +9,23 @@ from torch.utils._pytree import tree_flatten
from colossalai._analyzer.fx import symbolic_trace from colossalai._analyzer.fx import symbolic_trace
def trace_model_and_compare_output(model, data_gen): def trace_model_and_compare_output(model, data_gen, ignore_data: List[str] = None):
# must turn on eval mode to ensure the output is consistent # must turn on eval mode to ensure the output is consistent
model.eval() model.eval()
inputs = data_gen()
if ignore_data is not None:
# drop the ignore_data key
inputs = {k: v for k, v in inputs.items() if k not in ignore_data}
try: try:
kwargs = data_gen() meta_args = {k: v.to('meta') for k, v in inputs.items()}
meta_args = {k: v.to('meta') for k, v in kwargs.items()}
gm = symbolic_trace(model, meta_args=meta_args) gm = symbolic_trace(model, meta_args=meta_args)
except Exception as e: except Exception as e:
raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}") raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
# run forward # run forward
inputs = data_gen()
non_fx_out = model(**inputs) non_fx_out = model(**inputs)
fx_out = gm(**inputs) fx_out = gm(**inputs)
......
...@@ -15,7 +15,7 @@ SEQ_LENGTH = 16 ...@@ -15,7 +15,7 @@ SEQ_LENGTH = 16
def test_albert(): def test_albert():
sub_registry = model_zoo.get_sub_registry('transformers_albert') sub_registry = model_zoo.get_sub_registry('transformers_albert')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn() model = model_fn()
trace_model_and_compare_output(model, data_gen_fn) trace_model_and_compare_output(model, data_gen_fn)
......
...@@ -12,9 +12,9 @@ from tests.kit.model_zoo import model_zoo ...@@ -12,9 +12,9 @@ from tests.kit.model_zoo import model_zoo
def test_bert(): def test_bert():
sub_registry = model_zoo.get_sub_registry('transformers_bert') sub_registry = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, _, _) in sub_registry.items(): for name, (model_fn, data_gen_fn, _, _, _) in sub_registry.items():
model = model_fn() model = model_fn()
trace_model_and_compare_output(model, data_gen_fn) trace_model_and_compare_output(model, data_gen_fn, ignore_data=['labels', 'next_sentence_label'])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -47,7 +47,7 @@ def test_diffusers(): ...@@ -47,7 +47,7 @@ def test_diffusers():
sub_model_zoo = model_zoo.get_sub_registry('diffusers') sub_model_zoo = model_zoo.get_sub_registry('diffusers')
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, _, _, attribute) in sub_model_zoo.items():
data = data_gen_fn() data = data_gen_fn()
trace_and_compare(model_fn, data, output_transform_fn) trace_and_compare(model_fn, data, output_transform_fn)
torch.cuda.synchronize() torch.cuda.synchronize()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment