Unverified Commit 121b9db0 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix][OSS] two small hotfixes.. repro not obvious for grad_fn (#583)

parent 168c9baa
......@@ -494,18 +494,24 @@ class ShardedDataParallel(nn.Module):
if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
p_tmp = param.expand_as(param)
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = self._trainable_param_to_rank[param]
reduce_function = self._get_reduce_fn(index, param, dst_rank)
self._grad_hooks.append(grad_acc.register_hook(reduce_function))
self._grad_accs.append(grad_acc) # keep this hook in scope
self._manual_reduce.append(reduce_function)
# See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
# We're interested in the tensors which will be tracked by Autograd
# Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
# these do not need to be sync'ed
if p_tmp.grad_fn is not None:
# Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed
# (by default the hook with Pytorch is a pre-grad, not a post-grad)
grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = self._trainable_param_to_rank[param]
reduce_function = self._get_reduce_fn(index, param, dst_rank)
self._grad_hooks.append(grad_acc.register_hook(reduce_function))
self._grad_accs.append(grad_acc) # keep this hook in scope
self._manual_reduce.append(reduce_function)
@torch.no_grad()
def _sync_params_and_buffers(self) -> None:
......
......@@ -591,13 +591,14 @@ class OSS(Optimizer):
# Merge all the trainable params in a single bucket
trainable_params = list(filter(lambda x: x.requires_grad, params))
buffer_size = sum(map(lambda x: x.numel(), trainable_params))
bucket = ParamBucket(size=buffer_size, dtype=params[0].dtype, device=device)
if trainable_params:
buffer_size = sum(map(lambda x: x.numel(), trainable_params))
bucket = ParamBucket(size=buffer_size, dtype=trainable_params[0].dtype, device=device)
for param in trainable_params:
bucket.add_param(param)
for param in trainable_params:
bucket.add_param(param)
self.buckets[device][dst_rank] = bucket
self.buckets[device][dst_rank] = bucket
# Clear the buffer keys which are not in use anymore (could be that the devices changed)
devices_in_use = list(self._per_device_params.keys())
......
......@@ -30,8 +30,25 @@ from fairscale.utils.testing import (
)
def _get_mlp():
return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
def _get_mlp(tripwire: bool = False):
if not tripwire:
return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
class Tripwire(torch.nn.Module):
"""A model made to expose possible corner cases
"""
def __init__(self) -> None:
super().__init__()
self.model = Linear(2, 3, bias=False)
# mismatched types in between trainable or not, can trip the buckets for instance
self.register_parameter("tripwire", torch.nn.Parameter(torch.LongTensor((3, 3)), requires_grad=False))
def forward(self, x):
return self.model(x)
return Tripwire()
class _DoubleInput(torch.nn.Module):
......@@ -231,6 +248,20 @@ def test_random_attributes():
dist.destroy_process_group()
def test_mixed_types():
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
model = _get_mlp(tripwire=True)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
model = ShardedDataParallel(model, optimizer)
input_tensor = torch.rand((2, 2))
_ = model(input_tensor)
dist.destroy_process_group()
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
# Check that the wrapped module can change devices
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
......
......@@ -322,8 +322,8 @@ def run_test_step(rank, world_size, tempfile_name):
dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
p.grad.data /= world_size
o.step()
assert m.weight == torch.tensor([[0.75]], device=rank)
assert m.bias == torch.tensor([1.85], device=rank)
assert m.weight == torch.tensor([[0.75]], device=rank), f"{rank}: {m.weight.item()}, 0.75 expected"
assert m.bias == torch.tensor([1.85], device=rank), f"{rank}: {m.bias.item()}, 1.85 expected"
dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment