Unverified Commit 7ee228bf authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[ShardedDDP][Minor] Backport a bucket flush fix from FSDP, may help a few existing users (#435)

* bring back a fix from FSDP, may help a few existing users
parent 6b2897ca
...@@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ...@@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## NEXT - TBD ## NEXT - TBD
### Fixed
- ShardedDDP auto catch trailing buckets (TBD)
## [0.3.0] - 2021-02-22 ## [0.3.0] - 2021-02-22
### Added ### Added
- FullyShardedDataParallel (FSDP) ([#413](https://github.com/facebookresearch/fairscale/issues/413)) - FullyShardedDataParallel (FSDP) ([#413](https://github.com/facebookresearch/fairscale/issues/413))
......
...@@ -17,6 +17,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Tuple, ...@@ -17,6 +17,7 @@ from typing import Any, Callable, Deque, Dict, Generator, List, Optional, Tuple,
import torch import torch
from torch import nn from torch import nn
from torch.autograd import Variable
import torch.distributed as dist import torch.distributed as dist
from torch.nn import Parameter from torch.nn import Parameter
...@@ -79,13 +80,6 @@ class ShardedDataParallel(nn.Module): ...@@ -79,13 +80,6 @@ class ShardedDataParallel(nn.Module):
One needs a `shard-aware grad scaler<ShardedGradScaler>`, which is proposed in `fairscale.optim.grad_scaler`, One needs a `shard-aware grad scaler<ShardedGradScaler>`, which is proposed in `fairscale.optim.grad_scaler`,
compatible with PytorchAMP. compatible with PytorchAMP.
.. warning:
ShardedDDP uses buckets to speed up the network communications. If some parameters require_grad but are not actually
used, there is a chance that this would prevent the bucket mechanism to function, and that this could not be automatically
handled. In that case ShardedDDP will raise an exception and suggest to either remove the unused parameters from your model
(https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html?highlight=unused_parameters is helpful)
or set `reduce_buffer_size` to 0
.. warning: .. warning:
If `auto_refresh_trainable` is set to `True` (this is the default) then any trainability change in the model graph will be handled If `auto_refresh_trainable` is set to `True` (this is the default) then any trainability change in the model graph will be handled
automatically. automatically.
...@@ -102,7 +96,7 @@ class ShardedDataParallel(nn.Module): ...@@ -102,7 +96,7 @@ class ShardedDataParallel(nn.Module):
process_group: Any = None, process_group: Any = None,
broadcast_buffers: bool = True, broadcast_buffers: bool = True,
sync_models_at_startup: bool = True, sync_models_at_startup: bool = True,
reduce_buffer_size: int = 0, reduce_buffer_size: int = 2 ** 23,
auto_refresh_trainable: bool = True, auto_refresh_trainable: bool = True,
reduce_fp16: bool = False, reduce_fp16: bool = False,
): ):
...@@ -188,6 +182,7 @@ class ShardedDataParallel(nn.Module): ...@@ -188,6 +182,7 @@ class ShardedDataParallel(nn.Module):
self._sync_params_and_buffers() self._sync_params_and_buffers()
self._work_handles: Deque[Workhandle] = deque() self._work_handles: Deque[Workhandle] = deque()
self._bucket_flush_callback_set = False
self.refresh_trainable() self.refresh_trainable()
...@@ -373,13 +368,15 @@ class ShardedDataParallel(nn.Module): ...@@ -373,13 +368,15 @@ class ShardedDataParallel(nn.Module):
assert ( assert (
self.accumulate_grads_flipped or not self.training or self.should_accumulate_grads or bucket.sent self.accumulate_grads_flipped or not self.training or self.should_accumulate_grads or bucket.sent
), ( ), (
"A bucket failed to be sent, probably unused parameters. " "A bucket failed to be sent, cannot continue as results would be wrong. "
+ "Either mark the unused parameter as not trainable (`.requires_grad = False`) " + "You can trye de-activating ShardedDDP buckets -set `reduce_buffer_size` to 0-"
+ "or de-activate ShardedDDP buckets -set `reduce_buffer_size` to 0-" + "Please submit a GitHub issue, this should not happen"
) )
bucket.reset() bucket.reset()
self._bucket_flush_callback_set = False
if not self.should_accumulate_grads: if not self.should_accumulate_grads:
self.accumulate_grads_flipped = False self.accumulate_grads_flipped = False
...@@ -408,6 +405,10 @@ class ShardedDataParallel(nn.Module): ...@@ -408,6 +405,10 @@ class ShardedDataParallel(nn.Module):
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]: if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None" assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set:
Variable._execution_engine.queue_callback(self._flush_buckets)
self._bucket_flush_callback_set = True
# Make sure that this is not fired twice # Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False self._grad_to_be_reduced[index] = False
param.grad.mul_(self.world_size_scaling) param.grad.mul_(self.world_size_scaling)
...@@ -453,6 +454,10 @@ class ShardedDataParallel(nn.Module): ...@@ -453,6 +454,10 @@ class ShardedDataParallel(nn.Module):
if not self.should_accumulate_grads and self._grad_to_be_reduced[index]: if not self.should_accumulate_grads and self._grad_to_be_reduced[index]:
assert param.grad is not None, "Reducing gradients during backward pass, cannot be None" assert param.grad is not None, "Reducing gradients during backward pass, cannot be None"
if not self._bucket_flush_callback_set:
Variable._execution_engine.queue_callback(self._flush_buckets)
self._bucket_flush_callback_set = True
# Make sure that this is not fired twice # Make sure that this is not fired twice
self._grad_to_be_reduced[index] = False self._grad_to_be_reduced[index] = False
bucket = self.buckets[param.device][dst_rank] bucket = self.buckets[param.device][dst_rank]
...@@ -617,3 +622,21 @@ class ShardedDataParallel(nn.Module): ...@@ -617,3 +622,21 @@ class ShardedDataParallel(nn.Module):
work_handle = self._work_handles.popleft() work_handle = self._work_handles.popleft()
if work_handle.callback is not None: if work_handle.callback is not None:
work_handle.callback() work_handle.callback()
# Flush all the buckets, just in case
def _flush_buckets(self) -> None:
if self._bucket_list is not None:
last_handle = None
for bucket in self._bucket_list:
if not bucket.sent:
# Normalize the bucket in one go
bucket.buffer.mul_(self.world_size_scaling)
# Reduce the bucket
last_handle = dist.reduce(
tensor=bucket.buffer, dst=bucket.destination, group=self.process_group, async_op=True,
)
bucket.sent = True
if last_handle is not None:
last_handle.wait()
...@@ -445,7 +445,7 @@ def check_same_models_across_ranks( ...@@ -445,7 +445,7 @@ def check_same_models_across_ranks(
for sync_p in receptacle[1:]: for sync_p in receptacle[1:]:
assert not params_should_be_equal or torch.all( assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_p) torch.eq(receptacle[0], sync_p)
), "Models differ in between ranks" ), f"Models differ in between ranks {receptacle[0]} - {sync_p}"
# Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0) # Check that all the buffers are in sync (authoritative rank is 0, its buffer is 0)
if check_broadcast_buffers: if check_broadcast_buffers:
...@@ -456,7 +456,7 @@ def check_same_models_across_ranks( ...@@ -456,7 +456,7 @@ def check_same_models_across_ranks(
for sync_b in receptacle[1:]: for sync_b in receptacle[1:]:
assert not params_should_be_equal or torch.all( assert not params_should_be_equal or torch.all(
torch.eq(receptacle[0], sync_b) torch.eq(receptacle[0], sync_b)
), "Models differ in between ranks" ), f"Models differ in between ranks {receptacle[0]} - {sync_b}"
class DeviceAndTypeCheckModule(Base): class DeviceAndTypeCheckModule(Base):
......
...@@ -86,7 +86,7 @@ def run_one_step( ...@@ -86,7 +86,7 @@ def run_one_step(
for i in range(5): for i in range(5):
_ = optimizer.step(closure=closure) _ = optimizer.step(closure=closure)
# when running on cpu/gloo the "nodes" are not really different # when running on cpu/gloo the "nodes" are not really different
same_params = device == torch.device("cpu") or grad_accumulation same_params = device == torch.device("cpu") or not grad_accumulation
check_same_models_across_ranks( check_same_models_across_ranks(
ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers ddp_model, dist.group.WORLD, params_should_be_equal=same_params, check_broadcast_buffers=broadcast_buffers
) )
...@@ -359,7 +359,7 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name): ...@@ -359,7 +359,7 @@ def run_test_gpt2(rank, world_size, backend, device, temp_file_name):
embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2 embed_dim=256, num_heads=2, num_layers=12, num_positions=INPUT_DIM * INPUT_DIM, num_vocab=512, num_classes=2
).to(device) ).to(device)
optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99) optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-3, momentum=0.99)
ddp_model = ShardedDataParallel(model, optimizer, reduce_buffer_size=0) ddp_model = ShardedDataParallel(model, optimizer)
# Optim loop # Optim loop
def closure(): def closure():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment