Unverified Commit 351f35e1 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf] ShardedDDP: better handling of the callback queue, try to consume it as we go. (#254)

* Better handling of the callback queue, try to consume it as we go.

* dumping buckets for the reduce part, always the same unused params issue
parent 19cb5938
......@@ -11,7 +11,7 @@ reduction automatically.
import contextlib
from itertools import chain
import logging
from typing import Any, Callable, Generator, List, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Tuple, Union
import torch
from torch import nn
......@@ -92,6 +92,10 @@ class ShardedDataParallel(nn.Module):
# we build an iterator which goes through all the parameters involved globally
self._param_iterator = chain(*[optim.should_bucket_param.keys() for optim in self.sharded_optimizers])
self._grad_to_be_reduced = [True for _ in self._param_iterator]
self._reduced_grads: Dict[OSS, int] = {}
self._reduced_grads_max = {o: len(o.param_to_rank.values()) for o in self.sharded_optimizers}
self._clear_counters()
self._grad_accs: List[Callable] = []
self._setup_backward_hooks()
......@@ -110,7 +114,7 @@ class ShardedDataParallel(nn.Module):
self.sync_buffers(blocking=True)
# Reset all the grad reduce and bucket state flags
self._grad_to_be_reduced = [True] * len(self._grad_to_be_reduced)
self._clear_counters()
# Normal FW on the base model
return self.module(*inputs, **kwargs)
......@@ -122,18 +126,6 @@ class ShardedDataParallel(nn.Module):
"""
logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass")
def _sync_params_and_buffers(self) -> None:
"""
Sync the complete model states in between the ranks
"""
with torch.no_grad():
work_handles = [
dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
for t in self.module.state_dict().values()
]
_ = list(map(lambda x: x.wait(), work_handles))
def sync_buffers(self, blocking: bool = False) -> None:
"""
Sync all the param buffers in between ranks (including for instance batch norm statistics).
......@@ -155,6 +147,12 @@ class ShardedDataParallel(nn.Module):
yield
self.should_accumulate_grads = old_should_accumulate_grads
def _clear_counters(self) -> None:
""" Reset all the grad reduce and call counters
"""
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = {o: 0 for o in self.sharded_optimizers}
def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
""" Look up where this parameter belongs to """
for optim in self.sharded_optimizers:
......@@ -164,9 +162,7 @@ class ShardedDataParallel(nn.Module):
assert False, "This parameter is not present in an optimizer, this should not happen"
return (None, -1)
def _get_reduce_fn(
self, index: int, param: torch.Tensor, should_bucket: bool, dst_rank: int, optimizer: OSS
) -> Callable:
def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int, optimizer: OSS) -> Callable:
"""
Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
or contribute to a bucket and reduce when the bucket is full.
......@@ -174,7 +170,7 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback.
"""
def reduce_direct(*_: Any) -> None:
def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
......@@ -197,50 +193,18 @@ class ShardedDataParallel(nn.Module):
callback=cleanup,
)
)
self._reduced_grads[optimizer] += 1
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if len(optimizer.work_handles) == optimizer._max_work_handles:
optimizer._consume_work_handles()
# Bucket, update status, and possibly unroll the results
def reduce_bucket(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
# Copy to the flat buffer, update the buffer state
bucket = optimizer.buckets[param.device][dst_rank]
assert bucket.append(param, use_gradient=True), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size,
bucket.current_offset,
param.grad.numel(),
)
if bucket.full():
bucket.buffer /= self.world_size
optimizer.work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True,
),
callback=bucket.unroll,
)
)
# Opportunistically try to empty the queue
optimizer._try_consume_work_handle()
# If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets)
if len(optimizer.work_handles) == optimizer._max_work_handles:
if self._reduced_grads[optimizer] == self._reduced_grads_max[optimizer]:
optimizer._consume_work_handles()
return reduce_bucket if should_bucket else reduce_direct
return reduce
def _setup_backward_hooks(self) -> None:
"""
......@@ -250,7 +214,7 @@ class ShardedDataParallel(nn.Module):
# Go through the parameters, attach the hook
for sharded_optimizer in self.sharded_optimizers:
for param, should_bucket in sharded_optimizer.should_bucket_param.items():
for param, _ in sharded_optimizer.should_bucket_param.items():
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
......@@ -262,5 +226,17 @@ class ShardedDataParallel(nn.Module):
dst_rank = sharded_optimizer.param_to_rank[param]
index = len(self._grad_accs)
grad_acc.register_hook(self._get_reduce_fn(index, param, should_bucket, dst_rank, sharded_optimizer))
grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer))
self._grad_accs.append(grad_acc) # keep this function in scope
def _sync_params_and_buffers(self) -> None:
"""
Sync the complete model states in between the ranks
"""
with torch.no_grad():
work_handles = [
dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
for t in self.module.state_dict().values()
]
_ = list(map(lambda x: x.wait(), work_handles))
......@@ -3,14 +3,14 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from collections import OrderedDict
from collections import OrderedDict, deque
import copy
from enum import Enum, auto
import itertools
from itertools import chain
import logging
from math import inf
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Type, Union
import torch
import torch.distributed as dist
......@@ -74,8 +74,6 @@ class OSS(Optimizer):
broadcast_buffer_size: int = 2 ** 17,
**default: Any,
):
# logging.warning("Disabling bucketing for now, error prone for some models")
broadcast_buffer_size = 0
# Hold all the model params in the root .param_groups
self.in_super_constructor = True
......@@ -92,7 +90,6 @@ class OSS(Optimizer):
self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank)
self.optim = optim(self.partition_parameters()[self.rank], **default)
# - Sync local and global param_groups keys
......@@ -107,6 +104,11 @@ class OSS(Optimizer):
# Current default device is set by the parameters allocated to this rank
self._device = list(self.per_device_params.keys())[0]
self.buckets: Dict[torch.device, List[Bucket]] = {}
# if torch.cuda.is_available() and self.world_size <= torch.cuda.device_count():
# broadcast_buffer_size = 0
# logging.warning("Assuming single node job, bucketing is disabled")
self.bucket_size = broadcast_buffer_size
for device, per_device in self.per_device_params.items():
# Allocate one buffer per rank and per device to group the small parameters
......@@ -115,7 +117,7 @@ class OSS(Optimizer):
for _ in range(len(per_device))
]
self.should_bucket_param: Dict[torch.Tensor, bool] = {}
self.work_handles: List[Workhandle] = []
self.work_handles: Deque[Workhandle] = deque()
self._max_work_handles = -1
self._setup_bucket_strategy()
......@@ -554,12 +556,19 @@ class OSS(Optimizer):
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
for work_handle in self.work_handles:
while len(self.work_handles) > 0:
work_handle = self.work_handles.popleft()
work_handle.handle.wait()
if work_handle.callback is not None:
work_handle.callback()
self.work_handles.clear()
def _try_consume_work_handle(self) -> None:
""" Try to consume the oldest future. This is non blocking, if not ready we'll pass
"""
while len(self.work_handles) > 0 and self.work_handles[0].handle.is_completed():
work_handle = self.work_handles.popleft()
if work_handle.callback is not None:
work_handle.callback()
def _handle_trailing_buckets(self, flush_type: BucketFlush) -> None:
"""
......@@ -594,6 +603,10 @@ class OSS(Optimizer):
network requests have been issued.
"""
# Determine the max work handles in flight:
# - count all the buckets on the fly
self._max_work_handles = 0
for device, per_rank_params in self.per_device_params.items():
for dst_rank, params in enumerate(per_rank_params):
offset = 0
......@@ -604,6 +617,11 @@ class OSS(Optimizer):
# - enough room in the bucket
if (offset + param.numel()) < self.buckets[device][dst_rank].max_size:
self.should_bucket_param[param] = True
if offset == 0:
# count this bucket, only once
self._max_work_handles += 1
offset += param.numel()
else:
self.should_bucket_param[param] = False
......@@ -615,8 +633,4 @@ class OSS(Optimizer):
# Determine the max work handles in flight:
# - all the direct reduce/broadcast
self._max_work_handles = sum(not value for value in self.should_bucket_param.values())
# - if we're bucketing, this means more work handles: one per rank and per device
if self.bucket_size > 0:
self._max_work_handles += len(self.per_device_params.keys()) * self.world_size
self._max_work_handles += sum(not value for value in self.should_bucket_param.values())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment