Commit fc677c94 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Fix proxying in DistributedFairseqModel

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/302

Differential Revision: D10174608

Pulled By: myleott

fbshipit-source-id: 4e2dfc76eae97afc5488f29b47e74f9897a643ff
parent f766c9a0
......@@ -12,22 +12,21 @@ from fairseq.distributed_utils import c10d_status
from . import BaseFairseqModel
class DistributedFairseqModel(BaseFairseqModel):
def DistributedFairseqModel(args, model):
"""
A wrapper around a :class:`BaseFairseqModel` instance that adds support for
distributed training.
Wrap a *model* to support distributed data parallel training.
Anytime a method or attribute is called on this class we first try to
forward it to the underlying DistributedDataParallel instance, otherwise we
forward it to the original :class:`BaseFairseqModel` instance.
This is similar to the built-in DistributedDataParallel, but allows
additional configuration of the DistributedDataParallel class to
use, and also provides easier access to the wrapped model by
forwarding requests for missing attributes to the wrapped model.
Args:
args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap
"""
def __init__(self, args, model):
super().__init__()
# determine which DDP class to extend
assert isinstance(model, BaseFairseqModel)
if args.ddp_backend == 'c10d':
if c10d_status.is_default:
......@@ -39,7 +38,7 @@ class DistributedFairseqModel(BaseFairseqModel):
'Can\'t find c10d version of DistributedDataParallel. '
'Please update PyTorch.'
)
self.ddp_model = ddp_class(
init_kwargs = dict(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
......@@ -51,7 +50,7 @@ class DistributedFairseqModel(BaseFairseqModel):
ddp_class = parallel.deprecated.DistributedDataParallel
else:
ddp_class = parallel.DistributedDataParallel
self.ddp_model = ddp_class(
init_kwargs = dict(
module=model,
device_ids=[args.device_id],
output_device=args.device_id,
......@@ -60,19 +59,17 @@ class DistributedFairseqModel(BaseFairseqModel):
else:
raise ValueError('Unknown --ddp-backend: ' + args.ddp_backend)
def __call__(self, *args, **kwargs):
return self.ddp_model(*args, **kwargs)
class _DistributedFairseqModel(ddp_class):
"""Extend DistributedDataParallel to check for missing
attributes in the wrapped module."""
def forward(self, *args, **kwargs):
return self.ddp_model.forward(*args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def __getattr__(self, name):
try:
wrapped_module = super().__getattr__('module')
if hasattr(wrapped_module, name):
return getattr(wrapped_module, name)
return super().__getattr__(name)
except AttributeError:
pass
try:
return self.ddp_model.__getattr__(name)
except AttributeError:
pass
return self.ddp_model.module.__getattr__(name)
return _DistributedFairseqModel(**init_kwargs)
......@@ -281,7 +281,7 @@ def train_language_model(data_dir, arch):
data_dir,
'--arch', arch,
'--optimizer', 'nag',
'--lr', '1.0',
'--lr', '0.1',
'--criterion', 'adaptive_loss',
'--adaptive-softmax-cutoff', '5,10,15',
'--decoder-layers', '[(850, 3)] * 2 + [(1024,4)]',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment