Commit fc677c94 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix proxying in DistributedFairseqModel

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/302

Differential Revision: D10174608

Pulled By: myleott

fbshipit-source-id: 4e2dfc76eae97afc5488f29b47e74f9897a643ff
parent f766c9a0
...@@ -12,22 +12,21 @@ from fairseq.distributed_utils import c10d_status ...@@ -12,22 +12,21 @@ from fairseq.distributed_utils import c10d_status
from . import BaseFairseqModel from . import BaseFairseqModel
class DistributedFairseqModel(BaseFairseqModel): def DistributedFairseqModel(args, model):
""" """
A wrapper around a :class:`BaseFairseqModel` instance that adds support for Wrap a *model* to support distributed data parallel training.
distributed training.
Anytime a method or attribute is called on this class we first try to This is similar to the built-in DistributedDataParallel, but allows
forward it to the underlying DistributedDataParallel instance, otherwise we additional configuration of the DistributedDataParallel class to
forward it to the original :class:`BaseFairseqModel` instance. use, and also provides easier access to the wrapped model by
forwarding requests for missing attributes to the wrapped model.
Args: Args:
args (argparse.Namespace): fairseq args args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap model (BaseFairseqModel): model to wrap
""" """
def __init__(self, args, model): # determine which DDP class to extend
super().__init__()
assert isinstance(model, BaseFairseqModel) assert isinstance(model, BaseFairseqModel)
if args.ddp_backend == 'c10d': if args.ddp_backend == 'c10d':
if c10d_status.is_default: if c10d_status.is_default:
...@@ -39,7 +38,7 @@ class DistributedFairseqModel(BaseFairseqModel): ...@@ -39,7 +38,7 @@ class DistributedFairseqModel(BaseFairseqModel):
'Can\'t find c10d version of DistributedDataParallel. ' 'Can\'t find c10d version of DistributedDataParallel. '
'Please update PyTorch.' 'Please update PyTorch.'
) )
self.ddp_model = ddp_class( init_kwargs = dict(
module=model, module=model,
device_ids=[args.device_id], device_ids=[args.device_id],
output_device=args.device_id, output_device=args.device_id,
...@@ -51,7 +50,7 @@ class DistributedFairseqModel(BaseFairseqModel): ...@@ -51,7 +50,7 @@ class DistributedFairseqModel(BaseFairseqModel):
ddp_class = parallel.deprecated.DistributedDataParallel ddp_class = parallel.deprecated.DistributedDataParallel
else: else:
ddp_class = parallel.DistributedDataParallel ddp_class = parallel.DistributedDataParallel
self.ddp_model = ddp_class( init_kwargs = dict(
module=model, module=model,
device_ids=[args.device_id], device_ids=[args.device_id],
output_device=args.device_id, output_device=args.device_id,
...@@ -60,19 +59,17 @@ class DistributedFairseqModel(BaseFairseqModel): ...@@ -60,19 +59,17 @@ class DistributedFairseqModel(BaseFairseqModel):
else: else:
raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend) raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend)
def __call__(self, *args, **kwargs): class _DistributedFairseqModel(ddp_class):
return self.ddp_model(*args, **kwargs) """Extend DistributedDataParallel to check for missing
attributes in the wrapped module."""
def forward(self, *args, **kwargs): def __init__(self, *args, **kwargs):
return self.ddp_model.forward(*args, **kwargs) super().__init__(*args, **kwargs)
def __getattr__(self, name): def __getattr__(self, name):
try: wrapped_module = super().__getattr__('module')
if hasattr(wrapped_module, name):
return getattr(wrapped_module, name)
return super().__getattr__(name) return super().__getattr__(name)
except AttributeError:
pass return _DistributedFairseqModel(**init_kwargs)
try:
return self.ddp_model.__getattr__(name)
except AttributeError:
pass
return self.ddp_model.module.__getattr__(name)
...@@ -281,7 +281,7 @@ def train_language_model(data_dir, arch): ...@@ -281,7 +281,7 @@ def train_language_model(data_dir, arch):
data_dir, data_dir,
'--arch', arch, '--arch', arch,
'--optimizer', 'nag', '--optimizer', 'nag',
'--lr', '1.0', '--lr', '0.1',
'--criterion', 'adaptive_loss', '--criterion', 'adaptive_loss',
'--adaptive-softmax-cutoff', '5,10,15', '--adaptive-softmax-cutoff', '5,10,15',
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]', '--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment