Unverified Commit 66415206 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

fairseq-py goes distributed (#106)

This PR includes breaking API changes to modularize fairseq-py and adds support for distributed training across multiple nodes.

Changes:
- c7033ef: add support for distributed training! See updated README for usage.
- e016299: modularize fairseq-py, adding support for register_model, register_criterion, register_optimizer, etc.
- 154e440: update LSTM implementation to use PackedSequence objects in the encoder, better following best practices and improving perf
- 90c2973 and 1da6265: improve unit test coverage
parent 7e86e30c
......@@ -161,6 +161,44 @@ $ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
```
# Distributed version
Distributed training in fairseq-py is implemented on top of [torch.distributed](http://pytorch.org/docs/master/distributed.html).
In order to run it requires one process per GPU. In order for those processes to be able to discover each other
they need to know a unique host and port that can be used to establish initial connection and each process
needs to be assigned a rank, that is a unique number from 0 to n-1 where n is the total number of GPUs.
Below is the example of training of a big En2Fr model on 16 nodes with 8 GPUs each (in total 128 GPUs):
If you run on a cluster managed by [SLURM](https://slurm.schedmd.com/) you can train the WMT'14 En2Fr model with
the following command:
```
$ DATA=... # path to the preprocessed dataset, must be visible from all nodes
$ PORT=9218 # any available tcp port that can be used by the trained to establish initial connection
$ sbatch --job-name fairseq-py --gres gpu:8 --nodes 16 --ntasks-per-node 8 \
--cpus-per-task 10 --no-requeue --wrap 'srun --output train.log.node%t \
--error train.stderr.node%t.%j python train.py $DATA --distributed-world-size 128 \
--distributed-port $PORT --force-anneal 50 --lr-scheduler fixed --max-epoch 55 \
--arch fconv_wmt_en_fr --optimizer nag --lr 0.1,4 --max-tokens 3000 \
--clip-norm 0.1 --dropout 0.1 --criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 --wd 0.0001'
```
Alternatively you'll need to manually start one process per each GPU:
```
$ DATA=... # path to the preprocessed dataset, must be visible from all nodes
$ HOST_PORT=your.devserver.com:9218 # has to be one of the hosts that will be used by the job \
and the port on that host has to be available
$ RANK=... # the rank of this process, has to go from 0 to 127 in case of 128 GPUs
$ python train.py $DATA --distributed-world-size 128 \
--force-anneal 50 --lr-scheduler fixed --max-epoch 55 \
--arch fconv_wmt_en_fr --optimizer nag --lr 0.1,4 --max-tokens 3000 \
--clip-norm 0.1 --dropout 0.1 --criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 --wd 0.0001 \
--distributed-init-method='tcp://$HOST_PORT' --distributed-rank=$RANK
```
# Join the fairseq community
* Facebook page: https://www.facebook.com/groups/fairseq.users
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import socket
import subprocess
from singleprocess_train import main as single_process_main
from fairseq import distributed_utils, options
def main(args):
if args.distributed_init_method is None and args.distributed_port > 0:
# We can determine the init method automatically for Slurm.
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
single_process_main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from .multiprocessing_pdb import pdb
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import ctypes
import math
......
......@@ -20,6 +20,10 @@ at::Type& getDataType(const char* dtype) {
return at::getType(at::kCUDA, at::kFloat);
} else if (strcmp(dtype, "torch.FloatTensor") == 0) {
return at::getType(at::kCPU, at::kFloat);
} else if (strcmp(dtype, "torch.cuda.DoubleTensor") == 0) {
return at::getType(at::kCUDA, at::kDouble);
} else if (strcmp(dtype, "torch.DoubleTensor") == 0) {
return at::getType(at::kCPU, at::kDouble);
} else {
throw std::runtime_error(std::string("Unsupported data type: ") + dtype);
}
......
......@@ -4,12 +4,42 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from .cross_entropy import CrossEntropyCriterion
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
import importlib
import os
from .fairseq_criterion import FairseqCriterion
CRITERION_REGISTRY = {}
CRITERION_CLASS_NAMES = set()
def build_criterion(args, src_dict, dst_dict):
return CRITERION_REGISTRY[args.criterion](args, src_dict, dst_dict)
def register_criterion(name):
"""Decorator to register a new criterion."""
def register_criterion_cls(cls):
if name in CRITERION_REGISTRY:
raise ValueError('Cannot register duplicate criterion ({})'.format(name))
if not issubclass(cls, FairseqCriterion):
raise ValueError('Criterion ({}: {}) must extend FairseqCriterion'.format(name, cls.__name__))
if cls.__name__ in CRITERION_CLASS_NAMES:
# We use the criterion class name as a unique identifier in
# checkpoints, so all criterions must have unique class names.
raise ValueError('Cannot register criterion with duplicate class name ({})'.format(cls.__name__))
CRITERION_REGISTRY[name] = cls
CRITERION_CLASS_NAMES.add(cls.__name__)
return cls
return register_criterion_cls
__all__ = [
'CrossEntropyCriterion',
'LabelSmoothedCrossEntropyCriterion',
]
# automatically import any Python files in the criterions/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.criterions.' + module)
......@@ -4,18 +4,18 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch.nn.functional as F
from .fairseq_criterion import FairseqCriterion
from . import FairseqCriterion, register_criterion
@register_criterion('cross_entropy')
class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, dst_dict):
super().__init__(args, dst_dict)
def __init__(self, args, src_dict, dst_dict):
super().__init__(args, src_dict, dst_dict)
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
......@@ -27,6 +27,7 @@ class CrossEntropyCriterion(FairseqCriterion):
"""
net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = sample['target'].view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce)
......
......@@ -4,18 +4,22 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss):
def __init__(self, args, dst_dict):
def __init__(self, args, src_dict, dst_dict):
super().__init__()
self.args = args
self.padding_idx = dst_dict.pad()
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
pass
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
......
......@@ -4,16 +4,14 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
from torch.autograd.variable import Variable
import torch.nn.functional as F
from fairseq import utils
from .fairseq_criterion import FairseqCriterion
from . import FairseqCriterion, register_criterion
class LabelSmoothedNLLLoss(torch.autograd.Function):
......@@ -46,12 +44,18 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
return utils.volatile_variable(ctx.grad_input) * grad, None, None, None, None, None
@register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, dst_dict, weights=None):
super().__init__(args, dst_dict)
def __init__(self, args, src_dict, dst_dict):
super().__init__(args, src_dict, dst_dict)
self.eps = args.label_smoothing
self.weights = weights
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing')
def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
......@@ -63,8 +67,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
"""
net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, self.weights, reduce)
loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, None, reduce)
nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
......
......@@ -4,11 +4,11 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import contextlib
import itertools
import glob
import math
import numbers
import numpy as np
import os
......@@ -130,10 +130,10 @@ class LanguageDatasets(object):
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader(self, split, num_workers=0, max_tokens=None,
def train_dataloader(self, split, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
seed=None, epoch=1, sample_without_replacement=0,
sort_by_source_size=False):
sort_by_source_size=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size(
......@@ -141,40 +141,27 @@ class LanguageDatasets(object):
max_sentences=max_sentences, epoch=epoch,
sample=sample_without_replacement, max_positions=max_positions,
sort_by_source_size=sort_by_source_size)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
dataset, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False,
descending=False):
descending=False, shard_id=0, num_shards=1):
dataset = self.splits[split]
batch_sampler = list(batches_by_size(
batch_sampler = batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending))
descending=descending)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler)
def skip_group_enumerator(it, ngpus, offset=0):
res = []
idx = 0
for i, sample in enumerate(it):
if i < offset:
continue
res.append(sample)
if len(res) >= ngpus:
yield (i, res)
res = []
idx = i + 1
if len(res) > 0:
yield (idx, res)
class sharded_iterator(object):
def __init__(self, itr, num_shards, shard_id):
......@@ -192,7 +179,7 @@ class sharded_iterator(object):
yield v
class LanguagePairDataset(object):
class LanguagePairDataset(torch.utils.data.Dataset):
# padding constants
LEFT_PAD_SOURCE = True
......@@ -222,26 +209,47 @@ class LanguagePairDataset(object):
@staticmethod
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return LanguagePairDataset.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning)
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
target = merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True,
)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
return {
'id': torch.LongTensor([s['id'].item() for s in samples]),
'id': id,
'ntokens': sum(len(s['target']) for s in samples),
'net_input': {
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE),
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True),
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
},
'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET),
'target': target,
}
@staticmethod
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning):
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx)
......@@ -292,7 +300,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
sample_len = 0
ignored = []
for idx in indices:
for idx in map(int, indices):
if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
if ignore_invalid_inputs:
ignored.append(idx)
......@@ -332,9 +340,9 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
indices = np.argsort(src.sizes, kind='mergesort')
if descending:
indices = np.flip(indices, 0)
return _make_batches(
return list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False)
ignore_invalid_inputs, allow_different_src_lens=False))
def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
......@@ -380,6 +388,18 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
return batches
def mask_batches(batch_sampler, shard_id, num_shards):
if num_shards == 1:
return batch_sampler
res = [
batch
for i, batch in enumerate(batch_sampler)
if i % num_shards == shard_id
]
expected_length = int(math.ceil(len(batch_sampler) / num_shards))
return res + [[]] * (expected_length - len(res))
@contextlib.contextmanager
def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
......@@ -115,7 +114,7 @@ class Dictionary(object):
return Dictionary.load(fd)
except FileNotFoundError as fnfe:
raise fnfe
except:
except Exception:
raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f))
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import pickle
import torch.distributed
def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
if args.distributed_init_method.startswith('tcp://'):
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size, rank=args.distributed_rank)
else:
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size)
args.distributed_rank = torch.distributed.get_rank()
if args.distributed_rank != 0:
suppress_output()
return args.distributed_rank
def suppress_output():
"""Suppress printing on the current device. Force printing with `force=True`."""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
if 'force' in kwargs:
force = kwargs.pop('force')
if force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def all_reduce_and_rescale_tensors(tensors, rescale_denom, buffer_size=10485760):
"""All-reduce and rescale tensors in chunks of the specified size.
Args:
tensors: list of Tensors to all-reduce
rescale_denom: denominator for rescaling summed Tensors
buffer_size: all-reduce chunk size in bytes
"""
# buffer size is in bytes, determine equiv. # of elements based on data type
buffer_t = tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_()
buffer = []
def all_reduce_buffer():
# copy tensors into buffer_t
offset = 0
for t in buffer:
numel = t.numel()
buffer_t[offset:offset+numel].copy_(t.view(-1))
offset += numel
# all-reduce and rescale
torch.distributed.all_reduce(buffer_t[:offset])
buffer_t.div_(rescale_denom)
# copy all-reduced buffer back into tensors
offset = 0
for t in buffer:
numel = t.numel()
t.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for t in tensors:
sz = t.numel() * t.element_size()
if sz > buffer_size:
# tensor is bigger than buffer, all-reduce and rescale directly
torch.distributed.all_reduce(t)
t.div_(rescale_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [t]
filled = sz
else:
# add tensor to buffer
buffer.append(t)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def all_gather_list(data, max_size=4096):
"""Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != all_gather_list._in_buffer.size():
all_gather_list._in_buffer = torch.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size)
for i in range(world_size)
]
in_buffer = all_gather_list._in_buffer
out_buffers = all_gather_list._out_buffers
enc = pickle.dumps(data)
if len(enc) >= max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(len(enc)))
in_buffer[0] = len(enc)
in_buffer[1:len(enc)+1] = torch.ByteTensor(enc)
torch.distributed.all_gather(out_buffers, in_buffer.cuda())
result = []
for i in range(world_size):
out_buffer = out_buffers[i]
size = out_buffer[0]
result.append(
pickle.loads(bytes(out_buffer[1:size+1].tolist()))
)
return result
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import numpy as np
import os
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import time
......
......@@ -4,21 +4,58 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from .fairseq_decoder import FairseqDecoder
from .fairseq_encoder import FairseqEncoder
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import FairseqModel
import importlib
import os
from .fairseq_decoder import FairseqDecoder # noqa: F401
from .fairseq_encoder import FairseqEncoder # noqa: F401
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import FairseqModel # noqa: F401
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
def build_model(args, src_dict, dst_dict):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, src_dict, dst_dict)
def register_model(name):
"""Decorator to register a new model (e.g., LSTM)."""
def register_model_cls(cls):
if name in MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model ({})'.format(name))
if not issubclass(cls, FairseqModel):
raise ValueError('Model ({}: {}) must extend FairseqModel'.format(name, cls.__name__))
MODEL_REGISTRY[name] = cls
return cls
return register_model_cls
def register_model_architecture(model_name, arch_name):
"""Decorator to register a new model architecture (e.g., lstm_luong_wmt_en_de)."""
from . import fconv, lstm
def register_model_arch_fn(fn):
if model_name not in MODEL_REGISTRY:
raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name))
if arch_name in ARCH_MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name))
if not callable(fn):
raise ValueError('Model architecture must be callable ({})'.format(arch_name))
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
ARCH_CONFIG_REGISTRY[arch_name] = fn
return fn
return register_model_arch_fn
__all__ = ['fconv', 'lstm']
arch_model_map = {}
for model in __all__:
archs = locals()[model].get_archs()
for arch in archs:
assert arch not in arch_model_map, 'Duplicate model architecture detected: {}'.format(arch)
arch_model_map[arch] = model
# automatically import any Python files in the models/ directory
for file in os.listdir(os.path.dirname(__file__)):
if file.endswith('.py') and not file.startswith('_'):
module = file[:file.find('.py')]
importlib.import_module('fairseq.models.' + module)
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
import torch.nn.functional as F
......@@ -17,6 +16,9 @@ class FairseqDecoder(nn.Module):
super().__init__()
self.dictionary = dictionary
def forward(self, prev_output_tokens, encoder_out):
raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
vocab = net_output.size(-1)
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
......@@ -16,6 +15,9 @@ class FairseqEncoder(nn.Module):
super().__init__()
self.dictionary = dictionary
def forward(self, src_tokens, src_lengths):
raise NotImplementedError
def max_positions(self):
"""Maximum input length supported by the encoder."""
raise NotImplementedError
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
from . import FairseqDecoder
......@@ -17,7 +16,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
self._is_incremental_eval = False
self._incremental_state = {}
def forward(self, tokens, encoder_out):
def forward(self, prev_output_tokens, encoder_out):
if self._is_incremental_eval:
raise NotImplementedError
else:
......
......@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import torch.nn as nn
......@@ -30,10 +29,20 @@ class FairseqModel(nn.Module):
self._is_generation_fast = False
def forward(self, src_tokens, input_tokens):
encoder_out = self.encoder(src_tokens)
decoder_out, _ = self.decoder(input_tokens, encoder_out)
return decoder_out.view(-1, decoder_out.size(-1))
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
pass
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
raise NotImplementedError
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out, _ = self.decoder(prev_output_tokens, encoder_out)
return decoder_out
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
......@@ -47,6 +56,16 @@ class FairseqModel(nn.Module):
"""Maximum output length supported by the decoder."""
return self.decoder.max_positions()
def load_state_dict(self, state_dict, strict=True):
"""Copies parameters and buffers from state_dict into this module and
its descendants.
Overrides the method in nn.Module; compared with that method this
additionally "upgrades" state_dicts from old checkpoints.
"""
state_dict = self.upgrade_state_dict(state_dict)
super().load_state_dict(state_dict, strict)
def upgrade_state_dict(self, state_dict):
state_dict = self.encoder.upgrade_state_dict(state_dict)
state_dict = self.decoder.upgrade_state_dict(state_dict)
......
......@@ -4,25 +4,68 @@
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from fairseq.data import LanguagePairDataset
from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, register_model_architecture
@register_model('fconv')
class FConvModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-layers', type=str, metavar='EXPR',
help='encoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-layers', type=str, metavar='EXPR',
help='decoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
help='decoder output embedding dimension')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
parser.add_argument('--share-input-output-embed', action='store_true',
help='share input and output embeddings (requires'
' --decoder-out-embed-dim and --decoder-embed-dim'
' to be equal)')
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
encoder = FConvEncoder(
src_dict,
embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
max_positions=args.max_source_positions,
)
decoder = FConvDecoder(
dst_dict,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
max_positions=args.max_target_positions,
share_embed=args.share_input_output_embed
)
return FConvModel(encoder, decoder)
class FConvEncoder(FairseqEncoder):
"""Convolutional encoder"""
......@@ -35,8 +78,12 @@ class FConvEncoder(FairseqEncoder):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = PositionalEmbedding(max_positions, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
self.embed_positions = PositionalEmbedding(
max_positions,
embed_dim,
padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE,
)
in_channels = convolutions[0][0]
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
......@@ -52,7 +99,7 @@ class FConvEncoder(FairseqEncoder):
in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim)
def forward(self, src_tokens):
def forward(self, src_tokens, src_lengths):
# embed tokens and positions
x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training)
......@@ -151,8 +198,12 @@ class FConvDecoder(FairseqIncrementalDecoder):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = PositionalEmbedding(max_positions, embed_dim, padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
self.embed_positions = PositionalEmbedding(
max_positions,
embed_dim,
padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList()
......@@ -178,19 +229,19 @@ class FConvDecoder(FairseqIncrementalDecoder):
else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, input_tokens, encoder_out):
def forward(self, prev_output_tokens, encoder_out):
# split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out)
# embed positions
positions = self.embed_positions(input_tokens)
positions = self.embed_positions(prev_output_tokens)
if self._is_incremental_eval:
# keep only the last token for incremental forward pass
input_tokens = input_tokens[:, -1:]
prev_output_tokens = prev_output_tokens[:, -1:]
# embed tokens and positions
x = self.embed_tokens(input_tokens) + positions
x = self.embed_tokens(prev_output_tokens) + positions
x = F.dropout(x, p=self.dropout, training=self.training)
target_embedding = x
......@@ -316,63 +367,8 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
return nn.utils.weight_norm(m, dim=2)
def get_archs():
return [
'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr',
]
def _check_arch(args):
"""Check that the specified architecture is valid and not ambiguous."""
if args.arch not in get_archs():
raise ValueError('Unknown fconv model architecture: {}'.format(args.arch))
if args.arch != 'fconv':
# check that architecture is not ambiguous
for a in ['encoder_embed_dim', 'encoder_layers', 'decoder_embed_dim', 'decoder_layers',
'decoder_out_embed_dim']:
if hasattr(args, a):
raise ValueError('--{} cannot be combined with --arch={}'.format(a, args.arch))
def parse_arch(args):
_check_arch(args)
if args.arch == 'fconv_iwslt_de_en':
args.encoder_embed_dim = 256
args.encoder_layers = '[(256, 3)] * 4'
args.decoder_embed_dim = 256
args.decoder_layers = '[(256, 3)] * 3'
args.decoder_out_embed_dim = 256
elif args.arch == 'fconv_wmt_en_ro':
args.encoder_embed_dim = 512
args.encoder_layers = '[(512, 3)] * 20'
args.decoder_embed_dim = 512
args.decoder_layers = '[(512, 3)] * 20'
args.decoder_out_embed_dim = 512
elif args.arch == 'fconv_wmt_en_de':
convs = '[(512, 3)] * 9' # first 9 layers have 512 units
convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units
convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs
args.decoder_embed_dim = 768
args.decoder_layers = convs
args.decoder_out_embed_dim = 512
elif args.arch == 'fconv_wmt_en_fr':
convs = '[(512, 3)] * 6' # first 6 layers have 512 units
convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units
convs += ' + [(1024, 3)] * 3' # next 3 layers have 1024 units
convs += ' + [(2048, 1)] * 1' # next 1 layer uses 1x1 convolutions
convs += ' + [(4096, 1)] * 1' # final 1 layer uses 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs
args.decoder_embed_dim = 768
args.decoder_layers = convs
args.decoder_out_embed_dim = 512
else:
assert args.arch == 'fconv'
# default architecture
@register_model_architecture('fconv', 'fconv')
def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20')
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
......@@ -380,25 +376,51 @@ def parse_arch(args):
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
args.decoder_attention = getattr(args, 'decoder_attention', 'True')
args.share_input_output_embed = getattr(args, 'share_input_output_embed', False)
return args
def build_model(args, src_dict, dst_dict):
encoder = FConvEncoder(
src_dict,
embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
max_positions=args.max_source_positions,
)
decoder = FConvDecoder(
dst_dict,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
max_positions=args.max_target_positions,
share_embed=args.share_input_output_embed
)
return FConvModel(encoder, decoder)
@register_model_architecture('fconv', 'fconv_iwslt_de_en')
def fconv_iwslt_de_en(args):
base_architecture(args)
args.encoder_embed_dim = 256
args.encoder_layers = '[(256, 3)] * 4'
args.decoder_embed_dim = 256
args.decoder_layers = '[(256, 3)] * 3'
args.decoder_out_embed_dim = 256
@register_model_architecture('fconv', 'fconv_wmt_en_ro')
def fconv_wmt_en_ro(args):
base_architecture(args)
args.encoder_embed_dim = 512
args.encoder_layers = '[(512, 3)] * 20'
args.decoder_embed_dim = 512
args.decoder_layers = '[(512, 3)] * 20'
args.decoder_out_embed_dim = 512
@register_model_architecture('fconv', 'fconv_wmt_en_de')
def fconv_wmt_en_de(args):
base_architecture(args)
convs = '[(512, 3)] * 9' # first 9 layers have 512 units
convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units
convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs
args.decoder_embed_dim = 768
args.decoder_layers = convs
args.decoder_out_embed_dim = 512
@register_model_architecture('fconv', 'fconv_wmt_en_fr')
def fconv_wmt_en_fr(args):
base_architecture(args)
convs = '[(512, 3)] * 6' # first 6 layers have 512 units
convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units
convs += ' + [(1024, 3)] * 3' # next 3 layers have 1024 units
convs += ' + [(2048, 1)] * 1' # next 1 layer uses 1x1 convolutions
convs += ' + [(4096, 1)] * 1' # final 1 layer uses 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs
args.decoder_embed_dim = 768
args.decoder_layers = convs
args.decoder_out_embed_dim = 512
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment