Unverified Commit daa1bad5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][fix] ShardedDDP deferred init (#558)

* survive the model being moved to device post-construction
* make sure that a unit test would catch a regression
parent 5e6a7a57
......@@ -178,6 +178,7 @@ class ShardedDataParallel(nn.Module):
# - setup backward hooks which will be called by Torch's autograd in due time
self._grad_accs: List[Callable] = []
self._grad_hooks: List[Any] = []
self._manual_reduce: List[Callable] = []
# passing a handle to torch.nn.SyncBatchNorm layer
......@@ -190,22 +191,26 @@ class ShardedDataParallel(nn.Module):
self._work_handles: Deque[Workhandle] = deque()
self._bucket_flush_callback_set = False
self.refresh_trainable()
def forward(self, *inputs: Any, **kwargs: Any) -> Any:
"""
Module forward pass, handles any DDP-specific work in the background. Primes the
backward pass for gradient reduction to the proper ranks.
"""
# Optionally check whether the trainable parameters have changed
# Deferred initialization, or change detection
needs_setup = len(self._grad_hooks) == 0
if self.auto_refresh_trainable:
# Optionally check whether the trainable parameters have changed
trainable_mask = list(map(_trainable, self._all_params))
if trainable_mask != self._reference_trainable_mask:
logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning")
self.refresh_trainable()
needs_setup = True
self._reference_trainable_mask = trainable_mask
if needs_setup:
self.refresh_trainable()
if self.enable_broadcast_buffers:
# NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct
......@@ -478,6 +483,10 @@ class ShardedDataParallel(nn.Module):
This makes the gradient reduction automatic whenever there's a backward pass
"""
# Detach possible pre-existing hooks
while len(self._grad_hooks) > 0:
self._grad_hooks.pop().remove()
# Go through the parameters, attach the hook
self._grad_accs = []
self._manual_reduce = []
......@@ -493,9 +502,10 @@ class ShardedDataParallel(nn.Module):
dst_rank = self._trainable_param_to_rank[param]
reduce_function = self._get_reduce_fn(index, param, dst_rank)
grad_acc.register_hook(reduce_function)
self._manual_reduce.append(reduce_function)
self._grad_hooks.append(grad_acc.register_hook(reduce_function))
self._grad_accs.append(grad_acc) # keep this hook in scope
self._manual_reduce.append(reduce_function)
@torch.no_grad()
def _sync_params_and_buffers(self) -> None:
......
......@@ -73,8 +73,12 @@ class ParamBucket(Bucket):
@torch.no_grad()
def _add_param_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None:
assert self.buffer is not None
assert param.dtype == self.buffer.dtype
assert param.device == self.buffer.device
assert (
param.dtype == self.buffer.dtype
), f"Different types for the bucket and the param, cannot proceed: {param.dtype} - {self.buffer.dtype}"
assert (
param.device == self.buffer.device
), f"Different devices for the bucket and the param, cannot proceed: {param.device} - {self.buffer.device}"
fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel()
......
......@@ -254,7 +254,7 @@ def run_test_device_change(rank, world_size, backend, device, temp_file_name, re
@skip_if_single_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_device_change(reduce_buffer_size):
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
# Check that ShardedDDP handles a device change properly
world_size = 2
backend = "nccl"
temp_file_name = tempfile.mkstemp()[1]
......@@ -386,10 +386,13 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
np.random.seed(rank)
model = GPT2(
embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
).to(device)
)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
# Move the model to another device post-construction
model = model.to(device)
# Optim loop
def closure():
optimizer.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment