Unverified Commit 47042917 authored by Benjamin Lefaudeux's avatar Benjamin Lefaudeux Committed by GitHub
Browse files

[feat][ShardedDDP] manual reduce option (#389)

* initial implementation, with unit test and assert
* added changelog and better debug string
parent 54bd62d3
......@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [next rel] - TBD
### Fixed
- ShardedDDP and OSS handle model trainability changes during training ([#369](https://github.com/facebookresearch/fairscale/issues/369))
- ShardedDDP state dict load/save bug (#386)
### Added
- ShardedDDP manual reduce option for checkpointing (#389)
## [0.1.6] - 2021-02-10
### Added
......
......@@ -250,7 +250,9 @@ class ShardedDataParallel(nn.Module):
""" If the module trainability has changed, update all the assumptions """
# Make sure that this is not done while gradients are waiting to be reduced (if no_sync context for instance)
assert not functools.reduce(lambda x, y: x or y, self._grad_to_be_reduced, False), "Grads waiting to be reduced"
assert not functools.reduce(
lambda x, y: x or y, self._grad_to_be_reduced, False
), "Grads waiting to be reduced: {}".format(self._grad_to_be_reduced)
self._trainable_params = list(filter(lambda x: x.requires_grad, self._all_params))
self._trainable_params.sort(key=lambda x: x.numel())
......@@ -273,11 +275,21 @@ class ShardedDataParallel(nn.Module):
self._setup_backward_hooks()
def reduce(self) -> None:
""".. deprecated:: 0.0.4
This does not need to be called, the gradient reduction is done automatically during the BW pass
"""
logging.warning("This is not useful anymore, gradients have been reduced automatically with the backward pass")
This does not *need* to be called, the gradient reduction is done automatically during the BW pass.
Use this method to reduce the gradients manually
"""
# Check that this is not a mistake, if there's nothing to reduce
assert functools.reduce(
lambda x, y: x or y, self._grad_to_be_reduced, False
), "No grads waiting to be reduced, maybe that this was called twice or there was no BW pass ?"
# Trigger all the current BW hooks
_ = map(lambda x: x(), self._grad_accs)
# Make sure that all the futures are consumed
self._consume_work_handles()
@torch.no_grad()
def sync_buffers(self, blocking: bool = False) -> None:
......
......@@ -137,10 +137,10 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
INPUTS = 2
BATCH_SIZE = 32
def check_parity(amp: bool, accumulate: bool, change_train_graph: bool):
def check_parity(amp: bool, accumulate: bool, change_train_graph: bool, manual_reduction: bool):
# The API should be the exact same in between the sharded and non-sharded variants, generic closure
def closure(model, scaler, input_tensor, should_accumulate):
def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False):
accumulate_steps = 3 if should_accumulate else 1
model.zero_grad()
......@@ -158,7 +158,13 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
for _ in range(accumulate_steps - 1):
step()
if not _manual_reduction:
step()
else:
with model.no_sync():
step()
model.reduce()
# Any model works. Add one different buffer per rank
model = Sequential(Linear(INPUTS, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3), Linear(3, 3))
......@@ -192,7 +198,9 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
return closure(ddp_model, ddp_scaler, input_tensor, accumulate)
def closure_sharded(input_tensor=input_tensor):
return closure(sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate)
return closure(
sharded_ddp_model, sharded_ddp_scaler, input_tensor, accumulate, _manual_reduction=manual_reduction
)
# Step/scale both
if ddp_scaler is not None:
......@@ -226,11 +234,18 @@ def run_ddp_parity(rank, world_size, backend, temp_file_name):
for accumulate in [False, True]:
for change_train_graph in [False, True]:
manual_reductions = [False, True] if not accumulate and not change_train_graph else [False]
for manual_reduction in manual_reductions:
for amp in amp_tests:
print(
f"Checking configuration: accumulate {accumulate} - change train graph {change_train_graph} - amp {amp}"
f"Checking configuration: accumulate {accumulate} - change train graph {change_train_graph} - amp {amp} - manual reduction {manual_reduction}"
)
check_parity(
amp=amp,
accumulate=accumulate,
change_train_graph=change_train_graph,
manual_reduction=manual_reduction,
)
check_parity(amp=amp, accumulate=accumulate, change_train_graph=change_train_graph)
dist.destroy_process_group()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment