Unverified Commit 121b9db0 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix][OSS] two small hotfixes.. repro not obvious for grad_fn (#583)

parent 168c9baa
...@@ -494,10 +494,16 @@ class ShardedDataParallel(nn.Module): ...@@ -494,10 +494,16 @@ class ShardedDataParallel(nn.Module):
if param.grad is not None and param.grad.requires_grad: if param.grad is not None and param.grad.requires_grad:
raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad") raise RuntimeError("ShardedDataParallel only works with gradients that don't require grad")
p_tmp = param.expand_as(param)
# See https://pytorch.org/docs/stable/tensors.html?highlight=grad_fn
# We're interested in the tensors which will be tracked by Autograd
# Some tensors can have gradients independent of the inputs (ie. pooling layer for instance),
# these do not need to be sync'ed
if p_tmp.grad_fn is not None:
# Register the hook to the next function in line, # Register the hook to the next function in line,
# so that the hook is fired when this grad has properly been computed # so that the hook is fired when this grad has properly been computed
p_tmp = param.expand_as(param) # (by default the hook with Pytorch is a pre-grad, not a post-grad)
assert p_tmp.grad_fn is not None
grad_acc = p_tmp.grad_fn.next_functions[0][0] grad_acc = p_tmp.grad_fn.next_functions[0][0]
dst_rank = self._trainable_param_to_rank[param] dst_rank = self._trainable_param_to_rank[param]
......
...@@ -591,8 +591,9 @@ class OSS(Optimizer): ...@@ -591,8 +591,9 @@ class OSS(Optimizer):
# Merge all the trainable params in a single bucket # Merge all the trainable params in a single bucket
trainable_params = list(filter(lambda x: x.requires_grad, params)) trainable_params = list(filter(lambda x: x.requires_grad, params))
if trainable_params:
buffer_size = sum(map(lambda x: x.numel(), trainable_params)) buffer_size = sum(map(lambda x: x.numel(), trainable_params))
bucket = ParamBucket(size=buffer_size, dtype=params[0].dtype, device=device) bucket = ParamBucket(size=buffer_size, dtype=trainable_params[0].dtype, device=device)
for param in trainable_params: for param in trainable_params:
bucket.add_param(param) bucket.add_param(param)
......
...@@ -30,9 +30,26 @@ from fairscale.utils.testing import ( ...@@ -30,9 +30,26 @@ from fairscale.utils.testing import (
) )
def _get_mlp(): def _get_mlp(tripwire: bool = False):
if not tripwire:
return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3)) return Sequential(Linear(2, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
class Tripwire(torch.nn.Module):
"""A model made to expose possible corner cases
"""
def __init__(self) -> None:
super().__init__()
self.model = Linear(2, 3, bias=False)
# mismatched types in between trainable or not, can trip the buckets for instance
self.register_parameter("tripwire", torch.nn.Parameter(torch.LongTensor((3, 3)), requires_grad=False))
def forward(self, x):
return self.model(x)
return Tripwire()
class _DoubleInput(torch.nn.Module): class _DoubleInput(torch.nn.Module):
def __init__(self): def __init__(self):
...@@ -231,6 +248,20 @@ def test_random_attributes(): ...@@ -231,6 +248,20 @@ def test_random_attributes():
dist.destroy_process_group() dist.destroy_process_group()
def test_mixed_types():
# Check that ShardedDDP exposes the original module's attributes
dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1], backend="gloo", rank=0, world_size=1)
model = _get_mlp(tripwire=True)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
model = ShardedDataParallel(model, optimizer)
input_tensor = torch.rand((2, 2))
_ = model(input_tensor)
dist.destroy_process_group()
def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size): def run_test_device_change(rank, world_size, backend, device, temp_file_name, reduce_buffer_size):
# Check that the wrapped module can change devices # Check that the wrapped module can change devices
dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size) dist.init_process_group(init_method="file://" + temp_file_name, backend=backend, rank=rank, world_size=world_size)
......
...@@ -322,8 +322,8 @@ def run_test_step(rank, world_size, tempfile_name): ...@@ -322,8 +322,8 @@ def run_test_step(rank, world_size, tempfile_name):
dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM) dist.all_reduce(p.grad.data, op=dist.ReduceOp.SUM)
p.grad.data /= world_size p.grad.data /= world_size
o.step() o.step()
assert m.weight == torch.tensor([[0.75]], device=rank) assert m.weight == torch.tensor([[0.75]], device=rank), f"{rank}: {m.weight.item()}, 0.75 expected"
assert m.bias == torch.tensor([1.85], device=rank) assert m.bias == torch.tensor([1.85], device=rank), f"{rank}: {m.bias.item()}, 1.85 expected"
dist.destroy_process_group() dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment