Commit 72a5487c authored by Kritika Singh's avatar Kritika Singh Committed by Facebook Github Bot
Browse files

Allow unused params in distributed training

Summary:
Context from https://fb.workplace.com/groups/1405155842844877/permalink/2785095451517569/:

I am adding a model to pyspeech (formerly fairspeq) with the following `forward`:
```
def forward(self, src_tokens, src_lengths, prev_output_tokens, name):
    encoder_out = self.encoder(src_tokens, src_lengths)
    if name == Dataset.d1:
        decoder_out = self.decoder1(prev_output_tokens, encoder_out)
    elif name == Dataset.d2:
        decoder_out = self.decoder2(encoder_out)
    return decoder_out
```
When I run distributed training on this model, I get the following error:

```
RuntimeError: Expected to have finished reduction in the prior iteration before starting a
new one. This error indicates that your module has parameters that were not used in
producing loss. You can enable unused parameter detection by (1) passing the keyword
argument `find_unused_parameters=True` to `torch.nn.parallel.DistributedDataParallel`; (2)
making sure all `forward` function outputs participate in calculating loss. If you already have
done the above two steps, then the distributed data parallel module wasn't able to locate the
output tensors in the return value of your module's `forward` function. Please include the loss
function and the structure of the return value of `forward` of your module when reporting this
issue (e.g. list, dict, iterable). (prepare_for_backward at
caffe2/torch/csrc/distributed/c10d/reducer.cpp:410)
```

The recommended fix is to pass find_unused_parameters=True to DistributedDataParallel's initialization

Reviewed By: myleott

Differential Revision: D15439726

fbshipit-source-id: 7fd80d4a3f49ac90182dec723b49b14e6689406a
parent c11aaf14
...@@ -36,10 +36,13 @@ def DistributedFairseqModel(args, model): ...@@ -36,10 +36,13 @@ def DistributedFairseqModel(args, model):
output_device=args.device_id, output_device=args.device_id,
broadcast_buffers=False, broadcast_buffers=False,
bucket_cap_mb=args.bucket_cap_mb, bucket_cap_mb=args.bucket_cap_mb,
find_unused_parameters=args.find_unused_parameters
) )
# Maintain backward compatibility for 0.4 or earlier # Maintain backward compatibility
if 'check_reduction' in inspect.getargspec(ddp_class)[0]: if 'check_reduction' in inspect.getargspec(ddp_class)[0]:
init_kwargs['check_reduction'] = True init_kwargs['check_reduction'] = True
if 'find_unused_parameters' in inspect.getargspec(ddp_class)[0]:
init_kwargs['find_unused_parameters'] = args.find_unused_parameters
elif args.ddp_backend == 'no_c10d': elif args.ddp_backend == 'no_c10d':
ddp_class = LegacyDistributedDataParallel ddp_class = LegacyDistributedDataParallel
init_kwargs = dict( init_kwargs = dict(
......
...@@ -282,6 +282,9 @@ def add_distributed_training_args(parser): ...@@ -282,6 +282,9 @@ def add_distributed_training_args(parser):
help='don\'t shuffle batches between GPUs; this reduces overall ' help='don\'t shuffle batches between GPUs; this reduces overall '
'randomness and may affect precision but avoids the cost of ' 'randomness and may affect precision but avoids the cost of '
're-reading the data') 're-reading the data')
group.add_argument('--find-unused-parameters', default=False, action='store_true',
help='disable unused parameter detection (not applicable to '
'no_c10d ddp-backend')
# fmt: on # fmt: on
return group return group
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment