Unverified Commit a265586b authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[fix] ShardedDDP - properly handle post device change (#353)

* adding the .to(device) support + unit testing
* doc update
parent 9e8929e6
......@@ -11,7 +11,7 @@ reduction automatically.
import contextlib
from itertools import chain
import logging
from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
import torch
from torch import nn
......@@ -142,7 +142,6 @@ class ShardedDataParallel(nn.Module):
self.buckets: Dict[OSS, Dict[torch.device, List[Bucket]]] = {o: {} for o in self.sharded_optimizers}
self._should_bucket_grad: List[bool] = []
self._bucket_iterator: Optional[Iterable[Bucket]] = None
self._setup_bucket_strategy()
# - setup backward hooks which will be called by Torch's autograd in due time
......@@ -172,6 +171,44 @@ class ShardedDataParallel(nn.Module):
# Normal FW on the base model
return self.module(*inputs, **kwargs)
def to( # type: ignore
self,
device: Optional[Union[int, torch.device]],
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> "ShardedDataParallel":
"""
Moves and/or casts the parameters and buffers.
Its signature is similar to :meth:`torch.Tensor.to`, but only accepts
floating point desired :attr:`dtype` s. In addition, this method will
only cast the floating point parameters and buffers to :attr:`dtype`
(if given). The integral parameters and buffers will be moved
:attr:`device`, if that is given, but with dtypes unchanged. When
:attr:`non_blocking` is set, it tries to convert/move asynchronously
with respect to the host if possible, e.g., moving CPU Tensors with
pinned memory to CUDA devices.
.. note::
This method modifies the module in-place.
Arguments:
device (:class:`torch.device`): the desired device of the parameters and buffers in this module.
dtype (:class:`torch.dtype`): the desired floating point type of the floating point parameters and buffers.
non_blocking (bool): make it an asynchronous call.
Returns:
Module: self.
"""
for optimizer in self.buckets.keys():
for device in self.buckets[optimizer].keys():
for bucket in self.buckets[optimizer][device]:
bucket.buffer.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.module.to(device)
def reduce(self) -> None:
""".. deprecated:: 0.0.4
......@@ -215,14 +252,15 @@ class ShardedDataParallel(nn.Module):
def _clear_counters(self) -> None:
"""Reset all the grad reduce and call counters"""
if not self.should_accumulate_grads:
self._grad_to_be_reduced = [True for _ in self._grad_to_be_reduced]
self._reduced_grads = {o: 0 for o in self.sharded_optimizers}
for o in self.buckets.keys():
for d in self.buckets[o].keys():
for bucket in self.buckets[o][d]:
for optimizer in self.buckets.keys():
for device in self.buckets[optimizer].keys():
for bucket in self.buckets[optimizer][device]:
assert bucket.sent, (
"A bucket failed being sent, probably unused parameters."
"A bucket failed to be sent, probably unused parameters."
+ "Either remove the unused parameter or de-activate ShardedDDP buckets -set reduce_buffer_size to 0-"
)
......
......@@ -342,6 +342,35 @@ def test_random_attributes():
dist.destroy_process_group()
def run_test_device_change(rank, world_size, backend, device, temp_file_name):
# Check that the wrapped module can change devices
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
model = Sequential(Linear(2, 3), Linear(3, 3)).cpu()
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=0.01, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer)
ddp_model.to(device)
inputs = torch.rand((10, 2), device=device)
outputs = ddp_model(inputs) # assert if the module has not been changed properly
loss = outputs.norm().backward()
dist.destroy_process_group()
@skip_if_no_cuda
@skip_if_single_gpu
def test_device_change():
# Check that ShardedDDP is compatible with sync batch norm across multiple GPUs
world_size = 2
backend = "gloo"
temp_file_name = tempfile.mkstemp()[1]
device = "cuda"
mp.spawn(run_test_device_change, args=(world_size, backend, device, temp_file_name), nprocs=world_size, join=True)
def run_test_ddp_sync_batch_norm(rank, world_size, backend, device, temp_file_name):
url = "file://" + temp_file_name
dist.init_process_group(init_method=url, backend=backend, rank=rank, world_size=world_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment