Unverified Commit 0c51ff2c authored by ver217's avatar ver217 Committed by GitHub
Browse files

[hotfix] ZeroDDP use new process group (#1333)

* process group supports getting ranks in group

* chunk mgr receives a process group

* update unit test

* fix unit tests
parent 11d1436a
...@@ -4,9 +4,8 @@ from dataclasses import dataclass ...@@ -4,9 +4,8 @@ from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Optional, Dict, List from typing import Optional, Dict, List
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
class TensorState(Enum): class TensorState(Enum):
...@@ -65,14 +64,16 @@ class Chunk: ...@@ -65,14 +64,16 @@ class Chunk:
def __init__(self, def __init__(self,
chunk_size: int, chunk_size: int,
src_rank: int, src_rank: int,
process_group: ColoProcessGroup,
dtype: torch.dtype, dtype: torch.dtype,
init_device: Optional[torch.device] = None, init_device: Optional[torch.device] = None,
force_data_on_cuda: bool = False) -> None: force_data_on_cuda: bool = False) -> None:
self.size = chunk_size self.size = chunk_size
self.utilized_size = 0 self.utilized_size = 0
self.src_rank = src_rank self.src_rank = src_rank
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank self.process_group = process_group
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank] self.is_src_rank = process_group.dp_local_rank() == src_rank
self.global_src_rank = process_group.get_ranks_in_dp()[src_rank]
self.dtype = dtype self.dtype = dtype
device = init_device or get_current_device() device = init_device or get_current_device()
if force_data_on_cuda: if force_data_on_cuda:
...@@ -150,7 +151,7 @@ class Chunk: ...@@ -150,7 +151,7 @@ class Chunk:
if not self.is_src_rank: if not self.is_src_rank:
alloc_storage(self._payload) alloc_storage(self._payload)
self.move_device(get_current_device(), update_ptr=False) self.move_device(get_current_device(), update_ptr=False)
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) dist.broadcast(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
# update tensor meta info # update tensor meta info
self._update_tensors_ptr() self._update_tensors_ptr()
...@@ -193,9 +194,9 @@ class Chunk: ...@@ -193,9 +194,9 @@ class Chunk:
""" """
self.move_device(get_current_device(), update_ptr=False) self.move_device(get_current_device(), update_ptr=False)
if is_all_reduce: if is_all_reduce:
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA)) dist.all_reduce(self.data, group=self.process_group.dp_process_group())
else: else:
dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA)) dist.reduce(self.data, self.global_src_rank, group=self.process_group.dp_process_group())
self._update_tensors_ptr() self._update_tensors_ptr()
self._update_tensors_state(TensorState.HOLD) self._update_tensors_state(TensorState.HOLD)
...@@ -216,7 +217,7 @@ class Chunk: ...@@ -216,7 +217,7 @@ class Chunk:
# invalid calls will be ignored and nothing changes # invalid calls will be ignored and nothing changes
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS: if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
# print( # print(
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}' # f'WARNING: Rank{self.process_group.rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
# ) # )
return return
self.tensors_info[tensor].state = tensor_state self.tensors_info[tensor].state = tensor_state
......
...@@ -2,9 +2,8 @@ import torch ...@@ -2,9 +2,8 @@ import torch
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
from collections import deque from collections import deque
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.tensor import ProcessGroup as ColoProcessGroup
from .chunk import Chunk, ChunkFullError, TensorState from .chunk import Chunk, ChunkFullError, TensorState
...@@ -20,10 +19,13 @@ class ChunkManager: ...@@ -20,10 +19,13 @@ class ChunkManager:
def __init__(self, def __init__(self,
chunk_size: Optional[int], chunk_size: Optional[int],
process_group: ColoProcessGroup,
enable_distributed_storage: bool = False, enable_distributed_storage: bool = False,
init_device: Optional[torch.device] = None) -> None: init_device: Optional[torch.device] = None) -> None:
assert chunk_size is None or chunk_size > 0 assert chunk_size is None or chunk_size > 0
assert isinstance(process_group, ColoProcessGroup)
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.process_group = process_group
self.enable_distributed_storage = enable_distributed_storage self.enable_distributed_storage = enable_distributed_storage
self.device = init_device or get_current_device() self.device = init_device or get_current_device()
self.chunk_groups: Dict[str, Deque[Chunk]] = {} self.chunk_groups: Dict[str, Deque[Chunk]] = {}
...@@ -69,6 +71,7 @@ class ChunkManager: ...@@ -69,6 +71,7 @@ class ChunkManager:
src_rank = self._get_next_src_rank(group_name) src_rank = self._get_next_src_rank(group_name)
chunk = Chunk(chunk_size, chunk = Chunk(chunk_size,
src_rank, src_rank,
self.process_group,
tensor.dtype, tensor.dtype,
self.device, self.device,
force_data_on_cuda=self.groups_force_data_on_cuda[group_name]) force_data_on_cuda=self.groups_force_data_on_cuda[group_name])
...@@ -89,17 +92,17 @@ class ChunkManager: ...@@ -89,17 +92,17 @@ class ChunkManager:
def _get_next_src_rank(self, group_name: str) -> int: def _get_next_src_rank(self, group_name: str) -> int:
if not self.enable_distributed_storage: if not self.enable_distributed_storage:
# the chunk is owned by the current rank if no distributed storage is enabled # the chunk is owned by the current rank if no distributed storage is enabled
return gpc.get_local_rank(ParallelMode.DATA) return self.process_group.dp_local_rank()
if self.chunk_size is None: if self.chunk_size is None:
if group_name not in self.rank_load: if group_name not in self.rank_load:
self.rank_load[group_name] = torch.zeros(gpc.get_world_size(ParallelMode.DATA), dtype=torch.int64) self.rank_load[group_name] = torch.zeros(self.process_group.dp_world_size(), dtype=torch.int64)
# the process owning the tensor will be the process with the smallest number of elements # the process owning the tensor will be the process with the smallest number of elements
src_rank = torch.argmin(self.rank_load[group_name]).item() src_rank = torch.argmin(self.rank_load[group_name]).item()
else: else:
# chunk is owned by processes in a round-robin fashion # chunk is owned by processes in a round-robin fashion
chunk_idx = len(self.chunk_groups[group_name]) chunk_idx = len(self.chunk_groups[group_name])
src_rank = chunk_idx % gpc.get_world_size(ParallelMode.DATA) src_rank = chunk_idx % self.process_group.dp_world_size()
return src_rank return src_rank
def access_chunk(self, chunk: Chunk) -> None: def access_chunk(self, chunk: Chunk) -> None:
...@@ -222,7 +225,7 @@ class ChunkManager: ...@@ -222,7 +225,7 @@ class ChunkManager:
self.lazy_release_tensors.clear() self.lazy_release_tensors.clear()
def __repr__(self) -> str: def __repr__(self) -> str:
msg = f'Rank {gpc.get_local_rank(ParallelMode.DATA)}:\n' msg = f'Rank {self.process_group.dp_local_rank()}:\n'
msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n' msg += 'Total memory: ' + ', '.join([f'{k}={v}B' for k, v in self.total_mem.items()]) + '\n'
for group_name, group in self.chunk_groups.items(): for group_name, group in self.chunk_groups.items():
msg += f'Group {group_name}:\n' msg += f'Group {group_name}:\n'
......
...@@ -118,7 +118,7 @@ class ColoDDP(torch.nn.Module): ...@@ -118,7 +118,7 @@ class ColoDDP(torch.nn.Module):
return empty_grad return empty_grad
else: else:
#TODO(jiaruifang) fixme # TODO(jiaruifang) fixme
self.process_group.set_cpu_groups() self.process_group.set_cpu_groups()
dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group()) dist.all_reduce(grad, group=self.process_group.cpu_dp_process_group())
return grad return grad
...@@ -191,11 +191,8 @@ class ZeroDDP(ColoDDP): ...@@ -191,11 +191,8 @@ class ZeroDDP(ColoDDP):
For more details, see the API reference of ``GeminiManager``. For more details, see the API reference of ``GeminiManager``.
""" """
def __init__(self, def __init__(self, module: torch.nn.Module, gemini_manager: GeminiManager) -> None:
module: torch.nn.Module, super().__init__(module.half(), process_group=gemini_manager.chunk_manager.process_group)
gemini_manager: GeminiManager,
process_group: Optional[ColoProcessGroup] = None) -> None:
super().__init__(module.half(), process_group=process_group)
self.gemini_manager = gemini_manager self.gemini_manager = gemini_manager
self.chunk_manager = gemini_manager.chunk_manager self.chunk_manager = gemini_manager.chunk_manager
self.param_op_hook = ZeROHookV2(gemini_manager) self.param_op_hook = ZeROHookV2(gemini_manager)
......
...@@ -171,3 +171,9 @@ class ProcessGroup: ...@@ -171,3 +171,9 @@ class ProcessGroup:
def cpu_tp_process_group(self): def cpu_tp_process_group(self):
assert self._has_cpu_groups assert self._has_cpu_groups
return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo') return PYTORCHPGDICT_.get(self._tp_rank_list, 'gloo')
def get_ranks_in_dp(self):
return self._dp_rank_list
def get_ranks_in_tp(self):
return self._tp_rank_list
...@@ -33,11 +33,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP: ...@@ -33,11 +33,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP: def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False) -> ZeroDDP:
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(module, 64, 2) if use_chunk else None
chunk_manager = ChunkManager(chunk_size) chunk_manager = ChunkManager(chunk_size, pg)
gemini_manager = GeminiManager('cuda', chunk_manager) gemini_manager = GeminiManager('cuda', chunk_manager)
pg = ProcessGroup() return ZeroDDP(module, gemini_manager)
return ZeroDDP(module, gemini_manager, pg)
class Net(torch.nn.Module): class Net(torch.nn.Module):
......
...@@ -28,11 +28,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP: ...@@ -28,11 +28,11 @@ def init_ddp(module: torch.nn.Module) -> ColoDDP:
def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP: def init_ddpv2(module: torch.nn.Module, use_chunk: bool = False, use_zero: bool = False) -> ZeroDDP:
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(module, 64, 4) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
gemini_manager = GeminiManager('cuda', chunk_manager) gemini_manager = GeminiManager('cuda', chunk_manager)
pg = ProcessGroup() return ZeroDDP(module, gemini_manager)
return ZeroDDP(module, gemini_manager, process_group=pg)
def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]): def run_state_dict(ddp_init_func: Callable[[torch.nn.Module], ColoDDP]):
......
...@@ -7,8 +7,7 @@ from functools import partial ...@@ -7,8 +7,7 @@ from functools import partial
from colossalai.gemini import ChunkManager from colossalai.gemini import ChunkManager
from colossalai.testing import rerun_if_address_is_in_use, parameterize from colossalai.testing import rerun_if_address_is_in_use, parameterize
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.core import global_context as gpc from colossalai.tensor import ProcessGroup as ColoProcessGroup
from colossalai.context import ParallelMode
def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]): def check_has_params(params: List[torch.Tensor], has_tensors: List[bool]):
...@@ -38,12 +37,13 @@ TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512, ...@@ -38,12 +37,13 @@ TOTAL_MEM = {True: {True: [512, 512], False: [1024, 1024]}, False: {True: [512,
@parameterize('use_chunk', [False, True]) @parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True]) @parameterize('use_zero', [False, True])
def run_chunk_zero(use_chunk, use_zero): def run_chunk_zero(use_chunk, use_zero):
rank = gpc.get_local_rank(ParallelMode.DATA) pg = ColoProcessGroup()
rank = pg.rank()
if rank == 0: if rank == 0:
print(f'use_chunk={use_chunk}, use_zero={use_zero}') print(f'use_chunk={use_chunk}, use_zero={use_zero}')
params = [torch.rand(8, 8) for _ in range(3)] params = [torch.rand(8, 8) for _ in range(3)]
chunk_size = 128 if use_chunk else None chunk_size = 128 if use_chunk else None
chunk_manager = ChunkManager(chunk_size, enable_distributed_storage=use_zero) chunk_manager = ChunkManager(chunk_size, pg, enable_distributed_storage=use_zero)
chunk_manager.create_group('param') chunk_manager.create_group('param')
assert chunk_manager.total_mem['cpu'] == 0 assert chunk_manager.total_mem['cpu'] == 0
assert chunk_manager.total_mem['cuda'] == 0 assert chunk_manager.total_mem['cuda'] == 0
......
...@@ -31,8 +31,6 @@ def check_param_equal(model, torch_model, pg: ProcessGroup): ...@@ -31,8 +31,6 @@ def check_param_equal(model, torch_model, pg: ProcessGroup):
def check_grad_equal(model, torch_model, pg: ProcessGroup): def check_grad_equal(model, torch_model, pg: ProcessGroup):
for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()): for (n, p), (tn, tp) in zip(model.named_parameters(), torch_model.named_parameters()):
if p.grad is not None: if p.grad is not None:
torch.distributed.barrier()
print(torch.distributed.get_rank(), p.grad)
assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad, assert tensor_shard_equal(tp.grad.to(dtype=p.grad.dtype, device=p.grad.device), p.grad,
pg.tp_local_rank(), pg.tp_world_size()), \ pg.tp_local_rank(), pg.tp_world_size()), \
f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}' f'{tp.grad} vs {p.grad}\n{n}:\n\t{tp.grad.shape} vs {p.grad.shape} in {pg.rank()}'
...@@ -63,9 +61,9 @@ def init_1d_col_spec(model, pg: ProcessGroup): ...@@ -63,9 +61,9 @@ def init_1d_col_spec(model, pg: ProcessGroup):
p.set_tensor_spec(*spec) p.set_tensor_spec(*spec)
@parameterize('use_chunk', [False]) @parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False]) @parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda']) @parameterize('placement_policy', ['cuda', 'cpu'])
def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
set_seed(42) set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable('gpt2') get_components_func = non_distributed_component_funcs.get_callable('gpt2')
...@@ -92,10 +90,11 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): ...@@ -92,10 +90,11 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=use_zero, enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy)) init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
model = ZeroDDP(model, gemini_manager, pg) model = ZeroDDP(model, gemini_manager)
optim = HybridAdam(model.parameters(), lr=1e-3) optim = HybridAdam(model.parameters(), lr=1e-3)
optim = ZeroOptimizer(optim, model, initial_scale=1) optim = ZeroOptimizer(optim, model, initial_scale=1)
...@@ -104,7 +103,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None): ...@@ -104,7 +103,7 @@ def run_gpt(use_chunk, use_zero, placement_policy, tp_init_spec_func=None):
torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config) torch_model, torch_optim = convert_to_apex_amp(torch_model, torch_optim, amp_config)
torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group()) torch_model = DDP(torch_model, device_ids=[pg.rank()], process_group=pg.dp_process_group())
# print(chunk_manager) print(chunk_manager)
check_param_equal(model, torch_model, pg) check_param_equal(model, torch_model, pg)
model.eval() model.eval()
...@@ -129,13 +128,12 @@ def run_dist(rank, world_size, port): ...@@ -129,13 +128,12 @@ def run_dist(rank, world_size, port):
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
if world_size == 4: if world_size == 4:
run_gpt(tp_init_spec_func=init_1d_col_spec) run_gpt(tp_init_spec_func=init_1d_col_spec)
# run_gpt(tp_init_spec_func=init_1d_row_spec) run_gpt(tp_init_spec_func=init_1d_row_spec)
else: else:
run_gpt(tp_init_spec_func=init_1d_col_spec) run_gpt(tp_init_spec_func=init_1d_col_spec)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.skip("buggy test")
@pytest.mark.parametrize('world_size', [1, 4]) @pytest.mark.parametrize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gpt(world_size): def test_gpt(world_size):
......
...@@ -20,13 +20,14 @@ from colossalai.tensor import ProcessGroup ...@@ -20,13 +20,14 @@ from colossalai.tensor import ProcessGroup
def init_zero(model, use_chunk, use_zero, placement_policy): def init_zero(model, use_chunk, use_zero, placement_policy):
pg = ProcessGroup()
chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None chunk_size = ChunkManager.search_chunk_size(model, 8192, 8) if use_chunk else None
chunk_manager = ChunkManager(chunk_size, chunk_manager = ChunkManager(chunk_size,
pg,
enable_distributed_storage=use_zero, enable_distributed_storage=use_zero,
init_device=GeminiManager.get_default_device(placement_policy)) init_device=GeminiManager.get_default_device(placement_policy))
gemini_manager = GeminiManager(placement_policy, chunk_manager) gemini_manager = GeminiManager(placement_policy, chunk_manager)
pg = ProcessGroup() return ZeroDDP(model, gemini_manager)
return ZeroDDP(model, gemini_manager, pg)
def run_step(model, optim, criterion, data, label): def run_step(model, optim, criterion, data, label):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment