Unverified Commit ad933b34 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] ShardedDataParallel with autoreduce (#157)

* rewrite using autograd and Variable execution queue to make the reduce automatic
* share buckets with OSS to remove duplication
* some speed still likely on the table since the speed vs. bucketing does not match expectations, could be a follow up
parent 35d4129f
......@@ -124,10 +124,12 @@ run_oss_benchmark: &run_oss_benchmark
python benchmarks/oss.py --check_regression --world_size 4 --reference_speed 760 --reference_memory 1120 --reference_loss 0.023
run_oss_gloo: &run_oss_gloo
- run:
name: Run OSS with Gloo
command: |
python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 3
- run:
name: Run OSS with Gloo
command: |
python benchmarks/oss.py --gloo --optim_type oss_ddp --epochs 2
python benchmarks/oss.py --gloo --optim_type oss_sharded_ddp --epochs 2
run_oss_amp: &run_oss_amp
- run:
......
......@@ -97,19 +97,10 @@ def train(
scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else ShardedGradScaler()) if args.amp else None
if optim_type == OptimType.oss_sharded_ddp:
model = ShardedDDP(
model,
optimizer=OPTIM,
optimizer_params={"lr": 1e-4, "momentum": 0.9},
world_size=args.world_size,
broadcast_buffers=True,
)
optimizer = model.sharded_optimizer
optimizer = OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
model = ShardedDDP(model, optimizer)
else:
if args.cpu:
device_ids = None
else:
device_ids = [rank]
device_ids = None if args.cpu else [rank]
model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore
optimizer = (
OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
......@@ -120,6 +111,7 @@ def train(
# Reset the memory use counter
if not args.cpu:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(rank)
torch.cuda.synchronize(rank)
......@@ -159,9 +151,6 @@ def train(
loss = loss_fn(outputs, data["label"])
loss.backward()
if optim_type == OptimType.oss_sharded_ddp:
model.reduce()
if args.debug and rank == 0 and next(model.parameters()).grad is not None:
logging.debug(
"after BW: param {} -- grad {}".format(
......
......@@ -8,3 +8,4 @@ API Reference
optim/oss
optim/grad_scaler
nn/pipe
nn/sharded_ddp
ShardedDataParallel
====================
.. autoclass:: fairscale.nn.ShardedDataParallel
:members:
:undoc-members:
......@@ -3,7 +3,8 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from .data_parallel import ShardedDataParallel
from .moe import MOELayer, Top2Gate
from .pipe import LazyModule, Pipe, PipeRPCWrapper
__all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule"]
__all__ = ["Pipe", "PipeRPCWrapper", "Top2Gate", "LazyModule", "ShardedDataParallel"]
......@@ -4,234 +4,283 @@
# LICENSE file in the root directory of this source tree.
"""
A distributed data parallel class that works with OSS optimizer.
Adopted from LegacyDistributedDataParallel module from fairseq.
A nn.Module wrapper to go with a Sharded Optimizer in order to handle targeted gradient
reduction automatically.
"""
from contextlib import contextmanager
import copy
from typing import Any, Dict, Generator, List, Type, cast
import contextlib
from itertools import chain
import logging
from typing import Any, Callable, Generator, List, Tuple, Union
import torch
from torch import Tensor, nn
from torch import nn
from torch.autograd import Variable
import torch.distributed as dist
from torch.nn import Parameter
from fairscale.optim import OSS
from fairscale.optim.utils import Workhandle
class ShardedDataParallel(nn.Module):
"""Implements distributed data parallel training with optimizer state sharding.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and optionally
broadcast buffers.
Args:
module (~torch.nn.Module): module to be parallelized
optimizer (~torch.optim.Optimizer): optimizer to be used for training
optimizer_params(Dict): extra parameters for the optimizer
world_size (int): number of parallel workers
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function. (default: ``True``)
process_group (optional): the c10d process group to be used for
distributed gradient reduction. If None, the default WORLD process group
will be used.
buffer_size (int, optional): number of elements to buffer before
performing reduce (default: 512k). Used to reduce multiple small
params to avoid communication overhead.
"""
Wrap the model, and reduce the gradients to the right rank during the backward pass.
- the partition is given by the sharded optimizer
- wrap the base model with a model which knows where to reduce each gradient
- add an autograd function which calls the model grad dispatch on the way back
Args:
module (nn.Module):
model to be wrapped
sharded_optimizer (OSS, or list of OSS):
the sharded optimizer(s) which will decide the gradient partitioning
Keyword Args:
process_group (torch.nn.Optimizer):
Optimizer to shard (default: SGD)
process_group (group):
torch.distributed group (default: group.WORLD)
broadcast_buffers (bool):
Whether to additionally broadcast model buffers in between ranks at the beginning of each forward pass.
Same setting as in Pytorch DDP, this is in addition to the broadcast and reduction of the model parameters.
sync_models_at_startup (bool):
Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
or the training restarts from a saved state
"""
def __init__(
self,
module: nn.Module,
optimizer: Type[torch.optim.Optimizer],
optimizer_params: Dict[str, Any],
world_size: int,
broadcast_buffers: bool,
sharded_optimizer: Union[OSS, List[OSS]],
process_group: Any = None,
buffer_size: int = 2 ** 19,
broadcast_buffers: bool = True,
sync_models_at_startup: bool = True,
):
super().__init__()
self.module = module
self.world_size = world_size
self.sharded_optimizers = [sharded_optimizer] if isinstance(sharded_optimizer, OSS) else sharded_optimizer
self.enable_broadcast_buffers = broadcast_buffers
# Handle a no_sync() context which prevents the gradient synchronization,
# accumulate in place
self.should_accumulate_grads = False
# Communication related attributes
self.process_group = process_group if process_group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.process_group)
self.reference_global_rank = OSS.get_global_rank(self.process_group, 0) # picking rank 0 as the reference
self.rank = dist.get_rank(self.process_group)
self.broadcast_buffers = broadcast_buffers
self.authoritative_rank = 0
self.global_rank = OSS.get_global_rank(self.process_group, self.rank)
# Expose some of the PytorchDDP attributes, some frameworks rely on them.
# See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel
# device_id related logic is not present, this is not handled
devices = {p.device for p in self.module.parameters()}
self.is_multi_device_module = len(devices) > 1
self.device = list(devices)[0]
distinct_device_types = {p.device.type for p in self.module.parameters()}
assert len(distinct_device_types) == 1, (
"ShardedDataParallel's input module must be on "
"the same type of devices, but input module parameters are located on {} different device types."
).format(distinct_device_types)
self.device_type = list(distinct_device_types)[0]
# Scafolding to be able to reduce the grads during the BW pass
# several optimizers can be present each working on seperate parameter sets,
# we build an iterator which goes through all the parameters involved globally
self._param_iterator = chain(*[optim.should_bucket_param.keys() for optim in self.sharded_optimizers])
self._grad_to_be_reduced = [True for _ in self._param_iterator]
self._grad_accs: List[Callable] = []
self._setup_backward_hooks()
# Make sure that all ranks start with the same model
if sync_models_at_startup:
self._sync_params_and_buffers()
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
"""
Module forward pass, handles any DDP-specific work in the background. Primes the
backward pass for gradient reduction to the proper ranks.
"""
if self.enable_broadcast_buffers:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
self.sync_buffers(blocking=True)
# Flag used to make sure we only reduce gradients one time in the execution engine
self.need_reduction = False
# Normal FW on the base model
return self.module(*inputs, **kwargs)
# We can also forcibly accumulate grads locally and only do the
# gradients-reduce at some later time
self.accumulate_grads = False
def reduce(self) -> None:
""" .. deprecated:: 0.0.4
# Build the sharded optimizer
self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params)
This does not need to be called, the gradient reduction is done automatically during the BW pass
"""
logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass")
def _sync_params_and_buffers(self) -> None:
"""
Sync the complete model states in between the ranks
"""
with torch.no_grad():
work_handles = [
dist.broadcast(t, src=self.reference_global_rank, group=self.process_group, async_op=True)
for t in self.module.state_dict().values()
]
# Allocate reduce buffers
# - Never use a bigger buffer than the number of model params
buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
self._reduce_buffers: Dict[torch.device, List[torch.Tensor]] = {}
_ = list(map(lambda x: x.wait(), work_handles))
# - One buffer per rank per device
for device, per_device in self.sharded_optimizer.per_device_params.items():
buffer_dtype = per_device[0][0].dtype
self._reduce_buffers[device] = [
torch.zeros(buffer_size, dtype=buffer_dtype, device=device) for _ in range(len(per_device))
def sync_buffers(self, blocking: bool = False) -> None:
"""
Sync all the param buffers in between ranks (including for instance batch norm statistics).
"""
with torch.no_grad():
work_handles = [
dist.broadcast(buffer.data, self.reference_global_rank, self.process_group, async_op=True)
for buffer in self.module.buffers(recurse=True)
]
# Sanity checks
assert len(self.sharded_optimizer.param_to_rank) == len(
list(self.module.parameters())
), "number of params do not match"
for param in self.module.parameters():
assert param in self.sharded_optimizer.param_to_rank, f"{param} not in the optimizer"
def __getstate__(self) -> Dict:
attrs = copy.copy(self.__dict__)
return attrs
@property
def optimizer(self) -> torch.optim.Optimizer:
return self.sharded_optimizer
def train(self, mode: bool = True) -> "ShardedDataParallel":
pre_mode = self.module.training
self.module.train(mode)
if self.module.training:
assert not self.need_reduction or pre_mode, "incorrect state transition"
else:
assert not self.need_reduction, "try to enter eval with grads unreduced"
return self
@contextmanager
if blocking:
_ = list(map(lambda x: x.wait(), work_handles))
@contextlib.contextmanager
def no_sync(self) -> Generator:
"""A context manager to disable gradient synchronization."""
old_accumulate_grads = self.accumulate_grads
self.accumulate_grads = True
old_should_accumulate_grads = self.should_accumulate_grads
self.should_accumulate_grads = True
yield
self.accumulate_grads = old_accumulate_grads
self.should_accumulate_grads = old_should_accumulate_grads
def forward(self, *inputs: Any, **kwargs: Any) -> Tensor:
if self.module.training:
if self.need_reduction:
raise RuntimeError("OssDdp requires explicit reduction, must call OssDdp.reduce")
if not self.accumulate_grads:
self.need_reduction = True
if self.broadcast_buffers and len(list(self.module.buffers())) > 0:
self._sync_buffers()
def _find_rank(self, param: Parameter) -> Tuple[OSS, int]:
""" Look up where this parameter belongs to """
for optim in self.sharded_optimizers:
if param in optim.param_to_rank.keys():
return optim, optim.param_to_rank[param]
return self.module(*inputs, **kwargs)
assert False, "This parameter is not present in an optimizer, this should not happen"
return (None, -1)
def reduce(self) -> None:
def _get_reduce_fn(
self, index: int, param: torch.Tensor, should_bucket: bool, dst_rank: int, optimizer: OSS
) -> Callable:
"""
This function must be called explicitly after backward to reduce
gradients. There is no automatic hook like c10d.
Two possible backward hooks for a given parameter: either directly reduce to the appropriate rank,
or contribute to a bucket and reduce when the bucket is full.
Either way a delayed action is necessary and is passed as a callback.
"""
assert self.module.training, "Cannot call reduce in eval"
if not self.need_reduction or self.accumulate_grads:
return
def gatekeeper() -> None:
# Make sure that all the asynchronous calls have concluded before moving on. Consume the futures
# and execute the delayed actions (release gradients, unroll the buckets)
Variable._execution_engine.queue_callback(optimizer._consume_work_handles)
# Reset all the grad reduce and bucket state flags
self._grad_to_be_reduced = [True] * len(self._grad_to_be_reduced)
def reduce_direct(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
param.grad /= self.world_size
# Future work includes clearing up the buffer if possible
def cleanup() -> None:
if dst_rank != self.global_rank:
param.grad = None
# Async reduce for this buffer, log the future
optimizer.work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=param.grad.data, dst=dst_rank, group=self.process_group, async_op=True
),
callback=cleanup,
)
)
self.need_reduction = False
# If all the reduce operations have been called, add the gatekeeper
if len(optimizer.work_handles) == optimizer._max_work_handles:
gatekeeper()
with torch.no_grad():
for device, per_device in self.sharded_optimizer.per_device_params.items():
self._reduce_grads_task(
self._reduce_buffers[device],
per_device,
group=self.process_group,
self_rank=self.rank,
world_size=self.world_size,
)
# Bucket, update status, and possibly unroll the results
def reduce_bucket(*_: Any) -> None:
# Skip gradient reduction, do not alter status flags
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
@staticmethod
def _reduce_grads_task(
buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]], group: Any, self_rank: int, world_size: int
) -> None:
"""Helper to reduce a list of params. The params are sorted by size, smallest first, which allows for
an opportunistic bucketing.
NOTE: All param gradients are assumed to exist"""
buffer_size = buffers[0].numel()
bucket_requests = []
requests = []
for (rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size
if len(params) == 0:
continue
for p in params:
if p.grad is None:
p.grad = torch.zeros_like(p)
global_rank = OSS.get_global_rank(group, rank)
# Copy small gradients into per-GPU buffers and then async reduce
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
buffer[offset:end].copy_(params[i_bucketed].grad.data.view(-1)) # type: ignore
offset = end
i_bucketed += 1
if i_bucketed > 0:
buffer.div_(world_size)
bucket_requests.append(
(
dist.reduce(tensor=buffer, dst=global_rank, group=group, async_op=True), # type: ignore
rank,
)
# Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False
# Copy to the flat buffer, update the buffer state
bucket = optimizer.buckets[param.device][dst_rank]
assert bucket.append(param, use_gradient=True), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size,
bucket.current_offset,
param.grad.numel(),
)
# Directly reduce the other grads
for p in params[i_bucketed:]:
p.grad = cast(Tensor, p.grad)
if p.grad.requires_grad:
raise RuntimeError("DistributedDataParallel only works with gradients that don't require grad")
if bucket.full():
p.grad.div_(world_size)
requests.append(dist.reduce(tensor=p.grad, dst=global_rank, group=group, async_op=True)) # type: ignore
def unwrap() -> None:
for flat in bucket.params:
if dst_rank != self.global_rank:
# this rank is not the owner, release the grad
flat.param.grad = None
else:
# this rank is the owner, unroll the results
assert flat.param.grad is not None
# Unroll the initial packed small gradients, as soon as possible
for future, rank in bucket_requests:
future.wait()
flat.param.grad.data.copy_(
bucket.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
if rank == self_rank:
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
params = per_rank_params[rank]
buffer = buffers[rank]
bucket.reset()
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
params[i_bucketed].grad.data.copy_(buffer[offset:end].view_as(params[i_bucketed])) # type: ignore
offset = end
i_bucketed += 1
bucket.buffer /= self.world_size
# Make sure that we're done with this device before moving on and cleaning the unused params
_ = list(map(lambda x: x.wait(), requests))
optimizer.work_handles.append(
Workhandle(
handle=dist.reduce(
tensor=bucket.buffer, dst=dst_rank, group=self.process_group, async_op=True,
),
callback=unwrap,
)
)
# If all the reduce operations have been called, add the gatekeeper
if len(optimizer.work_handles) == optimizer._max_work_handles:
gatekeeper()
def _sync_buffers(self) -> None:
return reduce_bucket if should_bucket else reduce_direct
def _setup_backward_hooks(self) -> None:
"""
Sync all the param buffers in between ranks.
TODO: Could be worth bucketing ?
Attach a reduce function to each grad-requiring parameter.
This makes the gradient reduction automatic whenever there's a backward pass
"""
_ = list(
map(
lambda x: x.wait(),
map(
lambda x: dist.broadcast(x, self.authoritative_rank, self.process_group, async_op=True),
self.module.buffers(),
),
)
)
# Go through the parameters, attach the hook
for sharded_optimizer in self.sharded_optimizers:
for param, should_bucket in sharded_optimizer.should_bucket_param.items():
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = sharded_optimizer.param_to_rank[param]
index = len(self._grad_accs)
grad_acc.register_hook(self._get_reduce_fn(index, param, should_bucket, dst_rank, sharded_optimizer))
self._grad_accs.append(grad_acc) # keep this function in scope
......@@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, Optional
from typing import Dict
import torch
from torch.cuda.amp import GradScaler as TorchGradScaler
......@@ -32,15 +32,15 @@ class ShardedGradScaler(TorchGradScaler):
def __init__(self) -> None:
super().__init__()
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any) -> Optional[float]:
def unscale_(self, optimizer: Optimizer) -> None:
assert isinstance(optimizer, OSS), "ShardedGradScaler is to be used in combination with a sharded optimizer"
# Re-use the GradSCaler machinery, but make sure that the status is sync'ed in between the ranks
# Call the upstream unscale_ method which will only act on this rank's gradients
super().unscale_(optimizer)
# Synchronize the detected inf across the ranks
optimizer_state = self._per_optimizer_states[id(optimizer)]
handles = [dist.all_reduce(v, async_op=True) for v in optimizer_state["found_inf_per_device"].values()]
# Make sure that the calls are done before moving out
_ = list(map(lambda x: x.wait(), handles))
# Call Torch's GradScaler in turn, states have been synchronized across ranks
return super().step(optimizer, *args, **kwargs)
......@@ -16,7 +16,7 @@ import torch.distributed as dist
from torch.nn import Parameter
from torch.optim import SGD, Optimizer
from .utils import broadcast_object, recursive_copy_to_device
from .utils import Bucket, Workhandle, broadcast_object, recursive_copy_to_device
__all__ = ["OSS"]
......@@ -73,7 +73,7 @@ class OSS(Optimizer):
super().__init__(params, default)
self.in_super_constructor = False
# Partition information. lazy evaluation, computed if requested
# Partition information. lazy evaluation, computed when requested
self._per_device_params: Dict[torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = []
......@@ -88,22 +88,26 @@ class OSS(Optimizer):
# - Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v
for key, value in local_group.items():
if key != "params":
global_group[key] = value
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []
# Current default device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device
self._broadcast_buffers: Dict[torch.device, List[torch.Tensor]] = {}
self.buckets: Dict[torch.device, List[Bucket]] = {}
for device, per_device in self.per_device_params.items():
# Allocate one buffer per rank and per device to group the small parameters
self._broadcast_buffers[device] = [
torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device)
self.buckets[device] = [
Bucket(buffer=torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device))
for _ in range(len(per_device))
]
self.should_bucket_param: Dict[torch.Tensor, bool] = {}
self.work_handles: List[Workhandle] = []
self._max_work_handles = -1
self._setup_bucket_strategy()
# Partition helpers
def partition_parameters(self) -> List[List[dict]]:
......@@ -150,9 +154,9 @@ class OSS(Optimizer):
self._per_device_params[device][self.param_to_rank[param]] += [param]
# Sort param_lists by size
for k in self._per_device_params.keys():
for r in self._per_device_params[k]:
r.sort(key=lambda x: x.numel())
for device in self._per_device_params.keys():
for rank_params in self._per_device_params[device]:
rank_params.sort(key=lambda x: x.numel())
return self._per_device_params
......@@ -164,6 +168,9 @@ class OSS(Optimizer):
for param_group in param_groups:
for param in param_group["params"]:
self._param_rank[param] = rank
logging.debug("ZeRO: Parameters dispatched to ranks %s " % list(self._param_rank.values()))
return self._param_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
......@@ -181,20 +188,16 @@ class OSS(Optimizer):
self._sync_param_groups()
# Run the optimizer step on this shard only:
self._free_other_grads()
if closure is not None:
loss = self.optim.step(closure=closure, **kwargs) # type: ignore
else:
loss = self.optim.step(**kwargs)
# Depending on the DDP engine used, gradients specific to other ranks may still be loaded
self._free_other_grads()
# Sync all the updated shards in between the ranks
with torch.no_grad():
for (
device,
device_params,
) in self.per_device_params.items(): # all the params on this device (inc all ranks)
self._broadcast_params(self._broadcast_buffers[device], device_params)
self._broadcast_params()
# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._sync_param_groups(local_to_global=True)
......@@ -489,61 +492,107 @@ class OSS(Optimizer):
for t in p["params"]:
t.grad = None
def _broadcast_params(self, buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]]) -> None:
def _broadcast_params(self) -> None:
"""Helper function to broadcast all the parameters from a given device"""
buffer_size = buffers[0].numel()
bucket_requests = []
direct_requests = []
# Bucket and issue all the async calls
for (src_rank, params), buffer in zip(enumerate(per_rank_params), buffers):
global_src_rank = self.get_global_rank(self.group, src_rank)
# Copy small parameters into per-GPU buffers and then async broadcast
offset = 0
bucket_sent = False
bucket_params = []
# All the params are sorted per rank and per increasing size
for p in params:
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
if not bucket_sent and offset + p.numel() < buffer_size:
end = offset + p.numel()
buffer[offset:end].copy_(p.data.view(-1))
bucket_params.append((p, offset, end))
offset = end
else:
if offset > 0 and not bucket_sent:
bucket_requests.append(
(
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
src_rank,
bucket_params,
)
# The unroll callback is called when the broadcast is done.
# If this rank is a recipiendary and the call was bucketed, the results from the broadcast are unrolled
# onto the corresponding parameters.
def get_unroll_callback(src_rank: int, bucket: Bucket) -> Callable:
def unroll() -> None:
if src_rank != self.rank:
for flat in bucket.params:
flat.param.data.copy_(
bucket.buffer[flat.start : flat.stop].view_as(flat.param.data), non_blocking=True
)
bucket_sent = True
bucket.reset()
direct_requests.append(
dist.broadcast(tensor=p.data, src=global_src_rank, group=self.group, async_op=True)
)
return unroll
# Catch a trailing bucket
if not bucket_sent:
bucket_requests.append(
(
dist.broadcast(tensor=buffer, src=global_src_rank, group=self.group, async_op=True),
src_rank,
bucket_params,
)
)
with torch.no_grad():
for (
device,
device_params,
) in self.per_device_params.items(): # all the params on this device (inc all ranks)
buckets = self.buckets[device]
# Unroll the initial packed small parameters
for work_handle, src_rank, bucket_params in bucket_requests:
work_handle.wait()
if src_rank != self.rank:
for p, offset, end in bucket_params:
p.data.copy_(buffers[src_rank][offset:end].view_as(p.data))
# Bucket and issue all the async calls
for (src_rank, params), bucket in zip(enumerate(device_params), buckets):
global_src_rank = self.get_global_rank(self.group, src_rank)
for param in params:
# Bucket broadcast
if self.should_bucket_param[param]:
assert bucket.append(param), "Bucket overflow: max %s - current %s - adding %s" % (
bucket.max_size,
bucket.current_offset,
param.numel(),
)
if bucket.full():
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=bucket.buffer, src=global_src_rank, group=self.group, async_op=True
),
callback=get_unroll_callback(src_rank, bucket),
)
)
# Direct
else:
self.work_handles.append(
Workhandle(
handle=dist.broadcast(
tensor=param.data, src=global_src_rank, group=self.group, async_op=True
),
callback=None,
)
)
self._consume_work_handles()
def _consume_work_handles(self) -> None:
""" Consume all the futures which are tied to this optimizer's buckets.
We start from the first/older ones, since they are the most likely to be ready and non-blocking
"""
for work_handle in self.work_handles:
work_handle.handle.wait()
if work_handle.callback is not None:
work_handle.callback()
self.work_handles.clear()
def _setup_bucket_strategy(self) -> None:
""" Tag parameters to either bucket them or broadcast/reduce them directly. The parameters are ordered
(smallest first), the bucket will hold the smallest elements, the remaining ones will be directly sent
over the wire.
Generating the partition once and for all allows us to save some time at runtime, and to know when all the
network requests have been issued.
"""
# Unroll all the async work items, just in case
_ = list(map(lambda x: x.wait(), direct_requests))
for device, per_rank_params in self.per_device_params.items():
for dst_rank, params in enumerate(per_rank_params):
offset = 0
bucket_size = self.buckets[device][dst_rank].max_size
for param in params:
if (offset + param.numel()) < bucket_size:
# This parameter is small enough to fit in the remaining size of the bucket
self.should_bucket_param[param] = True
offset += param.numel()
else:
# The parameters are sorted by size, so all the following parameters
# will be too big and can be skipped
self.should_bucket_param[param] = False
# Register the max offset for this buffer
self.buckets[device][dst_rank].max_offset = offset
# Determine the max work handles in flight:
# - all the direct reduce/broadcast + 1 bucket
self._max_work_handles = sum(not value for value in self.should_bucket_param.values()) + 1
......@@ -4,13 +4,70 @@
# LICENSE file in the root directory of this source tree.
import io
from typing import Any, Dict
from typing import Any, Callable, Dict, List, Optional
import torch
from torch._six import container_abcs
import torch.distributed as dist
class Workhandle:
def __init__(self, handle: Any, callback: Optional[Callable]) -> None:
self.handle = handle
self.callback = callback
class FlatParam:
def __init__(self, tensor: torch.Tensor, start: int, stop: int) -> None:
self.param = tensor
self.start = start
self.stop = stop
class Bucket:
"""
Helper class to simplify the handling of broadcast or reduce buckets
"""
def __init__(self, buffer: torch.Tensor) -> None:
# The actual flat tensor
self.buffer = buffer
self.max_size = buffer.numel()
# Handles to the params and their position in this tensor, can be useful for a callback
self.params: List[FlatParam] = []
# Current status for this buffer
self.current_offset = 0
self.max_offset = 0
def reset(self) -> None:
""" empty the bucket """
self.current_offset = 0
self.params.clear()
def append(self, tensor: torch.Tensor, use_gradient: bool = False) -> bool:
""" add a tensor to the bucket """
end = self.current_offset + tensor.numel()
if end > self.max_size:
return False
if use_gradient:
assert tensor.grad is not None
data_source = tensor.grad.data if use_gradient else tensor.data # type: ignore # mypy is drunk
self.buffer[self.current_offset : end].copy_(data_source.view(-1))
self.params.append(FlatParam(tensor=tensor, start=self.current_offset, stop=end))
self.current_offset = end
return True
def full(self) -> bool:
""" is the bucket full ? """
return self.current_offset == self.max_offset
# Credits: classy_vision/generic/distributed_util.py
def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any:
"""
......
......@@ -324,7 +324,7 @@ class Tensor:
def coalesce(self) -> Tensor: ...
def conj(self) -> Tensor: ...
def contiguous(self) -> Tensor: ...
def copy_(self, other: Tensor) -> None: ...
def copy_(self, other: Tensor, non_blocking: Optional[_bool]=False) -> None: ...
def cos(self) -> Tensor: ...
def cos_(self) -> Tensor: ...
def cosh(self) -> Tensor: ...
......
......@@ -12,3 +12,4 @@ class GradScaler(object):
def _unscale_grads_(self, optimizer: Optimizer, inv_scale: Tensor, found_inf: Tensor, allow_fp16: bool) -> Dict[device, Tensor]:...
def step(self, optimizer: Optimizer, *args: Any, **kwargs: Any): ...
def update(self, new_scale: Optional[float]=None): ...
def unscale_(self, optimizer: Optimizer) -> None: ...
......@@ -28,8 +28,10 @@ class ReduceOp:
def get_rank(group: Any = None) -> int: ...
def get_world_size(group: Any = None) -> int: ...
def get_backend(group: Optional[Any] = None) -> Any: ...
def broadcast(tensor: Tensor, src: Any, group: Any, async_op: Any = False): ...
def gather(tensor: Tensor, gather_list: Optional[List[Tensor]], dst: Any, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def reduce(tensor: Tensor, dst: Any, op: Optional[Any]=ReduceOp.SUM, group:Optional[ProcessGroup] = None, async_op: Optional[bool] = False): ...
def is_initialized() -> bool: ...
......
......@@ -8,7 +8,9 @@ Testing OssDdp class.
"""
import tempfile
from typing import List
import numpy as np
import pytest
import torch
import torch.distributed as dist
......@@ -16,18 +18,20 @@ import torch.multiprocessing as mp
from torch.nn import Linear, Sequential
from fairscale.nn.data_parallel import ShardedDataParallel
from fairscale.optim import OSS
skip_if_no_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required")
skip_if_single_gpu = pytest.mark.skipif(torch.cuda.device_count() < 2, reason="multiple GPUs required")
from contextlib import suppress
def test_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"))
def test_step_on_cpu():
run_test(backend=dist.Backend.GLOO, device=torch.device("cpu"), world_size=4)
@skip_if_no_cuda
@skip_if_single_gpu
def test_on_gpu():
def test_step_on_gpu():
run_test(backend=dist.Backend.NCCL, device=torch.device("cuda"))
......@@ -37,46 +41,78 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3)).to(device)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
def weights_init(m):
if isinstance(m, Linear):
torch.nn.init.constant_(m.weight.data, 1.0)
torch.nn.init.constant_(m.bias.data, 1.0)
model.apply(weights_init)
model.to(device)
ddp = ShardedDataParallel(
module=model,
optimizer=torch.optim.SGD,
optimizer_params={"lr": 0.01, "momentum": 0.99},
world_size=world_size,
broadcast_buffers=True,
)
optimizer = ddp.optimizer
model = ddp.module
# Different input per rank, allows for checking that the gradients have been properly reduced
input_tensor = (torch.ones((64, 2)) * rank).to(device)
output = ddp(input_tensor).abs().sum()
output.backward()
ddp.reduce()
# Check that all the grads have been populated, for the shard
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param.shape == torch.Size([3, 2]):
assert param.grad[0, 0].cpu() == torch.tensor([32.0])
if param.shape == torch.Size([3]):
assert param.grad[0].cpu() == torch.tensor([64.0])
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
for b in model.buffers():
assert b.cpu().item() == 0.0
torch.manual_seed(rank)
np.random.seed(rank)
def check(broadcast_buffers: bool, grad_accumulation: bool = False) -> None:
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, broadcast_buffers=broadcast_buffers)
def check_same_model_params(same_params: bool):
# Check that all the params are the same on all ranks
# This should be true with and without broadcast_buffers, we don't have any real buffer here
receptacle: List[torch.Tensor] = []
if dist.get_backend() != "nccl":
for pg in optimizer.param_groups:
for p in pg["params"]:
# Check the params
receptacle = [p.clone() for _ in range(world_size)] if rank == 0 else []
dist.gather(p, receptacle, dst=0)
if rank == 0:
for sync_p in receptacle[1:]:
if same_params:
assert torch.all(torch.eq(receptacle[0], sync_p)), "Models differ in between ranks"
else:
assert not torch.all(
torch.eq(receptacle[0], sync_p)
), "Gradients should not have been synced"
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
if broadcast_buffers:
for b in ddp_model.buffers():
receptacle = [b.clone() for _ in range(world_size)] if rank == 0 else []
dist.gather(b, receptacle, dst=0)
if rank == 0:
for sync_b in receptacle[1:]:
if same_params:
assert torch.all(torch.eq(receptacle[0], sync_b)), "Models differ in between ranks"
else:
assert not torch.all(
torch.eq(receptacle[0], sync_b)
), "Gradients should not have been synced"
assert b.cpu().item() == 0.0
# The model should be synchronized in between the ranks at ShardedDataParallel construction time, check that
check_same_model_params(same_params=True)
# Optim loop
def closure():
optimizer.zero_grad()
with ddp_model.no_sync() if grad_accumulation else suppress():
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
# when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or grad_accumulation
check_same_model_params(same_params=same_params)
check(broadcast_buffers=False)
check(broadcast_buffers=True)
check(broadcast_buffers=False, grad_accumulation=True)
check(broadcast_buffers=True, grad_accumulation=True)
dist.destroy_process_group()
......@@ -85,33 +121,116 @@ def run_test(backend, device, world_size=2):
mp.spawn(run_one_step, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_eval_mode(_unused):
""" Testing eval mode make sure this is no asserts. """
dist.init_process_group(
init_method=f"file://{tempfile.mkstemp()[1]}", backend=dist.Backend.GLOO, rank=0, world_size=1
)
model = Sequential(Linear(2, 3), Linear(3, 4))
optimizer_params = {"lr": 0.1, "momentum": 0.99}
ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1, broadcast_buffers=False)
optimizer = ddp.optimizer
ddp.eval()
for _ in range(5):
input_tensor = torch.rand((64, 2))
output = ddp(input_tensor)
ddp.train()
try:
for _ in range(5):
input_tensor = torch.rand((64, 2))
output = ddp(input_tensor)
except RuntimeError:
pass
else:
assert False, "Multiple forward passes on training mode should not pass"
def run_test_two_inputs(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
model = _DoubleInput().to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
# Optim loop
def closure():
optimizer.zero_grad()
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
dist.destroy_process_group()
def test_eval_mode():
mp.spawn(run_eval_mode, args=(), join=True)
def test_inputs():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def test_ddp_attributes():
# Check that ShardedDDP exposes the same attributes as Pytorch's DDP
# - is multi_device_module
# - device_type
url = "file://" + tempfile.mkstemp()[1]
dist.init_process_group(init_method=url, backend="gloo", rank=0, world_size=1)
model = Sequential(Linear(2, 3), Linear(3, 3))
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
assert hasattr(ddp_model, "is_multi_device_module")
assert hasattr(ddp_model, "device_type")
dist.destroy_process_group()
def run_test_two_optimizers(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
torch.manual_seed(rank)
np.random.seed(rank)
class _DoubleInput(torch.nn.Module):
def __init__(self):
super().__init__()
self.mlp = Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def forward(self, x, y):
x1 = self.mlp(x)
x2 = self.mlp(y)
return torch.cat((x1, x2), dim=1)
model = _DoubleInput().to(device)
parameters = list(model.parameters())
optimizer_1 = OSS(params=parameters[:-10], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
optimizer_2 = OSS(params=parameters[-10:], optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, [optimizer_1, optimizer_2])
# Optim loop
def closure():
optimizer.zero_grad()
input_tensor = torch.rand((64, 2)).to(device)
loss = ddp_model(input_tensor, input_tensor).abs().sum()
loss.backward()
return loss
# The models should stay the same in between the ranks
for i in range(5):
_ = optimizer.step(closure=closure)
dist.destroy_process_group()
def test_two_optimizers():
# Check that the ShardedDDP wrapper accepts tuple(tensors) as inputs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cpu"
mp.spawn(run_test_two_inputs, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment