Unverified Commit 6b2f2ab9 authored by ver217's avatar ver217 Committed by GitHub
Browse files

[ddp] ColoDDP uses bucket all-reduce (#1177)

* add reducer

* update colo ddp with reducer

* polish unit test

* polish unit test
parent 7487215b
...@@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Optional ...@@ -12,6 +12,7 @@ from typing import Dict, Iterable, List, Optional
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from collections import OrderedDict from collections import OrderedDict
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from .reducer import Reducer
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError: except ImportError:
...@@ -61,7 +62,9 @@ class ColoDDP(torch.nn.Module): ...@@ -61,7 +62,9 @@ class ColoDDP(torch.nn.Module):
def __init__(self, def __init__(self,
module: torch.nn.Module, module: torch.nn.Module,
process_group: Optional[dist.ProcessGroup] = None, process_group: Optional[dist.ProcessGroup] = None,
cpu_process_group: Optional[dist.ProcessGroup] = None) -> None: cpu_process_group: Optional[dist.ProcessGroup] = None,
bucket_cap_mb: int = 25,
rebuild_bucket: bool = True) -> None:
assert not isinstance(module, ColoDDP) assert not isinstance(module, ColoDDP)
super().__init__() super().__init__()
self.module = module self.module = module
...@@ -69,6 +72,8 @@ class ColoDDP(torch.nn.Module): ...@@ -69,6 +72,8 @@ class ColoDDP(torch.nn.Module):
self.process_group = process_group or gpc.get_group(ParallelMode.DATA) self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA) self.cpu_process_group = cpu_process_group or gpc.get_cpu_group(ParallelMode.DATA)
self.dp_world_size = self.process_group.size() self.dp_world_size = self.process_group.size()
self.reducer = Reducer(bucket_cap_mb)
self.rebuild_bucket = rebuild_bucket
for p in module.parameters(): for p in module.parameters():
if getattr(p, '_ddp_to_ignore', False): if getattr(p, '_ddp_to_ignore', False):
continue continue
...@@ -87,7 +92,11 @@ class ColoDDP(torch.nn.Module): ...@@ -87,7 +92,11 @@ class ColoDDP(torch.nn.Module):
def backward(self, loss: torch.Tensor): def backward(self, loss: torch.Tensor):
loss.backward() loss.backward()
with torch.cuda.stream(self.comm_stream):
self.reducer.flush()
torch.cuda.current_stream().wait_stream(self.comm_stream) torch.cuda.current_stream().wait_stream(self.comm_stream)
if self.rebuild_bucket:
self.reducer.free()
for p in self.module.parameters(): for p in self.module.parameters():
if getattr(p, '_ddp_to_ignore', False): if getattr(p, '_ddp_to_ignore', False):
continue continue
...@@ -102,8 +111,9 @@ class ColoDDP(torch.nn.Module): ...@@ -102,8 +111,9 @@ class ColoDDP(torch.nn.Module):
grad = grad / self.dp_world_size grad = grad / self.dp_world_size
self.comm_stream.wait_stream(torch.cuda.current_stream()) self.comm_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self.comm_stream): with torch.cuda.stream(self.comm_stream):
dist.all_reduce(grad, group=self.process_group) self.reducer.all_reduce_async(grad,
ColoDDP._save_grad(p, grad) group=self.process_group,
callback_fn=partial(self._save_grad, p))
grad.record_stream(self.comm_stream) grad.record_stream(self.comm_stream)
else: else:
ColoDDP._save_grad(p, grad) ColoDDP._save_grad(p, grad)
......
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import functools
from typing import Callable, Dict, List, Optional, Tuple
import torch
import torch.distributed as dist
from torch import Tensor
from torch.distributed import ProcessGroup
class Bucket:
def __init__(self, size: int, dtype: torch.dtype, device: torch.device, group: ProcessGroup):
self.buffer = torch.zeros(size, dtype=dtype, device=device)
self.group = group
self.offset = 0
self.callbacks: List[Callable] = []
def flush(self) -> None:
"""Flush content of the bucket."""
if self.offset == 0:
assert len(self.callbacks) == 0
return
# reduce-scatter bucket
dist.all_reduce(self.buffer[:self.offset], group=self.group)
# execute post-reduction callbacks
for callback_fn in self.callbacks:
callback_fn()
# reuse input bucket but allocate a fresh output shard
self.offset = 0
self.callbacks.clear()
self.buffer = torch.zeros_like(self.buffer)
def alloc(self) -> None:
if self.buffer.storage().size() == 0:
self.buffer.storage().resize_(self.buffer.numel())
def free(self) -> None:
assert self.offset == 0 and self.callbacks == [], "Incorrect call of teardown"
self.buffer.storage().resize_(0)
def append(self, tensor: Tensor, callback_fn: Callable):
tensor_size = tensor.numel()
offset = self.offset
self.buffer[offset:offset + tensor_size].copy_(tensor.flatten())
self.offset += tensor_size
# callback will be given the reduced result
if callback_fn is not None:
result_view = self.buffer[offset:offset + tensor_size].view(tensor.shape)
self.callbacks.append(functools.partial(callback_fn, result_view))
@property
def avail_size(self) -> int:
return self.buffer.size(0) - self.offset
class Reducer:
def __init__(self, bucket_size_mb: int = 25):
self.bucket_size_mb = bucket_size_mb
self.buckets: Dict[Tuple[torch.dtype, torch.device, ProcessGroup], Bucket] = {}
@torch.no_grad()
def all_reduce_async(
self,
tensor: Tensor,
group: ProcessGroup,
callback_fn: Optional[Callable] = None,
) -> None:
bucket_size = self._get_bucket_size(tensor.element_size())
if tensor.numel() >= bucket_size:
dist.all_reduce(tensor, group=group)
if callback_fn is not None:
callback_fn(tensor)
return
bucket = self._get_bucket(tensor, group)
if tensor.numel() > bucket.avail_size:
# not enough space remaining in bucket, flush it now
bucket.flush()
bucket.append(tensor, callback_fn)
@torch.no_grad()
def flush(self) -> None:
for bucket in self.buckets.values():
bucket.flush()
@torch.no_grad()
def free(self) -> None:
for bucket in self.buckets.values():
bucket.free()
@functools.lru_cache()
def _get_bucket_size(self, element_size: int) -> int:
if self.bucket_size_mb <= 0: # Values <= 0 disable bucketing.
return 0
MB = 1024 * 1024
bucket_size = self.bucket_size_mb * MB / element_size
return int(bucket_size)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
key = (tensor.dtype, tensor.device, group)
if key not in self.buckets:
bucket_size = self._get_bucket_size(tensor.element_size())
self.buckets[key] = Bucket(bucket_size, tensor.dtype, tensor.device, group)
self.buckets[key].alloc()
return self.buckets[key]
import pytest
import colossalai
import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ChunkManager
from functools import partial
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.nn.parallel import ZeroDDP, ColoDDP
from colossalai.gemini.gemini_mgr import GeminiManager
from typing import Callable
from collections import OrderedDict
from colossalai.nn.parallel.reducer import Reducer
import torch.distributed as dist
from torch.distributed.distributed_c10d import _get_default_group
REDUCE_CNT = 0
def check_eq(grad, grad_clone):
global REDUCE_CNT
print(f'Rank{dist.get_rank()} check {REDUCE_CNT}')
REDUCE_CNT += 1
assert torch.allclose(grad, grad_clone)
def run_reducer():
grads = [torch.rand(64, i + 1, device=get_current_device()) for i in range(10)]
grads_clone = [g.clone().detach() for g in grads]
for g in grads:
dist.all_reduce(g)
reducer = Reducer(bucket_size_mb=1)
for g, g_clone in zip(grads, grads_clone):
reducer.all_reduce_async(g_clone, _get_default_group(), partial(check_eq, g))
reducer.flush()
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_reducer()
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2])
@rerun_if_address_is_in_use()
def test_reducer(world_size):
run_func = partial(run_dist, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_reducer(2)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment