Unverified Commit 341d8b2b authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat] OSS/SDP : bucketing (#122)

Same bucketing strategy for OSS and SDP:
sort everything ahead of time, per rank and per size, smaller tensors first. Bucket the smallest elements in a fixed buffer, send async, then send all the others async, and get back to the bucket. Once done then scatter the contents if needed
parent 6e7ad798
......@@ -165,6 +165,8 @@ def train(
print("[Regression Test] VALID")
dist.destroy_process_group() # type: ignore
if __name__ == "__main__":
......@@ -246,7 +248,10 @@ if __name__ == "__main__":
backend,
True, # OSS
True, # SDP
False, # no regression check
args.check_regression,
-1, # Not checking SDP for speed regression for now, still slower than OSS
args.reference_memory,
args.reference_loss,
),
nprocs=args.world_size,
join=True,
......
......@@ -11,7 +11,7 @@ Adopted from LegacyDistributedDataParallel module from fairseq.
from contextlib import contextmanager
import copy
from typing import Any, Dict, Generator, List, Optional, Type, cast
from typing import Any, Dict, Generator, List, Type, cast
import torch
from torch import Tensor, nn
......@@ -39,7 +39,7 @@ class ShardedDataParallel(nn.Module):
distributed gradient reduction. If None, the default WORLD process group
will be used.
buffer_size (int, optional): number of elements to buffer before
performing reduce (default: 256M). Used to reduce multiple small
performing reduce (default: 512k). Used to reduce multiple small
params to avoid communication overhead.
"""
......@@ -51,7 +51,7 @@ class ShardedDataParallel(nn.Module):
world_size: int,
broadcast_buffers: bool,
process_group: Any = None,
buffer_size: int = 2 ** 28,
buffer_size: int = 2 ** 19,
):
super().__init__()
......@@ -62,10 +62,6 @@ class ShardedDataParallel(nn.Module):
self.broadcast_buffers = broadcast_buffers
self.authoritative_rank = 0
# Never use a bigger buffer than the number of model params
self.buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
self.buffer: Optional[Tensor] = None
# Flag used to make sure we only reduce gradients one time in the execution engine
self.need_reduction = False
......@@ -76,6 +72,18 @@ class ShardedDataParallel(nn.Module):
# Build the sharded optimizer
self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params)
# Allocate reduce buffers
# - Never use a bigger buffer than the number of model params
buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
self._reduce_buffers: Dict[torch.device, List[torch.Tensor]] = {}
# - One buffer per rank per device
for device, per_device in self.sharded_optimizer.per_device_params.items():
buffer_dtype = per_device[0][0].dtype
self._reduce_buffers[device] = [
torch.zeros(buffer_size, dtype=buffer_dtype, device=device) for _ in range(len(per_device))
]
# Sanity checks
assert len(self.sharded_optimizer.param_to_rank) == len(
list(self.module.parameters())
......@@ -126,85 +134,92 @@ class ShardedDataParallel(nn.Module):
"""
assert self.module.training, "Cannot call reduce in eval"
def reduce_grads(params: List[Parameter], params_rank: int) -> None:
""" Helper to reduce a list of params that should fit in the buffer.
if not self.need_reduction or self.accumulate_grads:
return
self.need_reduction = False
with torch.no_grad():
for device, per_device in self.sharded_optimizer.per_device_params.items():
self._reduce_grads_task(
self._reduce_buffers[device],
per_device,
group=self.process_group,
self_rank=self.rank,
world_size=self.world_size,
)
@staticmethod
def _reduce_grads_task(
buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]], group: Any, self_rank: int, world_size: int
) -> None:
"""Helper to reduce a list of params. The params are sorted by size, smallest first, which allows for
an opportunistic bucketing.
NOTE: All param gradients are assumed to exist"""
assert self.buffer is not None
# Fill in the packed IO buffer
buffer: Tensor = cast(Tensor, self.buffer)
if len(params) > 1:
offset = 0
for p in params:
sz = p.numel()
buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) # type: ignore
offset += sz
else:
# we only have a single grad to reduce
buffer = params[0].grad.data # type: ignore
buffer_size = buffers[0].numel()
bucket_requests = []
requests = []
# Reduce
buffer.div_(self.world_size) # type: ignore
dist.reduce(tensor=buffer, dst=params_rank, group=self.process_group) # type: ignore
for (rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size
if len(params) == 0:
continue
# Copy reduced grads back into their original place, or free corresponding memory
if params_rank == self.rank:
offset = 0
for p in params:
sz = p.numel()
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) # type: ignore
offset += sz
else:
for p in params:
p.grad = None
def reduction_fn() -> None:
# This function only needs to be called once
if not self.need_reduction or self.accumulate_grads:
return
self.need_reduction = False
if p.grad is None:
p.grad = torch.zeros_like(p)
if self.buffer is None:
self.buffer = next(self.module.parameters()).new(self.buffer_size) # type: ignore
global_rank = OSS.get_global_rank(group, rank)
for params in self.sharded_optimizer.per_device_params:
# Reduce the gradients in buckets
# Copy small gradients into per-GPU buffers and then async reduce
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
buffered_params: List[Parameter] = []
param_rank: Optional[int] = None
for param in params:
last_param_rank: Optional[int] = param_rank
param_rank = self.sharded_optimizer.param_to_rank[param]
if not param.requires_grad:
continue
if param.grad is None:
param.grad = torch.zeros_like(param)
if param.grad.requires_grad:
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
buffer[offset:end].copy_(params[i_bucketed].grad.data.view(-1)) # type: ignore
offset = end
i_bucketed += 1
if i_bucketed > 0:
buffer.div_(world_size) # type: ignore
bucket_requests.append(
(
dist.reduce(tensor=buffer, dst=global_rank, group=group, async_op=True), # type: ignore
rank,
)
)
# Directly reduce the other grads
for p in params[i_bucketed:]:
p.grad = cast(Tensor, p.grad)
if p.grad.requires_grad:
raise RuntimeError("DistributedDataParallel only works with gradients that don't require grad")
sz = param.numel()
if sz > self.buffer.numel():
# reduce big params directly
assert param_rank is not None
reduce_grads([param], cast(int, param_rank))
else:
# smaller params are packed together from the same device
# and same rank.
if offset + sz > self.buffer.numel() or (
last_param_rank is not None and last_param_rank != param_rank
):
assert last_param_rank is not None
reduce_grads(buffered_params, cast(int, last_param_rank))
p.grad.div_(world_size) # type: ignore
requests.append(dist.reduce(tensor=p.grad, dst=global_rank, group=group, async_op=True)) # type: ignore
# Unroll the initial packed small gradients, as soon as possible
for future, rank in bucket_requests:
future.wait()
if rank == self_rank:
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
buffered_params.clear()
buffered_params.append(cast(Parameter, param))
offset += sz
params = per_rank_params[rank]
buffer = buffers[rank]
if len(buffered_params) > 0:
assert param_rank is not None
reduce_grads(buffered_params, cast(int, param_rank))
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
params[i_bucketed].grad.data.copy_(buffer[offset:end].view_as(params[i_bucketed])) # type: ignore
offset = end
i_bucketed += 1
reduction_fn()
# Make sure that we're done with this device before moving on and cleaning the unused params
_ = list(map(lambda x: x.wait(), requests))
def _sync_buffers(self) -> None:
"""
......
......@@ -49,6 +49,8 @@ class OSS(Optimizer):
optimizer to shard (default: SGD)
group (group):
torch.distributed group (default: group.WORLD)
broadcast_buffer_size (int):
the size of the buffer used to batch the small parameter tensors (default 128k).
"""
#: The optimizer used for a given shard
......@@ -56,38 +58,53 @@ class OSS(Optimizer):
in_super_constructor: bool
def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Optional[Any] = None, **default: Any):
def __init__(
self,
params: _params_t,
optim: Type[Optimizer] = SGD,
group: Optional[Any] = None,
broadcast_buffer_size: int = 2 ** 17,
**default: Any,
):
# Hold all the model params in the root .param_groups
self.in_super_constructor = True
super().__init__(params, default)
self.in_super_constructor = False
# Partition information. lazy evaluation, computed if requested
self._per_device_params: List[List[Parameter]] = []
self._per_device_params: OrderedDict[
torch.device, List[List[Parameter]]
] = OrderedDict() # device, rank, params
self._param_rank: Dict[torch.Tensor, int] = {}
self._partition_parameters: List[List[dict]] = []
# Build the wrapped optimizer, responsible for a shard of the params
self.group = group if group is not None else dist.group.WORLD
self.world_size = dist.get_world_size(self.group)
self.rank = dist.get_rank(self.group)
self.global_rank = self.get_global_rank(self.group, self.rank)
self.optim = optim(self.partition_parameters()[self.rank], **default)
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []
# Current device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device
# Sync local and global param_groups keys
# - Sync local and global param_groups keys
for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k, v in local_group.items():
if k != "params":
global_group[k] = v
# Optional consolidated optimizer state
self._all_states: List[Dict[str, Any]] = []
# Current default device is set by the parameters allocated to this rank
self._device = self.partition_parameters()[self.rank][0]["params"][0].device
self._broadcast_buffers: Dict[torch.device, List[torch.Tensor]] = {}
for device, per_device in self.per_device_params.items():
# Allocate one buffer per rank and per device to group the small parameters
self._broadcast_buffers[device] = [
torch.zeros(broadcast_buffer_size, dtype=per_device[0][0].dtype, device=device)
for _ in range(len(per_device))
]
# Partition helpers
def partition_parameters(self) -> List[List[dict]]:
"""Partitions parameters across distributed data parallel ranks.
......@@ -116,22 +133,26 @@ class OSS(Optimizer):
return self._partition_parameters
@property
def per_device_params(self) -> List[List[Parameter]]:
# TODO (Min): The algorithm here can be improved. We are sorting params by device
# and by rank. Then in reduction_fn below, we pack smaller ones into
# a buffer for reduction.
# We can pre-sort them here and simplify the reduction_fn logic below
# since their size shouldn't change.
def per_device_params(self) -> Dict[torch.device, List[List[Parameter]]]:
"""Sorted list of all the params, first per device then per rank.
Within a list params are sorted per number of elements to allow for an easy bucketing.
"""
if len(self._per_device_params) == 0:
# Go through all params, log them per device
# The ordering is important here, needs to be the same on all ranks
# So that ulterior broadcast calls are matching
for param_group in self.param_groups:
param_lists: OrderedDict = OrderedDict()
for param in param_group["params"]:
device = param.device
if param_lists.get(device) is None:
param_lists[device] = []
param_lists[device] += [param]
self._per_device_params = list(param_lists.values())
if self._per_device_params.get(device) is None:
self._per_device_params[device] = [[] for _ in range(self.world_size)]
self._per_device_params[device][self.param_to_rank[param]] += [param]
# Sort param_lists by size
for k in self._per_device_params.keys():
for r in self._per_device_params[k]:
r.sort(key=lambda x: x.numel())
return self._per_device_params
......@@ -145,13 +166,6 @@ class OSS(Optimizer):
self._param_rank[param] = rank
return self._param_rank
def get_global_rank(self, group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore
return global_rank
# NOTE(msb) We add a kwargs in order to support Optimizer sub-classes that support extra kwargs.
# For example, the apex library contains fused optimizers with a step that supports extra kwargs.
def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) -> Optional[float]:
......@@ -174,25 +188,14 @@ class OSS(Optimizer):
else:
loss = self.optim.step(**kwargs)
# Sync all the states. Broadcast requests are issued async, we check completeness before moving on
requests = []
requires_grad = []
for rank, param_groups in enumerate(self.partition_parameters()):
for param_group in param_groups:
for param in param_group["params"]:
# NOTE: Broadcast is in-place and not differentiable
# Gloo will rightly assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank.
global_rank = self.get_global_rank(self.group, rank)
# Sync all the updated shards in between the ranks
with torch.no_grad():
for (
device,
device_params,
) in self.per_device_params.items(): # all the params on this device (inc all ranks)
self._broadcast_params(self._broadcast_buffers[device], device_params, self.group, self.global_rank)
requires_grad.append((param, param.requires_grad))
param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=global_rank, group=self.group, async_op=True))
for fut, req_grad in zip(requests, requires_grad):
fut.wait()
req_grad[0].requires_grad = req_grad[1]
return loss
def local_state_dict(self) -> dict:
......@@ -322,7 +325,10 @@ class OSS(Optimizer):
super().add_param_group(param_group)
if not self.in_super_constructor:
self._partition_parameters.clear() # Force a re-partitioning
# Force a re-partitioning
self._partition_parameters.clear()
self._per_device_params.clear()
self._param_rank.clear()
param_groups = self.partition_parameters()[self.rank]
if len(param_groups) == len(self.optim.param_groups) + 1:
......@@ -353,7 +359,6 @@ class OSS(Optimizer):
else:
# Fetch the optim state from the other replicas
global_rank = self.get_global_rank(self.group, rank)
logging.debug("Receiving state from rank %s ", global_rank)
replica_state = broadcast_object(
empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device
)
......@@ -382,16 +387,93 @@ class OSS(Optimizer):
else:
global_rank = self.get_global_rank(self.group, rank)
# Discard this tensor/rank, broadcast necessary for syncing
logging.debug("Discarding broadcast from rank %s", global_rank)
broadcast_object(empty_buffer, src_rank=global_rank, group=self.group, dist_device=self._device)
def _free_other_grads(self) -> None:
"""Free all the gradients only useful for the other ranks
"""
for i, partition in enumerate(self.partition_parameters()):
if i == self.rank:
for rank, partition in enumerate(self.partition_parameters()):
if rank == self.rank:
continue
for p in partition:
for t in p["params"]:
t.grad = None
@staticmethod
def get_global_rank(group: Any, rank: int) -> int:
if group is dist.group.WORLD:
return rank
else:
global_rank = dist.distributed_c10d._get_global_rank(group, rank) # type: ignore
return global_rank
@staticmethod
def _broadcast_params(
buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]], group: Any, self_rank: int
) -> None:
"""Helper function to broadcast all the parameters from a given device
"""
buffer_size = buffers[0].numel()
restore_require_grad = []
bucket_requests = []
requests = []
# Bucket and issue all the async calls
for (rank, params), buffer in zip(enumerate(per_rank_params), buffers):
# All the params are sorted per rank and per increasing size
if len(params) == 0:
continue
global_rank = OSS.get_global_rank(group, rank)
# Copy small parameters into per-GPU buffers
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
# Since all the parameters are already sorted per increasing size, we only need to consider the first ones.
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
if rank == self_rank:
buffer[offset:end].copy_(params[i_bucketed].data.view(-1)) # type: ignore
offset = end
i_bucketed += 1
if i_bucketed > 0:
future = dist.broadcast(tensor=buffer, src=global_rank, group=group, async_op=True)
if rank != self_rank:
# This request will need to be unrolled
bucket_requests.append((future, rank))
# Directly broadcast the rest
for param in params[i_bucketed:]:
# NOTE: Broadcast is in-place and not differentiable
# Gloo will assert on this operation for any tensor that requires grad.
# We save and restore the grad requirement state to work around that, in our case
# the grad is only useful on the source rank.
if param.requires_grad:
restore_require_grad.append(param)
param.requires_grad = False
requests.append(dist.broadcast(tensor=param, src=global_rank, group=group, async_op=True))
# Unroll the initial packed small parameters
for gate, rank in bucket_requests:
gate.wait()
params = per_rank_params[rank]
buffer = buffers[rank]
i_bucketed = 0 # the number of tensors packed in the buffer
offset = 0
while i_bucketed < len(params) and offset + params[i_bucketed].numel() < buffer_size:
end = offset + params[i_bucketed].numel()
params[i_bucketed].data.copy_(buffer[offset:end].view_as(params[i_bucketed])) # type: ignore
offset = end
i_bucketed += 1
# Unroll all the async work items, just in case
_ = list(map(lambda x: x.wait(), requests))
for p in restore_require_grad:
p.requires_grad = True
......@@ -38,8 +38,15 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
torch.cuda.set_device(rank)
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 4)).to(device)
model = Sequential(Linear(2, 3)).to(device)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
def weights_init(m):
if isinstance(m, Linear):
torch.nn.init.constant_(m.weight.data, 1.0)
torch.nn.init.constant_(m.bias.data, 1.0)
model.apply(weights_init)
model.to(device)
ddp = ShardedDataParallel(
......@@ -52,24 +59,26 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
optimizer = ddp.optimizer
model = ddp.module
input_tensor = torch.rand((64, 2)).to(device)
output = ddp(input_tensor).abs().sum() / input_tensor.numel()
# Different input per rank, allows for checking that the gradients have been properly reduced
input_tensor = (torch.ones((64, 2)) * rank).to(device)
output = ddp(input_tensor).abs().sum()
output.backward()
ddp.reduce()
# Check that all the grads have been populated, for the shard
if device == torch.device("cuda"):
torch.cuda.synchronize() # flush any remaining cuda op, just in case
for pg in optimizer.optim.param_groups:
for param in pg["params"]:
if param.requires_grad:
assert param.grad.abs().sum().item() > 0.0, "The reduce step should have populated all the gradients"
if param.shape == torch.Size([3, 2]):
assert param.grad[0, 0].cpu() == torch.tensor([32.0])
if param.shape == torch.Size([3]):
assert param.grad[0].cpu() == torch.tensor([64.0])
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
for b in model.buffers():
assert b.cpu().item() == 0.0
dist.destroy_process_group()
def run_test(backend, device, world_size=2):
temp_file_name = tempfile.mkstemp()[1]
......@@ -101,6 +110,8 @@ def run_eval_mode(_unused):
else:
assert False, "Multiple forward passes on training mode should not pass"
dist.destroy_process_group()
def test_eval_mode():
mp.spawn(run_eval_mode, args=(), join=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment