Unverified Commit 79ded821 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[ShardedDDP] Sync buffers + small cleanup (#112)

- adding the buffer broadcast option
- minor cleanup in shardedDDP
parent 41819af9
...@@ -62,7 +62,7 @@ def train( ...@@ -62,7 +62,7 @@ def train(
): ):
assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS" assert not use_sdp or (use_sdp and use_oss), "ShardedDataParallel requires OSS"
# DDP # DDP
dist_init(rank, world_size, backend) dist_init(rank=rank, world_size=world_size, backend=backend)
# Setup # Setup
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
...@@ -81,7 +81,11 @@ def train( ...@@ -81,7 +81,11 @@ def train(
if use_sdp: if use_sdp:
ddp = ShardedDataParallel( ddp = ShardedDataParallel(
module=model, optimizer=OPTIM, optimizer_params={"lr": 1e-4, "momentum": 0.9}, world_size=world_size, module=model,
optimizer=OPTIM,
optimizer_params={"lr": 1e-4, "momentum": 0.9},
world_size=world_size,
broadcast_buffers=False,
) )
ddp.train() ddp.train()
optimizer = ddp.optimizer optimizer = ddp.optimizer
......
...@@ -25,7 +25,7 @@ class ShardedDataParallel(nn.Module): ...@@ -25,7 +25,7 @@ class ShardedDataParallel(nn.Module):
"""Implements distributed data parallel training with optimizer state sharding. """Implements distributed data parallel training with optimizer state sharding.
A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`. A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`.
This version uses a c10d process group for communication and does not This version uses a c10d process group for communication and optionally
broadcast buffers. broadcast buffers.
Args: Args:
...@@ -33,6 +33,8 @@ class ShardedDataParallel(nn.Module): ...@@ -33,6 +33,8 @@ class ShardedDataParallel(nn.Module):
optimizer (~torch.optim.Optimizer): optimizer to be used for training optimizer (~torch.optim.Optimizer): optimizer to be used for training
optimizer_params(Dict): extra parameters for the optimizer optimizer_params(Dict): extra parameters for the optimizer
world_size (int): number of parallel workers world_size (int): number of parallel workers
broadcast_buffers (bool): flag that enables syncing (broadcasting) buffers of
the module at beginning of the forward function. (default: ``True``)
process_group (optional): the c10d process group to be used for process_group (optional): the c10d process group to be used for
distributed gradient reduction. If None, the default WORLD process group distributed gradient reduction. If None, the default WORLD process group
will be used. will be used.
...@@ -47,6 +49,7 @@ class ShardedDataParallel(nn.Module): ...@@ -47,6 +49,7 @@ class ShardedDataParallel(nn.Module):
optimizer: Type[torch.optim.Optimizer], optimizer: Type[torch.optim.Optimizer],
optimizer_params: Dict[str, Any], optimizer_params: Dict[str, Any],
world_size: int, world_size: int,
broadcast_buffers: bool,
process_group: Any = None, process_group: Any = None,
buffer_size: int = 2 ** 28, buffer_size: int = 2 ** 28,
): ):
...@@ -56,6 +59,8 @@ class ShardedDataParallel(nn.Module): ...@@ -56,6 +59,8 @@ class ShardedDataParallel(nn.Module):
self.world_size = world_size self.world_size = world_size
self.process_group = process_group if process_group is not None else dist.group.WORLD self.process_group = process_group if process_group is not None else dist.group.WORLD
self.rank = dist.get_rank(self.process_group) self.rank = dist.get_rank(self.process_group)
self.broadcast_buffers = broadcast_buffers
self.authoritative_rank = 0
# Never use a bigger buffer than the number of model params # Never use a bigger buffer than the number of model params
self.buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters())) self.buffer_size = min(buffer_size, sum(p.numel() for p in self.module.parameters()))
...@@ -71,7 +76,7 @@ class ShardedDataParallel(nn.Module): ...@@ -71,7 +76,7 @@ class ShardedDataParallel(nn.Module):
# Build the sharded optimizer # Build the sharded optimizer
self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params) self.sharded_optimizer = OSS(self.module.parameters(), optim=optimizer, group=process_group, **optimizer_params)
# sanity checks # Sanity checks
assert len(self.sharded_optimizer.param_to_rank) == len( assert len(self.sharded_optimizer.param_to_rank) == len(
list(self.module.parameters()) list(self.module.parameters())
), "number of params do not match" ), "number of params do not match"
...@@ -109,6 +114,9 @@ class ShardedDataParallel(nn.Module): ...@@ -109,6 +114,9 @@ class ShardedDataParallel(nn.Module):
raise RuntimeError("OssDdp requires explicit reduction, must call OssDdp.reduce") raise RuntimeError("OssDdp requires explicit reduction, must call OssDdp.reduce")
if not self.accumulate_grads: if not self.accumulate_grads:
self.need_reduction = True self.need_reduction = True
if self.broadcast_buffers and len(list(self.module.buffers())) > 0:
self._sync_buffers()
return self.module(*inputs, **kwargs) return self.module(*inputs, **kwargs)
def reduce(self) -> None: def reduce(self) -> None:
...@@ -118,52 +126,35 @@ class ShardedDataParallel(nn.Module): ...@@ -118,52 +126,35 @@ class ShardedDataParallel(nn.Module):
""" """
assert self.module.training, "Cannot call reduce in eval" assert self.module.training, "Cannot call reduce in eval"
def reduce_params(params: List[Parameter], params_rank: int) -> None: def reduce_grads(params: List[Parameter], params_rank: int) -> None:
""" Helper to reduce a list of params that should fix in the buffer. """ """ Helper to reduce a list of params that should fit in the buffer.
NOTE: All param gradients are assumed to exist"""
assert self.buffer is not None assert self.buffer is not None
# Fill in the packed IO buffer
buffer: Tensor = cast(Tensor, self.buffer) buffer: Tensor = cast(Tensor, self.buffer)
nonzero_buffer = False
if len(params) > 1: if len(params) > 1:
offset = 0 offset = 0
for p in params: for p in params:
sz = p.numel() sz = p.numel()
if p.grad is not None:
# The type error could have been fixed in later
# version of pytorch. Same elsewhere.
buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) # type: ignore buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) # type: ignore
nonzero_buffer = True
else:
buffer[offset : offset + sz].zero_()
offset += sz offset += sz
else: else:
# we only have a single grad to reduce # we only have a single grad to reduce
p = params[0] buffer = params[0].grad.data # type: ignore
if p.grad is not None:
buffer = p.grad.data
nonzero_buffer = True
elif p.numel() <= self.buffer.numel():
buffer = buffer[: p.numel()]
buffer.zero_()
else:
buffer = torch.zeros_like(p)
if nonzero_buffer: # Reduce
buffer.div_(self.world_size) # type: ignore buffer.div_(self.world_size) # type: ignore
dist.reduce(tensor=buffer, dst=params_rank, group=self.process_group) # type: ignore
dist.reduce(buffer, params_rank, group=self.process_group) # type: ignore # Copy reduced grads back into their original place, or free corresponding memory
if params_rank == self.rank: if params_rank == self.rank:
# copy reduced grads back into their original place
offset = 0 offset = 0
for p in params: for p in params:
sz = p.numel() sz = p.numel()
if p.grad is not None:
p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) # type: ignore p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) # type: ignore
else:
p.grad = buffer[offset : offset + sz].view_as(p).clone()
offset += sz offset += sz
else: else:
# wipe the grads
for p in params: for p in params:
p.grad = None p.grad = None
...@@ -195,7 +186,7 @@ class ShardedDataParallel(nn.Module): ...@@ -195,7 +186,7 @@ class ShardedDataParallel(nn.Module):
if sz > self.buffer.numel(): if sz > self.buffer.numel():
# reduce big params directly # reduce big params directly
assert param_rank is not None assert param_rank is not None
reduce_params([param], cast(int, param_rank)) reduce_grads([param], cast(int, param_rank))
else: else:
# smaller params are packed together from the same device # smaller params are packed together from the same device
# and same rank. # and same rank.
...@@ -203,7 +194,7 @@ class ShardedDataParallel(nn.Module): ...@@ -203,7 +194,7 @@ class ShardedDataParallel(nn.Module):
last_param_rank is not None and last_param_rank != param_rank last_param_rank is not None and last_param_rank != param_rank
): ):
assert last_param_rank is not None assert last_param_rank is not None
reduce_params(buffered_params, cast(int, last_param_rank)) reduce_grads(buffered_params, cast(int, last_param_rank))
offset = 0 offset = 0
buffered_params.clear() buffered_params.clear()
buffered_params.append(cast(Parameter, param)) buffered_params.append(cast(Parameter, param))
...@@ -211,6 +202,21 @@ class ShardedDataParallel(nn.Module): ...@@ -211,6 +202,21 @@ class ShardedDataParallel(nn.Module):
if len(buffered_params) > 0: if len(buffered_params) > 0:
assert param_rank is not None assert param_rank is not None
reduce_params(buffered_params, cast(int, param_rank)) reduce_grads(buffered_params, cast(int, param_rank))
reduction_fn() reduction_fn()
def _sync_buffers(self) -> None:
"""
Sync all the param buffers in between ranks.
TODO: Could be worth bucketing ?
"""
_ = list(
map(
lambda x: x.wait(),
map(
lambda x: dist.broadcast(x, self.authoritative_rank, self.process_group, async_op=True),
self.module.buffers(),
),
)
)
...@@ -37,12 +37,20 @@ def run_one_step(rank, world_size, backend, device, temp_file_name): ...@@ -37,12 +37,20 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if device == torch.device("cuda"): if device == torch.device("cuda"):
torch.cuda.set_device(rank) torch.cuda.set_device(rank)
# Any model works. Add one different buffer per rank
model = Sequential(Linear(2, 3), Linear(3, 4)).to(device) model = Sequential(Linear(2, 3), Linear(3, 4)).to(device)
model.register_buffer("test_buffer", torch.ones((1)) * rank)
model.to(device)
ddp = ShardedDataParallel( ddp = ShardedDataParallel(
module=model, optimizer=torch.optim.SGD, optimizer_params={"lr": 0.1, "momentum": 0.99}, world_size=world_size module=model,
optimizer=torch.optim.SGD,
optimizer_params={"lr": 0.01, "momentum": 0.99},
world_size=world_size,
broadcast_buffers=True,
) )
optimizer = ddp.optimizer optimizer = ddp.optimizer
model = ddp.module
input_tensor = torch.rand((64, 2)).to(device) input_tensor = torch.rand((64, 2)).to(device)
output = ddp(input_tensor).abs().sum() / input_tensor.numel() output = ddp(input_tensor).abs().sum() / input_tensor.numel()
...@@ -58,10 +66,9 @@ def run_one_step(rank, world_size, backend, device, temp_file_name): ...@@ -58,10 +66,9 @@ def run_one_step(rank, world_size, backend, device, temp_file_name):
if param.requires_grad: if param.requires_grad:
assert param.grad.abs().sum().item() > 0.0, "The reduce step should have populated all the gradients" assert param.grad.abs().sum().item() > 0.0, "The reduce step should have populated all the gradients"
# Check that the optimization process makes sense (ie. loss goes down for the same data) # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
optimizer.step() for b in model.buffers():
new_eval = ddp(input_tensor).abs().sum() / input_tensor.numel() assert b.cpu().item() == 0.0
# assert new_eval.item() < output.item()
def run_test(backend, device, world_size=2): def run_test(backend, device, world_size=2):
...@@ -76,7 +83,7 @@ def run_eval_mode(_unused): ...@@ -76,7 +83,7 @@ def run_eval_mode(_unused):
) )
model = Sequential(Linear(2, 3), Linear(3, 4)) model = Sequential(Linear(2, 3), Linear(3, 4))
optimizer_params = {"lr": 0.1, "momentum": 0.99} optimizer_params = {"lr": 0.1, "momentum": 0.99}
ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1) ddp = ShardedDataParallel(model, torch.optim.SGD, optimizer_params, 1, broadcast_buffers=False)
optimizer = ddp.optimizer optimizer = ddp.optimizer
ddp.eval() ddp.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment