Unverified Commit 351f35e1 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[perf] ShardedDDP: better handling of the callback queue, try to consume it as we go. (#254)

* Better handling of the callback queue, try to consume it as we go.

* dumping buckets for the reduce part, always the same unused params issue
parent 19cb5938
...@@ -11,7 +11,7 @@ reduction automatically. ...@@ -11,7 +11,7 @@ reduction automatically.
import contextlib import contextlib
from itertools import chain from itertools import chain
import logging import logging
from typing import Any, Callable, Generator, List, Tuple, Union from typing import Any, Callable, Dict, Generator, List, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -92,6 +92,10 @@ class ShardedDataParallel(nn.Module): ...@@ -92,6 +92,10 @@ class ShardedDataParallel(nn.Module):
# we build an iterator which goes through all the parameters involved globally # we build an iterator which goes through all the parameters involved globally
self._param_iterator = chain(*[optim.should_bucket_param.keys() for optim in self.sharded_optimizers]) self._param_iterator = chain(*[optim.should_bucket_param.keys() for optim in self.sharded_optimizers])
self._grad_to_be_reduced = [True for _ in self._param_iterator] self._grad_to_be_reduced = [True for _ in self._param_iterator]
self._reduced_grads: Dict[OSS, int] = {}
self._reduced_grads_max = {o: len(o.param_to_rank.values()) for o in self.sharded_optimizers}
self._clear_counters()
self._grad_accs: List[Callable] = [] self._grad_accs: List[Callable] = []
self._setup_backward_hooks() self._setup_backward_hooks()
...@@ -110,7 +114,7 @@ class ShardedDataParallel(nn.Module): ...@@ -110,7 +114,7 @@ class ShardedDataParallel(nn.Module):
self.sync_buffers(blocking=True) self.sync_buffers(blocking=True)
# Reset all the grad reduce and bucket state flags # Reset all the grad reduce and bucket state flags
self._grad_to_be_reduced = [True] * len(self._grad_to_be_reduced) self._clear_counters()
# Normal FW on the base model # Normal FW on the base model
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
...@@ -122,18 +126,6 @@ class ShardedDataParallel(nn.Module): ...@@ -122,18 +126,6 @@ class ShardedDataParallel(nn.Module):
""" """
logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass") logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass")
def _sync_params_and_buffers(self) -> None:
"""
Sync the complete model states in between the ranks
"""
with torch.no_grad():
work_handles = [
dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
for t in self.module.state_dict().values()
]
_ = list(map(lambda x: x.wait(), work_handles))
def sync_buffers(self, blocking: bool = False) -> None: def sync_buffers(self, blocking: bool = False) -> None:
""" """
Sync all the param buffers in between ranks (including for instance batch norm statistics). Sync all the param buffers in between ranks (including for instance batch norm statistics).
...@@ -155,6 +147,12 @@ class ShardedDataParallel(nn.Module): ...@@ -155,6 +147,12 @@ class ShardedDataParallel(nn.Module):
yield yield
self.should_accumulate_grads = old_should_accumulate_grads self.should_accumulate_grads = old_should_accumulate_grads
def _clear_counters(self) -> None:
""" Reset all the grad reduce and call counters
"""
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = {o: 0 for o in self.sharded_optimizers}
def _find_rank(self, param: Parameter) -> Tuple[OSS, int]: def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
""" Look up where this parameter belongs to """ """ Look up where this parameter belongs to """
for optim in self.sharded_optimizers: for optim in self.sharded_optimizers:
...@@ -164,9 +162,7 @@ class ShardedDataParallel(nn.Module): ...@@ -164,9 +162,7 @@ class ShardedDataParallel(nn.Module):
assert False, "This parameter is not present in an optimizer, this should not happen" assert False, "This parameter is not present in an optimizer, this should not happen"
return (None, -1) return (None, -1)
def _get_reduce_fn( def _get_reduce_fn(self, index: int, param: torch.Tensor, dst_rank: int, optimizer: OSS) -> Callable:
self, index: int, param: torch.Tensor, should_bucket: bool, dst_rank: int, optimizer: OSS
) -> Callable:
""" """
Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank, Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
or contribute to a bucket and reduce when the bucket is full. or contribute to a bucket and reduce when the bucket is full.
...@@ -174,7 +170,7 @@ class ShardedDataParallel(nn.Module): ...@@ -174,7 +170,7 @@ class ShardedDataParallel(nn.Module):
Either way a delayed action is necessary and is passed as a callback. Either way a delayed action is necessary and is passed as a callback.
""" """
def reduce_direct(*_: Any) -> None: def reduce(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags # Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]: if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None" assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
...@@ -197,50 +193,18 @@ class ShardedDataParallel(nn.Module): ...@@ -197,50 +193,18 @@ class ShardedDataParallel(nn.Module):
callback=cleanup, callback=cleanup,
) )
) )
self._reduced_grads[optimizer] += 1
# If all the reduce operations have been called, # Opportunistically try to empty the queue
# make sure that all the asynchronous calls have concluded before moving on optimizer._try_consume_work_handle()
# and execute the delayed actions (release gradients, unroll the buckets)
if len(optimizer.work_handles) == optimizer._max_work_handles:
optimizer._consume_work_handles()
# Bucket, update status, and possibly unroll the results
def reduce_bucket(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
# Copy to the flat buffer, update the buffer state
bucket = optimizer.buckets[param.device][dst_rank]
assert bucket.append(param, use_gradient=True), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size,
bucket.current_offset,
param.grad.numel(),
)
if bucket.full():
bucket.buffer /= self.world_size
optimizer.work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True,
),
callback=bucket.unroll,
)
)
# If all the reduce operations have been called, # If all the reduce operations have been called,
# make sure that all the asynchronous calls have concluded before moving on # make sure that all the asynchronous calls have concluded before moving on
# and execute the delayed actions (release gradients, unroll the buckets) # and execute the delayed actions (release gradients, unroll the buckets)
if len(optimizer.work_handles) == optimizer._max_work_handles: if self._reduced_grads[optimizer] == self._reduced_grads_max[optimizer]:
optimizer._consume_work_handles() optimizer._consume_work_handles()
return reduce_bucket if should_bucket else reduce_direct return reduce
def _setup_backward_hooks(self) -> None: def _setup_backward_hooks(self) -> None:
""" """
...@@ -250,7 +214,7 @@ class ShardedDataParallel(nn.Module): ...@@ -250,7 +214,7 @@ class ShardedDataParallel(nn.Module):
# Go through the parameters, attach the hook # Go through the parameters, attach the hook
for sharded_optimizer in self.sharded_optimizers: for sharded_optimizer in self.sharded_optimizers:
for param, should_bucket in sharded_optimizer.should_bucket_param.items(): for param, _ in sharded_optimizer.should_bucket_param.items():
if param.grad is not None and param.grad.requires_grad: if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad") raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
...@@ -262,5 +226,17 @@ class ShardedDataParallel(nn.Module): ...@@ -262,5 +226,17 @@ class ShardedDataParallel(nn.Module):
dst_rank = sharded_optimizer.param_to_rank[param] dst_rank = sharded_optimizer.param_to_rank[param]
index = len(self._grad_accs) index = len(self._grad_accs)
grad_acc.register_hook(self._get_reduce_fn(index, param, should_bucket, dst_rank, sharded_optimizer)) grad_acc.register_hook(self._get_reduce_fn(index, param, dst_rank, sharded_optimizer))
self._grad_accs.append(grad_acc) # keep this function in scope self._grad_accs.append(grad_acc) # keep this function in scope
def _sync_params_and_buffers(self) -> None:
"""
Sync the complete model states in between the ranks
"""
with torch.no_grad():
work_handles = [
dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
for t in self.module.state_dict().values()
]
_ = list(map(lambda x: x.wait(), work_handles))
...@@ -3,14 +3,14 @@ ...@@ -3,14 +3,14 @@
# This source code is licensed under the BSD license found in the # This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import OrderedDict from collections import OrderedDict, deque
import copy import copy
from enum import Enum, auto from enum import Enum, auto
import itertools import itertools
from itertools import chain from itertools import chain
import logging import logging
from math import inf from math import inf
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Any, Callable, Deque, Dict, List, Optional, Tuple, Type, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
...@@ -74,8 +74,6 @@ class OSS(Optimizer): ...@@ -74,8 +74,6 @@ class OSS(Optimizer):
broadcast_buffer_size: int = 2 ** 17, broadcast_buffer_size: int = 2 ** 17,
**default: Any, **default: Any,
): ):
# logging.warning("Disabling bucketing for now, error prone for some models")
broadcast_buffer_size = 0
# Hold all the model params in the root .param_groups # Hold all the model params in the root .param_groups
self.in_super_constructor = True self.in_super_constructor = True
...@@ -92,7 +90,6 @@ class OSS(Optimizer): ...@@ -92,7 +90,6 @@ class OSS(Optimizer):
self.world_size = dist.get_world_size(self.group) self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group) self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank) self.global_rank = self.get_global_rank(self.group, self.rank)
self.optim = optim(self.partition_parameters()[self.rank], **default) self.optim = optim(self.partition_parameters()[self.rank], **default)
# - Sync local and global param_groups keys # - Sync local and global param_groups keys
...@@ -107,6 +104,11 @@ class OSS(Optimizer): ...@@ -107,6 +104,11 @@ class OSS(Optimizer):
# Current default device is set by the parameters allocated to this rank # Current default device is set by the parameters allocated to this rank
self._device = list(self.per_device_params.keys())[0] self._device = list(self.per_device_params.keys())[0]
self.buckets: Dict[torch.device, List[Bucket]] = {} self.buckets: Dict[torch.device, List[Bucket]] = {}
# if torch.cuda.is_available() and self.world_size <= torch.cuda.device_count():
# broadcast_buffer_size = 0
# logging.warning("Assuming single node job, bucketing is disabled")
self.bucket_size = broadcast_buffer_size self.bucket_size = broadcast_buffer_size
for device, per_device in self.per_device_params.items(): for device, per_device in self.per_device_params.items():
# Allocate one buffer per rank and per device to group the small parameters # Allocate one buffer per rank and per device to group the small parameters
...@@ -115,7 +117,7 @@ class OSS(Optimizer): ...@@ -115,7 +117,7 @@ class OSS(Optimizer):
for _ in range(len(per_device)) for _ in range(len(per_device))
] ]
self.should_bucket_param: Dict[torch.Tensor, bool] = {} self.should_bucket_param: Dict[torch.Tensor, bool] = {}
self.work_handles: List[Workhandle] = [] self.work_handles: Deque[Workhandle] = deque()
self._max_work_handles = -1 self._max_work_handles = -1
self._setup_bucket_strategy() self._setup_bucket_strategy()
...@@ -554,12 +556,19 @@ class OSS(Optimizer): ...@@ -554,12 +556,19 @@ class OSS(Optimizer):
We start from the first/older ones, since they are the most likely to be ready and non-blocking We start from the first/older ones, since they are the most likely to be ready and non-blocking
""" """
for work_handle in self.work_handles: while len(self.work_handles) > 0:
work_handle = self.work_handles.popleft()
work_handle.handle.wait() work_handle.handle.wait()
if work_handle.callback is not None: if work_handle.callback is not None:
work_handle.callback() work_handle.callback()
self.work_handles.clear() def _try_consume_work_handle(self) -> None:
""" Try to consume the oldest future. This is non blocking, if not ready we'll pass
"""
while len(self.work_handles) > 0 and self.work_handles[0].handle.is_completed():
work_handle = self.work_handles.popleft()
if work_handle.callback is not None:
work_handle.callback()
def _handle_trailing_buckets(self, flush_type: BucketFlush) -> None: def _handle_trailing_buckets(self, flush_type: BucketFlush) -> None:
""" """
...@@ -594,6 +603,10 @@ class OSS(Optimizer): ...@@ -594,6 +603,10 @@ class OSS(Optimizer):
network requests have been issued. network requests have been issued.
""" """
# Determine the max work handles in flight:
# - count all the buckets on the fly
self._max_work_handles = 0
for device, per_rank_params in self.per_device_params.items(): for device, per_rank_params in self.per_device_params.items():
for dst_rank, params in enumerate(per_rank_params): for dst_rank, params in enumerate(per_rank_params):
offset = 0 offset = 0
...@@ -604,6 +617,11 @@ class OSS(Optimizer): ...@@ -604,6 +617,11 @@ class OSS(Optimizer):
# - enough room in the bucket # - enough room in the bucket
if (offset + param.numel()) < self.buckets[device][dst_rank].max_size: if (offset + param.numel()) < self.buckets[device][dst_rank].max_size:
self.should_bucket_param[param] = True self.should_bucket_param[param] = True
if offset == 0:
# count this bucket, only once
self._max_work_handles += 1
offset += param.numel() offset += param.numel()
else: else:
self.should_bucket_param[param] = False self.should_bucket_param[param] = False
...@@ -615,8 +633,4 @@ class OSS(Optimizer): ...@@ -615,8 +633,4 @@ class OSS(Optimizer):
# Determine the max work handles in flight: # Determine the max work handles in flight:
# - all the direct reduce/broadcast # - all the direct reduce/broadcast
self._max_work_handles = sum(not value for value in self.should_bucket_param.values()) self._max_work_handles += sum(not value for value in self.should_bucket_param.values())
# - if we're bucketing, this means more work handles: one per rank and per device
if self.bucket_size > 0:
self._max_work_handles += len(self.per_device_params.keys()) * self.world_size
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment