Unverified Commit 79ded821 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[ShardedDDP] Sync buffers + small cleanup (#112)

- adding the buffer broadcast option
- minor cleanup in shardedDDP
parent 41819af9
......@@ -62,7 +62,7 @@ def train(
):
assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS"
# DDP
dist_init(rank, world_size, backend)
dist_init(rank=rank, world_size=world_size, backend=backend)
# Setup
torch.cuda.set_device(rank)
......@@ -81,7 +81,11 @@ def train(
if use_sdp:
ddp = ShardedDataParallel(
module=model, optimizer=OPTIM, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size,
module=model,
optimizer=OPTIM,
optimizer_params={"lr": 1e-4, "momentum": 0.9},
world_size=world_size,
broadcast_buffers=False,
)
ddp.train()
optimizer = ddp.optimizer
......
......@@ -25,7 +25,7 @@ class ShardedDataParallel(nn.Module):
"""Implements distributed data parallel training with optimizer state sharding.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does not
This version uses a c10d process group for communication and optionally
broadcast buffers.
Args:
......@@ -33,6 +33,8 @@ class ShardedDataParallel(nn.Module):
optimizer (~torch.optim.Optimizer): optimizer to be used for training
optimizer_params(Dict): extra parameters for the optimizer
world_size (int): number of parallel workers
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function. (default: ``True``)
process_group (optional): the c10d process group to be used for
distributed gradient reduction. If None, the default WORLD process group
will be used.
......@@ -47,6 +49,7 @@ class ShardedDataParallel(nn.Module):
optimizer: Type[torch.optim.Optimizer],
optimizer_params: Dict[str, Any],
world_size: int,
broadcast_buffers: bool,
process_group: Any = None,
buffer_size: int = 2 ** 28,
):
......@@ -56,6 +59,8 @@ class ShardedDataParallel(nn.Module):
self.world_size = world_size
self.process_group = process_group if process_group is not None else dist.group.WORLD
self.rank = dist.get_rank(self.process_group)
self.broadcast_buffers = broadcast_buffers
self.authoritative_rank = 0
# Never use a bigger buffer than the number of model params
self.buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
......@@ -71,7 +76,7 @@ class ShardedDataParallel(nn.Module):
# Build the sharded optimizer
self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params)
# sanity checks
# Sanity checks
assert len(self.sharded_optimizer.param_to_rank) == len(
list(self.module.parameters())
), "number of params do not match"
......@@ -109,6 +114,9 @@ class ShardedDataParallel(nn.Module):
raise RuntimeError("OssDdp requires explicit reduction, must call OssDdp.reduce")
if not self.accumulate_grads:
self.need_reduction = True
if self.broadcast_buffers and len(list(self.module.buffers())) > 0:
self._sync_buffers()
return self.module(*inputs, **kwargs)
def reduce(self) -> None:
......@@ -118,52 +126,35 @@ class ShardedDataParallel(nn.Module):
"""
assert self.module.training, "Cannot call reduce in eval"
def reduce_params(params: List[Parameter], params_rank: int) -> None:
""" Helper to reduce a list of params that should fix in the buffer. """
def reduce_grads(params: List[Parameter], params_rank: int) -> None:
""" Helper to reduce a list of params that should fit in the buffer.
NOTE: All param gradients are assumed to exist"""
assert self.buffer is not None
# Fill in the packed IO buffer
buffer: Tensor = cast(Tensor, self.buffer)
nonzero_buffer = False
if len(params) > 1:
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
# The type error could have been fixed in later
# version of pytorch. Same elsewhere.
buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) # type: ignore
nonzero_buffer = True
else:
buffer[offset : offset + sz].zero_()
buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) # type: ignore
offset += sz
else:
# we only have a single grad to reduce
p = params[0]
if p.grad is not None:
buffer = p.grad.data
nonzero_buffer = True
elif p.numel() <= self.buffer.numel():
buffer = buffer[: p.numel()]
buffer.zero_()
else:
buffer = torch.zeros_like(p)
if nonzero_buffer:
buffer.div_(self.world_size) # type: ignore
buffer = params[0].grad.data # type: ignore
dist.reduce(buffer, params_rank, group=self.process_group) # type: ignore
# Reduce
buffer.div_(self.world_size) # type: ignore
dist.reduce(tensor=buffer, dst=params_rank, group=self.process_group) # type: ignore
# Copy reduced grads back into their original place, or free corresponding memory
if params_rank == self.rank:
# copy reduced grads back into their original place
offset = 0
for p in params:
sz = p.numel()
if p.grad is not None:
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) # type: ignore
else:
p.grad = buffer[offset : offset + sz].view_as(p).clone()
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) # type: ignore
offset += sz
else:
# wipe the grads
for p in params:
p.grad = None
......@@ -195,7 +186,7 @@ class ShardedDataParallel(nn.Module):
if sz > self.buffer.numel():
# reduce big params directly
assert param_rank is not None
reduce_params([param], cast(int, param_rank))
reduce_grads([param], cast(int, param_rank))
else:
# smaller params are packed together from the same device
# and same rank.
......@@ -203,7 +194,7 @@ class ShardedDataParallel(nn.Module):
last_param_rank is not None and last_param_rank != param_rank
):
assert last_param_rank is not None
reduce_params(buffered_params, cast(int, last_param_rank))
reduce_grads(buffered_params, cast(int, last_param_rank))
offset = 0
buffered_params.clear()
buffered_params.append(cast(Parameter, param))
......@@ -211,6 +202,21 @@ class ShardedDataParallel(nn.Module):
if len(buffered_params) > 0:
assert param_rank is not None
reduce_params(buffered_params, cast(int, param_rank))
reduce_grads(buffered_params, cast(int, param_rank))
reduction_fn()
def _sync_buffers(self) -> None:
"""
Sync all the param buffers in between ranks.
TODO: Could be worth bucketing ?
"""
_ = list(
map(
lambda x: x.wait(),
map(
lambda x: dist.broadcast(x, self.authoritative_rank, self.process_group, async_op=True),
self.module.buffers(),
),
)
)
......@@ -37,12 +37,20 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if device == torch.device("cuda"):
torch.cuda.set_device(rank)
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 4)).to(device)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
ddp = ShardedDataParallel(
module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 0.1, "momentum": 0.99}, world_size=world_size
module=model,
optimizer=torch.optim.SGD,
optimizer_params={"lr": 0.01, "momentum": 0.99},
world_size=world_size,
broadcast_buffers=True,
)
optimizer = ddp.optimizer
model = ddp.module
input_tensor = torch.rand((64, 2)).to(device)
output = ddp(input_tensor).abs().sum() / input_tensor.numel()
......@@ -58,10 +66,9 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if param.requires_grad:
assert param.grad.abs().sum().item() > 0.0, "The reduce step should have populated all the gradients"
# Check that the optimization process makes sense (ie. loss goes down for the same data)
optimizer.step()
new_eval = ddp(input_tensor).abs().sum() / input_tensor.numel()
# assert new_eval.item() < output.item()
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
for b in model.buffers():
assert b.cpu().item() == 0.0
def run_test(backend, device, world_size=2):
......@@ -76,7 +83,7 @@ def run_eval_mode(_unused):
)
model = Sequential(Linear(2, 3), Linear(3, 4))
optimizer_params = {"lr": 0.1, "momentum": 0.99}
ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1)
ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1, broadcast_buffers=False)
optimizer = ddp.optimizer
ddp.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment