Unverified Commit ddcf58ca authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Revert "[sync] sync feature/shardformer with develop"

parent 24651fdd
...@@ -6,9 +6,7 @@ import numpy as np ...@@ -6,9 +6,7 @@ import numpy as np
import torch import torch
from packaging import version from packaging import version
from colossalai.device.device_mesh import DeviceMesh
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.layout_converter import to_global from colossalai.tensor.d_tensor.layout_converter import to_global
from tests.kit.model_zoo.registry import ModelAttribute from tests.kit.model_zoo.registry import ModelAttribute
...@@ -83,8 +81,7 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, ...@@ -83,8 +81,7 @@ def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False,
print(f'{model.__class__.__name__} pass') print(f'{model.__class__.__name__} pass')
def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, device_mesh: DeviceMesh, def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.Module, layout_dict: dict) -> None:
sharding_spec_dict: dict) -> None:
state = model.state_dict() state = model.state_dict()
distributed_state = distributed_model.state_dict() distributed_state = distributed_model.state_dict()
...@@ -94,7 +91,6 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn. ...@@ -94,7 +91,6 @@ def assert_dist_model_equal(model: torch.nn.Module, distributed_model: torch.nn.
assert n1 == n2 assert n1 == n2
t1 = t1.cuda() t1 = t1.cuda()
t2 = t2.cuda() t2 = t2.cuda()
if n2 in sharding_spec_dict: if n2 in layout_dict:
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_dict[n2], global_shape=t1.shape) t2 = to_global(t2, layout_dict[n2])
t2 = to_global(t2, layout)
assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}' assert torch.equal(t1, t2), f'{n1} {t1} vs {t2}'
...@@ -26,19 +26,23 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]: ...@@ -26,19 +26,23 @@ def find_shard_dim(shape: torch.Size) -> Optional[int]:
return dim return dim
def make_sharding_spec(original_tensor: torch.Tensor) -> Layout: def make_layout(device_mesh: DeviceMesh, original_tensor: torch.Tensor) -> Layout:
shard_dim = find_shard_dim(original_tensor.shape) shard_dim = find_shard_dim(original_tensor.shape)
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {} dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
return target_sharding_spec layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
entire_shape=original_tensor.shape)
return layout
def _get_current_name(prefix: str, name: str) -> str: def _get_current_name(prefix: str, name: str) -> str:
return f'{prefix}.{name}'.lstrip('.') return f'{prefix}.{name}'.lstrip('.')
def generate_sharding_spec_dict(model: nn.Module) -> dict: def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
sharding_spec_dict = {} layout_dict = {}
@torch.no_grad() @torch.no_grad()
def generate_recursively(module: nn.Module, prefix: str = ''): def generate_recursively(module: nn.Module, prefix: str = ''):
...@@ -49,17 +53,17 @@ def generate_sharding_spec_dict(model: nn.Module) -> dict: ...@@ -49,17 +53,17 @@ def generate_sharding_spec_dict(model: nn.Module) -> dict:
# initialize tensors directly attached to the current module # initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False): for name, param in module.named_parameters(recurse=False):
if isinstance(param, LazyTensor): if isinstance(param, LazyTensor):
sharding_spec = make_sharding_spec(param) layout = make_layout(device_mesh, param)
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec layout_dict[_get_current_name(prefix, name)] = layout
for name, buf in module.named_buffers(recurse=False): for name, buf in module.named_buffers(recurse=False):
if isinstance(buf, LazyTensor): if isinstance(buf, LazyTensor):
sharding_spec = make_sharding_spec(buf) layout = make_layout(device_mesh, buf)
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec layout_dict[_get_current_name(prefix, name)] = layout
generate_recursively(model) generate_recursively(model)
return sharding_spec_dict return layout_dict
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm']) @parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
...@@ -81,9 +85,9 @@ def run_dist_lazy_init(subset, seed: int = 42): ...@@ -81,9 +85,9 @@ def run_dist_lazy_init(subset, seed: int = 42):
ctx = LazyInitContext() ctx = LazyInitContext()
with ctx: with ctx:
deferred_model = model_fn() deferred_model = model_fn()
sharding_spec_dict = generate_sharding_spec_dict(deferred_model) layout_dict = generate_layout_dict(deferred_model, device_mesh)
ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True) ctx.distribute(deferred_model, layout_dict, verbose=True)
assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict) assert_dist_model_equal(model, deferred_model, layout_dict)
def run_dist(rank, world_size, port) -> None: def run_dist(rank, world_size, port) -> None:
......
...@@ -125,6 +125,23 @@ def check_all_reduce_bwd(process_groups_dict, rank): ...@@ -125,6 +125,23 @@ def check_all_reduce_bwd(process_groups_dict, rank):
assert tensor_to_comm.equal(tensor_to_check) assert tensor_to_comm.equal(tensor_to_check)
def check_all_reduce_in_flatten_device_mesh(process_groups_dict, rank):
# tensor to comm
tensor_to_comm = torch.ones(2, 2).cuda() * rank
# reduce through logical process axis 0 at flatten device mesh
# tensor to check
# tensor([[6., 6.],
# [6., 6.]])
tensor_to_check = torch.tensor([[6, 6], [6, 6]], dtype=tensor_to_comm.dtype).cuda()
# CommSpec:(comm_pattern:all_reduce, logical_process_axis:[0, 1])
comm_spec = CommSpec(CollectiveCommPattern.ALLREDUCE_FWD_IDENTITY_BWD, process_groups_dict, logical_process_axis=0)
tensor_to_comm = comm_spec.covert_spec_to_action(tensor_to_comm)
assert tensor_to_comm.equal(tensor_to_check)
def check_comm(rank, world_size, port): def check_comm(rank, world_size, port):
disable_existing_loggers() disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
...@@ -136,22 +153,24 @@ def check_comm(rank, world_size, port): ...@@ -136,22 +153,24 @@ def check_comm(rank, world_size, port):
# [[0, 1, # [[0, 1,
# [2, 3]] # [2, 3]]
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
process_groups_dict = device_mesh.process_groups_dict
process_group_dict = device_mesh._process_group_dict[rank]
# test all gather # test all gather
check_all_gather(process_group_dict, rank) check_all_gather(process_groups_dict, rank)
# test shard # test shard
check_shard(process_group_dict, rank) check_shard(process_groups_dict, rank)
# test all to all # test all to all
check_all_to_all(process_group_dict, rank) check_all_to_all(process_groups_dict, rank)
# test all reduce # test all reduce
check_all_reduce_fwd(process_group_dict, rank) check_all_reduce_fwd(process_groups_dict, rank)
check_all_reduce_bwd(process_group_dict, rank) check_all_reduce_bwd(process_groups_dict, rank)
flatten_process_groups_dict = device_mesh.flatten_device_mesh.process_groups_dict
# test all reduce in 1D flatten device mesh
check_all_reduce_in_flatten_device_mesh(flatten_process_groups_dict, rank)
gpc.destroy() gpc.destroy()
......
...@@ -31,9 +31,13 @@ def check_dtensor(rank, world_size, port): ...@@ -31,9 +31,13 @@ def check_dtensor(rank, world_size, port):
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True) device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]}) target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
d_tensor = DTensor(original_tensor, device_mesh, target_sharding_spec) layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=target_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor = DTensor(original_tensor, layout)
assert d_tensor.global_shape == original_tensor.shape assert d_tensor.entire_shape == original_tensor.shape
assert d_tensor.data_type == original_tensor.dtype assert d_tensor.data_type == original_tensor.dtype
if rank in (0, 1): if rank in (0, 1):
...@@ -53,7 +57,12 @@ def check_dtensor(rank, world_size, port): ...@@ -53,7 +57,12 @@ def check_dtensor(rank, world_size, port):
raise ValueError(f'rank {rank} is not in the device mesh') raise ValueError(f'rank {rank} is not in the device mesh')
new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]}) new_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0, 1]})
d_tensor.layout_convert(device_mesh, new_sharding_spec) new_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=new_sharding_spec,
entire_shape=original_tensor.shape)
d_tensor.layout_convert(new_layout)
if rank == 0: if rank == 0:
assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1)) assert d_tensor.local_tensor.equal(original_tensor.narrow(0, 0, 1))
...@@ -66,7 +75,7 @@ def check_dtensor(rank, world_size, port): ...@@ -66,7 +75,7 @@ def check_dtensor(rank, world_size, port):
else: else:
raise ValueError(f'rank {rank} is not in the device mesh') raise ValueError(f'rank {rank} is not in the device mesh')
dtensor_from_local = distribute_tensor(original_tensor, device_mesh, new_sharding_spec) dtensor_from_local = distribute_tensor(original_tensor, new_layout)
if rank == 0: if rank == 0:
assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1)) assert dtensor_from_local.local_tensor.equal(original_tensor.narrow(0, 0, 1))
......
...@@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter ...@@ -12,9 +12,9 @@ from colossalai.tensor.d_tensor.layout_converter import LayoutConverter
from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec from colossalai.tensor.d_tensor.sharding_spec import DimSpec, ShardingSpec
from colossalai.testing import rerun_if_address_is_in_use, spawn from colossalai.testing import rerun_if_address_is_in_use, spawn
global_shape = torch.Size((64, 32, 16)) entire_shape = torch.Size((64, 32, 16))
layout_converter = LayoutConverter() layout_converter = LayoutConverter()
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
mesh_shape = (2, 2) mesh_shape = (2, 2)
...@@ -30,7 +30,10 @@ def check_one_step_transform(rank, world_size, port): ...@@ -30,7 +30,10 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R # shard_sequence: S0,S1,R
# device_mesh_shape: (2, 2) # device_mesh_shape: (2, 2)
sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict) sharding_spec = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict)
layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape) layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec,
entire_shape=entire_shape)
rst_dict = layout_converter.all_gather_transform_layouts(layout) rst_dict = layout_converter.all_gather_transform_layouts(layout)
...@@ -46,7 +49,10 @@ def check_one_step_transform(rank, world_size, port): ...@@ -46,7 +49,10 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,S1,R # shard_sequence: S0,S1,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all) sharding_spec_all2all = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_dict_all2all)
layout_all2all = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_all2all, global_shape=global_shape) layout_all2all = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_all2all,
entire_shape=entire_shape)
rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all) rst_dict_all2all = layout_converter.all_to_all_transform_layout(layout_all2all)
...@@ -65,7 +71,10 @@ def check_one_step_transform(rank, world_size, port): ...@@ -65,7 +71,10 @@ def check_one_step_transform(rank, world_size, port):
# shard_sequence: S0,R,R # shard_sequence: S0,R,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard) sharding_spec_shard = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_shard)
shard_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_shard, global_shape=global_shape) shard_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_shard,
entire_shape=entire_shape)
rst_dict_shard = layout_converter.shard_transform_layout(shard_layout) rst_dict_shard = layout_converter.shard_transform_layout(shard_layout)
...@@ -91,13 +100,19 @@ def check_layout_converting(rank, world_size, port): ...@@ -91,13 +100,19 @@ def check_layout_converting(rank, world_size, port):
# shard_sequence: R,S01,R # shard_sequence: R,S01,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
# DistSpec: # DistSpec:
# shard_sequence: S01,R,R # shard_sequence: S01,R,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout) transform_path, comm_action_sequence = layout_converter.layout_converting(source_layout, target_layout)
...@@ -144,15 +159,21 @@ def check_layout_converting_apply(rank, world_size, port): ...@@ -144,15 +159,21 @@ def check_layout_converting_apply(rank, world_size, port):
# shard_sequence: R,S01,R # shard_sequence: R,S01,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source) sharding_spec_source = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_source)
source_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_source, global_shape=global_shape) source_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_source,
entire_shape=entire_shape)
# DistSpec: # DistSpec:
# shard_sequence: S01,R,R # shard_sequence: S01,R,R
# device_mesh_shape: (4, 4) # device_mesh_shape: (4, 4)
sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target) sharding_spec_target = ShardingSpec(dim_size=3, dim_partition_dict=dim_partition_target)
target_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec_target, global_shape=global_shape) target_layout = Layout(device_mesh=device_mesh,
device_type=torch.device('cuda'),
sharding_spec=sharding_spec_target,
entire_shape=entire_shape)
original_tensor = torch.rand(global_shape).cuda() original_tensor = torch.rand(entire_shape).cuda()
# tensor_to_apply: [R, S01, R] # tensor_to_apply: [R, S01, R]
tensor_to_apply = original_tensor.narrow(1, rank * 8, 8) tensor_to_apply = original_tensor.narrow(1, rank * 8, 8)
......
from colossalai.tensor.shape_consistency import ShapeConsistencyManager, CollectiveCommPattern
import torch import torch
from colossalai.tensor.sharding_spec import _DimSpec, ShardingSpec
from colossalai.device.device_mesh import DeviceMesh from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.shape_consistency import CollectiveCommPattern, ShapeConsistencyManager
from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
physical_mesh_id = torch.arange(0, 16) physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4) mesh_shape = (4, 4)
# [[0, 1, 2, 3], # [[0, 1, 2, 3],
# [4, 5, 6, 7], # [4, 5, 6, 7],
......
...@@ -26,7 +26,7 @@ def run_dist(rank, world_size, port): ...@@ -26,7 +26,7 @@ def run_dist(rank, world_size, port):
# the mesh is in the following topo # the mesh is in the following topo
# [[0, 1], # [[0, 1],
# [2, 3]] # [2, 3]]
physical_mesh_id = torch.arange(0, 4) physical_mesh_id = torch.arange(0, 4).reshape(2, 2)
mesh_shape = (2, 2) mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape) device_mesh = DeviceMesh(physical_mesh_id, mesh_shape)
row_id = rank // 2 row_id = rank // 2
......
...@@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec ...@@ -5,7 +5,7 @@ from colossalai.tensor.sharding_spec import ShardingSpec, _DimSpec
def test_sharding_spec(): def test_sharding_spec():
physical_mesh_id = torch.arange(0, 16) physical_mesh_id = torch.arange(0, 16).reshape(2, 8)
mesh_shape = (4, 4) mesh_shape = (4, 4)
# [[0, 1, 2, 3], # [[0, 1, 2, 3],
# [4, 5, 6, 7], # [4, 5, 6, 7],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment