"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "2ac00ddd8684b8ba5cd571280f65b327e1c200e5"
Unverified Commit daa1bad5 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][fix] ShardedDDP deferred init (#558)

* survive the model being moved to device post-construction
* make sure that a unit test would catch a regression
parent 5e6a7a57
...@@ -178,6 +178,7 @@ class ShardedDataParallel(nn.Module): ...@@ -178,6 +178,7 @@ class ShardedDataParallel(nn.Module):
# - setup backward hooks which will be called by Torch's autograd in due time # - setup backward hooks which will be called by Torch's autograd in due time
self._grad_accs: List[Callable] = [] self._grad_accs: List[Callable] = []
self._grad_hooks: List[Any] = []
self._manual_reduce: List[Callable] = [] self._manual_reduce: List[Callable] = []
# passing a handle to torch.nn.SyncBatchNorm layer # passing a handle to torch.nn.SyncBatchNorm layer
...@@ -190,22 +191,26 @@ class ShardedDataParallel(nn.Module): ...@@ -190,22 +191,26 @@ class ShardedDataParallel(nn.Module):
self._work_handles: Deque[Workhandle] = deque() self._work_handles: Deque[Workhandle] = deque()
self._bucket_flush_callback_set = False self._bucket_flush_callback_set = False
self.refresh_trainable()
def forward(self, *inputs: Any, **kwargs: Any) -> Any: def forward(self, *inputs: Any, **kwargs: Any) -> Any:
""" """
Module forward pass, handles any DDP-specific work in the background. Primes the Module forward pass, handles any DDP-specific work in the background. Primes the
backward pass for gradient reduction to the proper ranks. backward pass for gradient reduction to the proper ranks.
""" """
# Optionally check whether the trainable parameters have changed # Deferred initialization, or change detection
needs_setup = len(self._grad_hooks) == 0
if self.auto_refresh_trainable: if self.auto_refresh_trainable:
# Optionally check whether the trainable parameters have changed
trainable_mask = list(map(_trainable, self._all_params)) trainable_mask = list(map(_trainable, self._all_params))
if trainable_mask != self._reference_trainable_mask: if trainable_mask != self._reference_trainable_mask:
logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning") logging.warning("ShardedDDP detected that the trainable params changed, updating the partitioning")
self.refresh_trainable() needs_setup = True
self._reference_trainable_mask = trainable_mask self._reference_trainable_mask = trainable_mask
if needs_setup:
self.refresh_trainable()
if self.enable_broadcast_buffers: if self.enable_broadcast_buffers:
# NCCL communications are on a different stream, needs to be blocking # NCCL communications are on a different stream, needs to be blocking
# for the subsequent FW to be correct # for the subsequent FW to be correct
...@@ -478,6 +483,10 @@ class ShardedDataParallel(nn.Module): ...@@ -478,6 +483,10 @@ class ShardedDataParallel(nn.Module):
This makes the gradient reduction automatic whenever there's a backward pass This makes the gradient reduction automatic whenever there's a backward pass
""" """
# Detach possible pre-existing hooks
while len(self._grad_hooks) > 0:
self._grad_hooks.pop().remove()
# Go through the parameters, attach the hook # Go through the parameters, attach the hook
self._grad_accs = [] self._grad_accs = []
self._manual_reduce = [] self._manual_reduce = []
...@@ -493,9 +502,10 @@ class ShardedDataParallel(nn.Module): ...@@ -493,9 +502,10 @@ class ShardedDataParallel(nn.Module):
dst_rank = self._trainable_param_to_rank[param] dst_rank = self._trainable_param_to_rank[param]
reduce_function = self._get_reduce_fn(index, param, dst_rank) reduce_function = self._get_reduce_fn(index, param, dst_rank)
grad_acc.register_hook(reduce_function)
self._manual_reduce.append(reduce_function) self._grad_hooks.append(grad_acc.register_hook(reduce_function))
self._grad_accs.append(grad_acc) # keep this hook in scope self._grad_accs.append(grad_acc) # keep this hook in scope
self._manual_reduce.append(reduce_function)
@torch.no_grad() @torch.no_grad()
def _sync_params_and_buffers(self) -> None: def _sync_params_and_buffers(self) -> None:
......
...@@ -73,8 +73,12 @@ class ParamBucket(Bucket): ...@@ -73,8 +73,12 @@ class ParamBucket(Bucket):
@torch.no_grad() @torch.no_grad()
def _add_param_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None: def _add_param_as_view(self, param: torch.Tensor, keep_existing_value: bool = True) -> None:
assert self.buffer is not None assert self.buffer is not None
assert param.dtype == self.buffer.dtype assert (
assert param.device == self.buffer.device param.dtype == self.buffer.dtype
), f"Different types for the bucket and the param, cannot proceed: {param.dtype} - {self.buffer.dtype}"
assert (
param.device == self.buffer.device
), f"Different devices for the bucket and the param, cannot proceed: {param.device} - {self.buffer.device}"
fill_next = self._fill + param.numel() fill_next = self._fill + param.numel()
assert fill_next <= self.buffer.numel() assert fill_next <= self.buffer.numel()
......
...@@ -254,7 +254,7 @@ def run_test_device_change(rank, world_size, backend, device, temp_file_name, re ...@@ -254,7 +254,7 @@ def run_test_device_change(rank, world_size, backend, device, temp_file_name, re
@skip_if_single_gpu @skip_if_single_gpu
@pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20]) @pytest.mark.parametrize("reduce_buffer_size", [0, 2 ** 20])
def test_device_change(reduce_buffer_size): def test_device_change(reduce_buffer_size):
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs # Check that ShardedDDP handles a device change properly
world_size = 2 world_size = 2
backend = "nccl" backend = "nccl"
temp_file_name = tempfile.mkstemp()[1] temp_file_name = tempfile.mkstemp()[1]
...@@ -386,10 +386,13 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name): ...@@ -386,10 +386,13 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
np.random.seed(rank) np.random.seed(rank)
model = GPT2( model = GPT2(
embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2 embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
).to(device) )
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer) ddp_model = ShardedDataParallel(model, optimizer)
# Move the model to another device post-construction
model = model.to(device)
# Optim loop # Optim loop
def closure(): def closure():
optimizer.zero_grad() optimizer.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment