Unverified Commit 66415206 authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

fairseq-py goes distributed (#106)

This PR includes breaking API changes to modularize fairseq-py and adds support for distributed training across multiple nodes.

Changes:
- c7033ef: add support for distributed training! See updated README for usage.
- e016299: modularize fairseq-py, adding support for register_model, register_criterion, register_optimizer, etc.
- 154e440: update LSTM implementation to use PackedSequence objects in the encoder, better following best practices and improving perf
- 90c2973 and 1da6265: improve unit test coverage
parent 7e86e30c
...@@ -161,6 +161,44 @@ $ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref ...@@ -161,6 +161,44 @@ $ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787) BLEU4 = 40.83, 67.5/46.9/34.4/25.5 (BP=1.000, ratio=1.006, syslen=83262, reflen=82787)
``` ```
# Distributed version
Distributed training in fairseq-py is implemented on top of [torch.distributed](http://pytorch.org/docs/master/distributed.html).
In order to run it requires one process per GPU. In order for those processes to be able to discover each other
they need to know a unique host and port that can be used to establish initial connection and each process
needs to be assigned a rank, that is a unique number from 0 to n-1 where n is the total number of GPUs.
Below is the example of training of a big En2Fr model on 16 nodes with 8 GPUs each (in total 128 GPUs):
If you run on a cluster managed by [SLURM](https://slurm.schedmd.com/) you can train the WMT'14 En2Fr model with
the following command:
```
$ DATA=... # path to the preprocessed dataset, must be visible from all nodes
$ PORT=9218 # any available tcp port that can be used by the trained to establish initial connection
$ sbatch --job-name fairseq-py --gres gpu:8 --nodes 16 --ntasks-per-node 8 \
--cpus-per-task 10 --no-requeue --wrap 'srun --output train.log.node%t \
--error train.stderr.node%t.%j python train.py $DATA --distributed-world-size 128 \
--distributed-port $PORT --force-anneal 50 --lr-scheduler fixed --max-epoch 55 \
--arch fconv_wmt_en_fr --optimizer nag --lr 0.1,4 --max-tokens 3000 \
--clip-norm 0.1 --dropout 0.1 --criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 --wd 0.0001'
```
Alternatively you'll need to manually start one process per each GPU:
```
$ DATA=... # path to the preprocessed dataset, must be visible from all nodes
$ HOST_PORT=your.devserver.com:9218 # has to be one of the hosts that will be used by the job \
and the port on that host has to be available
$ RANK=... # the rank of this process, has to go from 0 to 127 in case of 128 GPUs
$ python train.py $DATA --distributed-world-size 128 \
--force-anneal 50 --lr-scheduler fixed --max-epoch 55 \
--arch fconv_wmt_en_fr --optimizer nag --lr 0.1,4 --max-tokens 3000 \
--clip-norm 0.1 --dropout 0.1 --criterion label_smoothed_cross_entropy \
--label-smoothing 0.1 --wd 0.0001 \
--distributed-init-method='tcp://$HOST_PORT' --distributed-rank=$RANK
```
# Join the fairseq community # Join the fairseq community
* Facebook page: https://www.facebook.com/groups/fairseq.users * Facebook page: https://www.facebook.com/groups/fairseq.users
......
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import os
import socket
import subprocess
from singleprocess_train import main as single_process_main
from fairseq import distributed_utils, options
def main(args):
if args.distributed_init_method is None and args.distributed_port > 0:
# We can determine the init method automatically for Slurm.
node_list = os.environ.get('SLURM_JOB_NODELIST')
if node_list is not None:
try:
hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', node_list])
args.distributed_init_method = 'tcp://{host}:{port}'.format(
host=hostnames.split()[0].decode('utf-8'),
port=args.distributed_port)
args.distributed_rank = int(os.environ.get('SLURM_PROCID'))
args.device_id = int(os.environ.get('SLURM_LOCALID'))
except subprocess.CalledProcessError as e: # scontrol failed
raise e
except FileNotFoundError as e: # Slurm is not installed
pass
if args.distributed_init_method is None:
raise ValueError('--distributed-init-method or --distributed-port '
'must be specified for distributed training')
args.distributed_rank = distributed_utils.distributed_init(args)
print('| initialized host {} as rank {}'.format(socket.gethostname(), args.distributed_rank))
single_process_main(args)
if __name__ == '__main__':
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser)
main(args)
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
from .multiprocessing_pdb import pdb from .multiprocessing_pdb import pdb
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import ctypes import ctypes
import math import math
......
...@@ -20,6 +20,10 @@ at::Type& getDataType(const char* dtype) { ...@@ -20,6 +20,10 @@ at::Type& getDataType(const char* dtype) {
return at::getType(at::kCUDA, at::kFloat); return at::getType(at::kCUDA, at::kFloat);
} else if (strcmp(dtype, "torch.FloatTensor") == 0) { } else if (strcmp(dtype, "torch.FloatTensor") == 0) {
return at::getType(at::kCPU, at::kFloat); return at::getType(at::kCPU, at::kFloat);
} else if (strcmp(dtype, "torch.cuda.DoubleTensor") == 0) {
return at::getType(at::kCUDA, at::kDouble);
} else if (strcmp(dtype, "torch.DoubleTensor") == 0) {
return at::getType(at::kCPU, at::kDouble);
} else { } else {
throw std::runtime_error(std::string("Unsupported data type: ") + dtype); throw std::runtime_error(std::string("Unsupported data type: ") + dtype);
} }
......
...@@ -4,12 +4,42 @@ ...@@ -4,12 +4,42 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
from .cross_entropy import CrossEntropyCriterion import importlib
from .label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion import os
from .fairseq_criterion import FairseqCriterion
CRITERION_REGISTRY = {}
CRITERION_CLASS_NAMES = set()
def build_criterion(args, src_dict, dst_dict):
return CRITERION_REGISTRY[args.criterion](args, src_dict, dst_dict)
def register_criterion(name):
"""Decorator to register a new criterion."""
def register_criterion_cls(cls):
if name in CRITERION_REGISTRY:
raise ValueError('Cannot register duplicate criterion ({})'.format(name))
if not issubclass(cls, FairseqCriterion):
raise ValueError('Criterion ({}: {}) must extend FairseqCriterion'.format(name, cls.__name__))
if cls.__name__ in CRITERION_CLASS_NAMES:
# We use the criterion class name as a unique identifier in
# checkpoints, so all criterions must have unique class names.
raise ValueError('Cannot register criterion with duplicate class name ({})'.format(cls.__name__))
CRITERION_REGISTRY[name] = cls
CRITERION_CLASS_NAMES.add(cls.__name__)
return cls
return register_criterion_cls
__all__ = [ # automatically import any Python files in the criterions/ directory
'CrossEntropyCriterion', for file in os.listdir(os.path.dirname(__file__)):
'LabelSmoothedCrossEntropyCriterion', if file.endswith('.py') and not file.startswith('_'):
] module = file[:file.find('.py')]
importlib.import_module('fairseq.criterions.' + module)
...@@ -4,18 +4,18 @@ ...@@ -4,18 +4,18 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import math import math
import torch.nn.functional as F import torch.nn.functional as F
from .fairseq_criterion import FairseqCriterion from . import FairseqCriterion, register_criterion
@register_criterion('cross_entropy')
class CrossEntropyCriterion(FairseqCriterion): class CrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, dst_dict): def __init__(self, args, src_dict, dst_dict):
super().__init__(args, dst_dict) super().__init__(args, src_dict, dst_dict)
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
...@@ -27,6 +27,7 @@ class CrossEntropyCriterion(FairseqCriterion): ...@@ -27,6 +27,7 @@ class CrossEntropyCriterion(FairseqCriterion):
""" """
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx,
reduce=reduce) reduce=reduce)
......
...@@ -4,18 +4,22 @@ ...@@ -4,18 +4,22 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
class FairseqCriterion(_Loss): class FairseqCriterion(_Loss):
def __init__(self, args, dst_dict): def __init__(self, args, src_dict, dst_dict):
super().__init__() super().__init__()
self.args = args self.args = args
self.padding_idx = dst_dict.pad() self.padding_idx = dst_dict.pad()
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
pass
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
......
...@@ -4,16 +4,14 @@ ...@@ -4,16 +4,14 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import math import math
import torch import torch
from torch.autograd.variable import Variable
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import utils
from .fairseq_criterion import FairseqCriterion from . import FairseqCriterion, register_criterion
class LabelSmoothedNLLLoss(torch.autograd.Function): class LabelSmoothedNLLLoss(torch.autograd.Function):
...@@ -46,12 +44,18 @@ class LabelSmoothedNLLLoss(torch.autograd.Function): ...@@ -46,12 +44,18 @@ class LabelSmoothedNLLLoss(torch.autograd.Function):
return utils.volatile_variable(ctx.grad_input) * grad, None, None, None, None, None return utils.volatile_variable(ctx.grad_input) * grad, None, None, None, None, None
@register_criterion('label_smoothed_cross_entropy')
class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
def __init__(self, args, dst_dict, weights=None): def __init__(self, args, src_dict, dst_dict):
super().__init__(args, dst_dict) super().__init__(args, src_dict, dst_dict)
self.eps = args.label_smoothing self.eps = args.label_smoothing
self.weights = weights
@staticmethod
def add_args(parser):
"""Add criterion-specific arguments to the parser."""
parser.add_argument('--label-smoothing', default=0., type=float, metavar='D',
help='epsilon for label smoothing, 0 means no label smoothing')
def forward(self, model, sample, reduce=True): def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample. """Compute the loss for the given sample.
...@@ -63,8 +67,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion): ...@@ -63,8 +67,9 @@ class LabelSmoothedCrossEntropyCriterion(FairseqCriterion):
""" """
net_output = model(**sample['net_input']) net_output = model(**sample['net_input'])
lprobs = model.get_normalized_probs(net_output, log_probs=True) lprobs = model.get_normalized_probs(net_output, log_probs=True)
lprobs = lprobs.view(-1, lprobs.size(-1))
target = sample['target'].view(-1) target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, self.weights, reduce) loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, None, reduce)
nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce) nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens'] sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = { logging_output = {
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import contextlib import contextlib
import itertools import itertools
import glob import glob
import math
import numbers import numbers
import numpy as np import numpy as np
import os import os
...@@ -130,10 +130,10 @@ class LanguageDatasets(object): ...@@ -130,10 +130,10 @@ class LanguageDatasets(object):
assert self.src_dict.eos() == self.dst_dict.eos() assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk() assert self.src_dict.unk() == self.dst_dict.unk()
def train_dataloader(self, split, num_workers=0, max_tokens=None, def train_dataloader(self, split, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024), max_sentences=None, max_positions=(1024, 1024),
seed=None, epoch=1, sample_without_replacement=0, seed=None, epoch=1, sample_without_replacement=0,
sort_by_source_size=False): sort_by_source_size=False, shard_id=0, num_shards=1):
dataset = self.splits[split] dataset = self.splits[split]
with numpy_seed(seed): with numpy_seed(seed):
batch_sampler = shuffled_batches_by_size( batch_sampler = shuffled_batches_by_size(
...@@ -141,40 +141,27 @@ class LanguageDatasets(object): ...@@ -141,40 +141,27 @@ class LanguageDatasets(object):
max_sentences=max_sentences, epoch=epoch, max_sentences=max_sentences, epoch=epoch,
sample=sample_without_replacement, max_positions=max_positions, sample=sample_without_replacement, max_positions=max_positions,
sort_by_source_size=sort_by_source_size) sort_by_source_size=sort_by_source_size)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater, dataset, collate_fn=dataset.collater,
batch_sampler=batch_sampler) batch_sampler=batch_sampler)
def eval_dataloader(self, split, num_workers=0, max_tokens=None, def eval_dataloader(self, split, num_workers=0, max_tokens=None,
max_sentences=None, max_positions=(1024, 1024), max_sentences=None, max_positions=(1024, 1024),
skip_invalid_size_inputs_valid_test=False, skip_invalid_size_inputs_valid_test=False,
descending=False): descending=False, shard_id=0, num_shards=1):
dataset = self.splits[split] dataset = self.splits[split]
batch_sampler = list(batches_by_size( batch_sampler = batches_by_size(
dataset.src, dataset.dst, max_tokens, max_sentences, dataset.src, dataset.dst, max_tokens, max_sentences,
max_positions=max_positions, max_positions=max_positions,
ignore_invalid_inputs=skip_invalid_size_inputs_valid_test, ignore_invalid_inputs=skip_invalid_size_inputs_valid_test,
descending=descending)) descending=descending)
batch_sampler = mask_batches(batch_sampler, shard_id=shard_id, num_shards=num_shards)
return torch.utils.data.DataLoader( return torch.utils.data.DataLoader(
dataset, num_workers=num_workers, collate_fn=dataset.collater, dataset, num_workers=num_workers, collate_fn=dataset.collater,
batch_sampler=batch_sampler) batch_sampler=batch_sampler)
def skip_group_enumerator(it, ngpus, offset=0):
res = []
idx = 0
for i, sample in enumerate(it):
if i < offset:
continue
res.append(sample)
if len(res) >= ngpus:
yield (i, res)
res = []
idx = i + 1
if len(res) > 0:
yield (idx, res)
class sharded_iterator(object): class sharded_iterator(object):
def __init__(self, itr, num_shards, shard_id): def __init__(self, itr, num_shards, shard_id):
...@@ -192,7 +179,7 @@ class sharded_iterator(object): ...@@ -192,7 +179,7 @@ class sharded_iterator(object):
yield v yield v
class LanguagePairDataset(object): class LanguagePairDataset(torch.utils.data.Dataset):
# padding constants # padding constants
LEFT_PAD_SOURCE = True LEFT_PAD_SOURCE = True
...@@ -222,26 +209,47 @@ class LanguagePairDataset(object): ...@@ -222,26 +209,47 @@ class LanguagePairDataset(object):
@staticmethod @staticmethod
def collate(samples, pad_idx, eos_idx): def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False): def merge(key, left_pad, move_eos_to_beginning=False):
return LanguagePairDataset.collate_tokens( return LanguagePairDataset.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad, move_eos_to_beginning) [s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE)
target = merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True,
)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
return { return {
'id': torch.LongTensor([s['id'].item() for s in samples]), 'id': id,
'ntokens': sum(len(s['target']) for s in samples), 'ntokens': sum(len(s['target']) for s in samples),
'net_input': { 'net_input': {
'src_tokens': merge('source', left_pad=LanguagePairDataset.LEFT_PAD_SOURCE), 'src_tokens': src_tokens,
# we create a shifted version of targets for feeding the 'src_lengths': src_lengths,
# previous output token(s) into the next decoder step 'prev_output_tokens': prev_output_tokens,
'input_tokens': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
move_eos_to_beginning=True),
}, },
'target': merge('target', left_pad=LanguagePairDataset.LEFT_PAD_TARGET), 'target': target,
} }
@staticmethod @staticmethod
def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning): def collate_tokens(values, pad_idx, eos_idx, left_pad, move_eos_to_beginning=False):
size = max(v.size(0) for v in values) size = max(v.size(0) for v in values)
res = values[0].new(len(values), size).fill_(pad_idx) res = values[0].new(len(values), size).fill_(pad_idx)
...@@ -292,7 +300,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions, ...@@ -292,7 +300,7 @@ def _make_batches(src, dst, indices, max_tokens, max_sentences, max_positions,
sample_len = 0 sample_len = 0
ignored = [] ignored = []
for idx in indices: for idx in map(int, indices):
if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions): if not _valid_size(src.sizes[idx], dst.sizes[idx], max_positions):
if ignore_invalid_inputs: if ignore_invalid_inputs:
ignored.append(idx) ignored.append(idx)
...@@ -332,9 +340,9 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -332,9 +340,9 @@ def batches_by_size(src, dst, max_tokens=None, max_sentences=None,
indices = np.argsort(src.sizes, kind='mergesort') indices = np.argsort(src.sizes, kind='mergesort')
if descending: if descending:
indices = np.flip(indices, 0) indices = np.flip(indices, 0)
return _make_batches( return list(_make_batches(
src, dst, indices, max_tokens, max_sentences, max_positions, src, dst, indices, max_tokens, max_sentences, max_positions,
ignore_invalid_inputs, allow_different_src_lens=False) ignore_invalid_inputs, allow_different_src_lens=False))
def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None, def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
...@@ -380,6 +388,18 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None, ...@@ -380,6 +388,18 @@ def shuffled_batches_by_size(src, dst, max_tokens=None, max_sentences=None,
return batches return batches
def mask_batches(batch_sampler, shard_id, num_shards):
if num_shards == 1:
return batch_sampler
res = [
batch
for i, batch in enumerate(batch_sampler)
if i % num_shards == shard_id
]
expected_length = int(math.ceil(len(batch_sampler) / num_shards))
return res + [[]] * (expected_length - len(res))
@contextlib.contextmanager @contextlib.contextmanager
def numpy_seed(seed): def numpy_seed(seed):
"""Context manager which seeds the NumPy PRNG with the specified seed and """Context manager which seeds the NumPy PRNG with the specified seed and
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import math import math
import torch import torch
...@@ -115,7 +114,7 @@ class Dictionary(object): ...@@ -115,7 +114,7 @@ class Dictionary(object):
return Dictionary.load(fd) return Dictionary.load(fd)
except FileNotFoundError as fnfe: except FileNotFoundError as fnfe:
raise fnfe raise fnfe
except: except Exception:
raise Exception("Incorrect encoding detected in {}, please " raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f)) "rebuild the dataset".format(f))
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import pickle
import torch.distributed
def distributed_init(args):
if args.distributed_world_size == 1:
raise ValueError('Cannot initialize distributed with distributed_world_size=1')
print('| distributed init (rank {}): {}'.format(
args.distributed_rank, args.distributed_init_method), flush=True)
if args.distributed_init_method.startswith('tcp://'):
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size, rank=args.distributed_rank)
else:
torch.distributed.init_process_group(
backend=args.distributed_backend, init_method=args.distributed_init_method,
world_size=args.distributed_world_size)
args.distributed_rank = torch.distributed.get_rank()
if args.distributed_rank != 0:
suppress_output()
return args.distributed_rank
def suppress_output():
"""Suppress printing on the current device. Force printing with `force=True`."""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
if 'force' in kwargs:
force = kwargs.pop('force')
if force:
builtin_print(*args, **kwargs)
__builtin__.print = print
def all_reduce_and_rescale_tensors(tensors, rescale_denom, buffer_size=10485760):
"""All-reduce and rescale tensors in chunks of the specified size.
Args:
tensors: list of Tensors to all-reduce
rescale_denom: denominator for rescaling summed Tensors
buffer_size: all-reduce chunk size in bytes
"""
# buffer size is in bytes, determine equiv. # of elements based on data type
buffer_t = tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_()
buffer = []
def all_reduce_buffer():
# copy tensors into buffer_t
offset = 0
for t in buffer:
numel = t.numel()
buffer_t[offset:offset+numel].copy_(t.view(-1))
offset += numel
# all-reduce and rescale
torch.distributed.all_reduce(buffer_t[:offset])
buffer_t.div_(rescale_denom)
# copy all-reduced buffer back into tensors
offset = 0
for t in buffer:
numel = t.numel()
t.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for t in tensors:
sz = t.numel() * t.element_size()
if sz > buffer_size:
# tensor is bigger than buffer, all-reduce and rescale directly
torch.distributed.all_reduce(t)
t.div_(rescale_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [t]
filled = sz
else:
# add tensor to buffer
buffer.append(t)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def all_gather_list(data, max_size=4096):
"""Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size()
if not hasattr(all_gather_list, '_in_buffer') or \
max_size != all_gather_list._in_buffer.size():
all_gather_list._in_buffer = torch.ByteTensor(max_size)
all_gather_list._out_buffers = [
torch.cuda.ByteTensor(max_size)
for i in range(world_size)
]
in_buffer = all_gather_list._in_buffer
out_buffers = all_gather_list._out_buffers
enc = pickle.dumps(data)
if len(enc) >= max_size:
raise ValueError('encoded data exceeds max_size: {}'.format(len(enc)))
in_buffer[0] = len(enc)
in_buffer[1:len(enc)+1] = torch.ByteTensor(enc)
torch.distributed.all_gather(out_buffers, in_buffer.cuda())
result = []
for i in range(world_size):
out_buffer = out_buffers[i]
size = out_buffer[0]
result.append(
pickle.loads(bytes(out_buffer[1:size+1].tolist()))
)
return result
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import numpy as np import numpy as np
import os import os
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import time import time
......
...@@ -4,21 +4,58 @@ ...@@ -4,21 +4,58 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
from .fairseq_decoder import FairseqDecoder import importlib
from .fairseq_encoder import FairseqEncoder import os
from .fairseq_incremental_decoder import FairseqIncrementalDecoder
from .fairseq_model import FairseqModel from .fairseq_decoder import FairseqDecoder # noqa: F401
from .fairseq_encoder import FairseqEncoder # noqa: F401
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import FairseqModel # noqa: F401
MODEL_REGISTRY = {}
ARCH_MODEL_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {}
def build_model(args, src_dict, dst_dict):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, src_dict, dst_dict)
def register_model(name):
"""Decorator to register a new model (e.g., LSTM)."""
def register_model_cls(cls):
if name in MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model ({})'.format(name))
if not issubclass(cls, FairseqModel):
raise ValueError('Model ({}: {}) must extend FairseqModel'.format(name, cls.__name__))
MODEL_REGISTRY[name] = cls
return cls
return register_model_cls
def register_model_architecture(model_name, arch_name):
"""Decorator to register a new model architecture (e.g., lstm_luong_wmt_en_de)."""
from . import fconv, lstm def register_model_arch_fn(fn):
if model_name not in MODEL_REGISTRY:
raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name))
if arch_name in ARCH_MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name))
if not callable(fn):
raise ValueError('Model architecture must be callable ({})'.format(arch_name))
ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
ARCH_CONFIG_REGISTRY[arch_name] = fn
return fn
return register_model_arch_fn
__all__ = ['fconv', 'lstm']
arch_model_map = {} # automatically import any Python files in the models/ directory
for model in __all__: for file in os.listdir(os.path.dirname(__file__)):
archs = locals()[model].get_archs() if file.endswith('.py') and not file.startswith('_'):
for arch in archs: module = file[:file.find('.py')]
assert arch not in arch_model_map, 'Duplicate model architecture detected: {}'.format(arch) importlib.import_module('fairseq.models.' + module)
arch_model_map[arch] = model
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
...@@ -17,6 +16,9 @@ class FairseqDecoder(nn.Module): ...@@ -17,6 +16,9 @@ class FairseqDecoder(nn.Module):
super().__init__() super().__init__()
self.dictionary = dictionary self.dictionary = dictionary
def forward(self, prev_output_tokens, encoder_out):
raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs): def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
vocab = net_output.size(-1) vocab = net_output.size(-1)
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import torch.nn as nn import torch.nn as nn
...@@ -16,6 +15,9 @@ class FairseqEncoder(nn.Module): ...@@ -16,6 +15,9 @@ class FairseqEncoder(nn.Module):
super().__init__() super().__init__()
self.dictionary = dictionary self.dictionary = dictionary
def forward(self, src_tokens, src_lengths):
raise NotImplementedError
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
raise NotImplementedError raise NotImplementedError
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
from . import FairseqDecoder from . import FairseqDecoder
...@@ -17,7 +16,7 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -17,7 +16,7 @@ class FairseqIncrementalDecoder(FairseqDecoder):
self._is_incremental_eval = False self._is_incremental_eval = False
self._incremental_state = {} self._incremental_state = {}
def forward(self, tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out):
if self._is_incremental_eval: if self._is_incremental_eval:
raise NotImplementedError raise NotImplementedError
else: else:
......
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import torch.nn as nn import torch.nn as nn
...@@ -30,10 +29,20 @@ class FairseqModel(nn.Module): ...@@ -30,10 +29,20 @@ class FairseqModel(nn.Module):
self._is_generation_fast = False self._is_generation_fast = False
def forward(self, src_tokens, input_tokens): @staticmethod
encoder_out = self.encoder(src_tokens) def add_args(parser):
decoder_out, _ = self.decoder(input_tokens, encoder_out) """Add model-specific arguments to the parser."""
return decoder_out.view(-1, decoder_out.size(-1)) pass
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
raise NotImplementedError
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out, _ = self.decoder(prev_output_tokens, encoder_out)
return decoder_out
def get_normalized_probs(self, net_output, log_probs): def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
...@@ -47,6 +56,16 @@ class FairseqModel(nn.Module): ...@@ -47,6 +56,16 @@ class FairseqModel(nn.Module):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return self.decoder.max_positions() return self.decoder.max_positions()
def load_state_dict(self, state_dict, strict=True):
"""Copies parameters and buffers from state_dict into this module and
its descendants.
Overrides the method in nn.Module; compared with that method this
additionally "upgrades" state_dicts from old checkpoints.
"""
state_dict = self.upgrade_state_dict(state_dict)
super().load_state_dict(state_dict, strict)
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
state_dict = self.encoder.upgrade_state_dict(state_dict) state_dict = self.encoder.upgrade_state_dict(state_dict)
state_dict = self.decoder.upgrade_state_dict(state_dict) state_dict = self.decoder.upgrade_state_dict(state_dict)
......
...@@ -4,25 +4,68 @@ ...@@ -4,25 +4,68 @@
# This source code is licensed under the license found in the LICENSE file in # This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
#
import math import math
import torch import torch
from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq.data import LanguagePairDataset from fairseq.data import LanguagePairDataset
from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution from fairseq.modules import BeamableMM, GradMultiply, LearnedPositionalEmbedding, LinearizedConvolution
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, register_model_architecture
@register_model('fconv')
class FConvModel(FairseqModel): class FConvModel(FairseqModel):
def __init__(self, encoder, decoder): def __init__(self, encoder, decoder):
super().__init__(encoder, decoder) super().__init__(encoder, decoder)
self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention) self.encoder.num_attention_layers = sum(layer is not None for layer in decoder.attention)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', default=0.1, type=float, metavar='D',
help='dropout probability')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-layers', type=str, metavar='EXPR',
help='encoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-layers', type=str, metavar='EXPR',
help='decoder layers [(dim, kernel_size), ...]')
parser.add_argument('--decoder-out-embed-dim', type=int, metavar='N',
help='decoder output embedding dimension')
parser.add_argument('--decoder-attention', type=str, metavar='EXPR',
help='decoder attention [True, ...]')
parser.add_argument('--share-input-output-embed', action='store_true',
help='share input and output embeddings (requires'
' --decoder-out-embed-dim and --decoder-embed-dim'
' to be equal)')
@classmethod
def build_model(cls, args, src_dict, dst_dict):
"""Build a new model instance."""
encoder = FConvEncoder(
src_dict,
embed_dim=args.encoder_embed_dim,
convolutions=eval(args.encoder_layers),
dropout=args.dropout,
max_positions=args.max_source_positions,
)
decoder = FConvDecoder(
dst_dict,
embed_dim=args.decoder_embed_dim,
convolutions=eval(args.decoder_layers),
out_embed_dim=args.decoder_out_embed_dim,
attention=eval(args.decoder_attention),
dropout=args.dropout,
max_positions=args.max_target_positions,
share_embed=args.share_input_output_embed
)
return FConvModel(encoder, decoder)
class FConvEncoder(FairseqEncoder): class FConvEncoder(FairseqEncoder):
"""Convolutional encoder""" """Convolutional encoder"""
...@@ -35,8 +78,12 @@ class FConvEncoder(FairseqEncoder): ...@@ -35,8 +78,12 @@ class FConvEncoder(FairseqEncoder):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = PositionalEmbedding(max_positions, embed_dim, padding_idx, self.embed_positions = PositionalEmbedding(
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE) max_positions,
embed_dim,
padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_SOURCE,
)
in_channels = convolutions[0][0] in_channels = convolutions[0][0]
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
...@@ -52,7 +99,7 @@ class FConvEncoder(FairseqEncoder): ...@@ -52,7 +99,7 @@ class FConvEncoder(FairseqEncoder):
in_channels = out_channels in_channels = out_channels
self.fc2 = Linear(in_channels, embed_dim) self.fc2 = Linear(in_channels, embed_dim)
def forward(self, src_tokens): def forward(self, src_tokens, src_lengths):
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens) x = self.embed_tokens(src_tokens) + self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
...@@ -151,8 +198,12 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -151,8 +198,12 @@ class FConvDecoder(FairseqIncrementalDecoder):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
padding_idx = dictionary.pad() padding_idx = dictionary.pad()
self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx) self.embed_tokens = Embedding(num_embeddings, embed_dim, padding_idx)
self.embed_positions = PositionalEmbedding(max_positions, embed_dim, padding_idx, self.embed_positions = PositionalEmbedding(
left_pad=LanguagePairDataset.LEFT_PAD_TARGET) max_positions,
embed_dim,
padding_idx,
left_pad=LanguagePairDataset.LEFT_PAD_TARGET,
)
self.fc1 = Linear(embed_dim, in_channels, dropout=dropout) self.fc1 = Linear(embed_dim, in_channels, dropout=dropout)
self.projections = nn.ModuleList() self.projections = nn.ModuleList()
...@@ -178,19 +229,19 @@ class FConvDecoder(FairseqIncrementalDecoder): ...@@ -178,19 +229,19 @@ class FConvDecoder(FairseqIncrementalDecoder):
else: else:
self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout) self.fc3 = Linear(out_embed_dim, num_embeddings, dropout=dropout)
def forward(self, input_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out):
# split and transpose encoder outputs # split and transpose encoder outputs
encoder_a, encoder_b = self._split_encoder_out(encoder_out) encoder_a, encoder_b = self._split_encoder_out(encoder_out)
# embed positions # embed positions
positions = self.embed_positions(input_tokens) positions = self.embed_positions(prev_output_tokens)
if self._is_incremental_eval: if self._is_incremental_eval:
# keep only the last token for incremental forward pass # keep only the last token for incremental forward pass
input_tokens = input_tokens[:, -1:] prev_output_tokens = prev_output_tokens[:, -1:]
# embed tokens and positions # embed tokens and positions
x = self.embed_tokens(input_tokens) + positions x = self.embed_tokens(prev_output_tokens) + positions
x = F.dropout(x, p=self.dropout, training=self.training) x = F.dropout(x, p=self.dropout, training=self.training)
target_embedding = x target_embedding = x
...@@ -316,63 +367,8 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs): ...@@ -316,63 +367,8 @@ def ConvTBC(in_channels, out_channels, kernel_size, dropout=0, **kwargs):
return nn.utils.weight_norm(m, dim=2) return nn.utils.weight_norm(m, dim=2)
def get_archs(): @register_model_architecture('fconv', 'fconv')
return [ def base_architecture(args):
'fconv', 'fconv_iwslt_de_en', 'fconv_wmt_en_ro', 'fconv_wmt_en_de', 'fconv_wmt_en_fr',
]
def _check_arch(args):
"""Check that the specified architecture is valid and not ambiguous."""
if args.arch not in get_archs():
raise ValueError('Unknown fconv model architecture: {}'.format(args.arch))
if args.arch != 'fconv':
# check that architecture is not ambiguous
for a in ['encoder_embed_dim', 'encoder_layers', 'decoder_embed_dim', 'decoder_layers',
'decoder_out_embed_dim']:
if hasattr(args, a):
raise ValueError('--{} cannot be combined with --arch={}'.format(a, args.arch))
def parse_arch(args):
_check_arch(args)
if args.arch == 'fconv_iwslt_de_en':
args.encoder_embed_dim = 256
args.encoder_layers = '[(256, 3)] * 4'
args.decoder_embed_dim = 256
args.decoder_layers = '[(256, 3)] * 3'
args.decoder_out_embed_dim = 256
elif args.arch == 'fconv_wmt_en_ro':
args.encoder_embed_dim = 512
args.encoder_layers = '[(512, 3)] * 20'
args.decoder_embed_dim = 512
args.decoder_layers = '[(512, 3)] * 20'
args.decoder_out_embed_dim = 512
elif args.arch == 'fconv_wmt_en_de':
convs = '[(512, 3)] * 9' # first 9 layers have 512 units
convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units
convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs
args.decoder_embed_dim = 768
args.decoder_layers = convs
args.decoder_out_embed_dim = 512
elif args.arch == 'fconv_wmt_en_fr':
convs = '[(512, 3)] * 6' # first 6 layers have 512 units
convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units
convs += ' + [(1024, 3)] * 3' # next 3 layers have 1024 units
convs += ' + [(2048, 1)] * 1' # next 1 layer uses 1x1 convolutions
convs += ' + [(4096, 1)] * 1' # final 1 layer uses 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs
args.decoder_embed_dim = 768
args.decoder_layers = convs
args.decoder_out_embed_dim = 512
else:
assert args.arch == 'fconv'
# default architecture
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20') args.encoder_layers = getattr(args, 'encoder_layers', '[(512, 3)] * 20')
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
...@@ -380,25 +376,51 @@ def parse_arch(args): ...@@ -380,25 +376,51 @@ def parse_arch(args):
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256) args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
args.decoder_attention = getattr(args, 'decoder_attention', 'True') args.decoder_attention = getattr(args, 'decoder_attention', 'True')
args.share_input_output_embed = getattr(args, 'share_input_output_embed', False) args.share_input_output_embed = getattr(args, 'share_input_output_embed', False)
return args
@register_model_architecture('fconv', 'fconv_iwslt_de_en')
def build_model(args, src_dict, dst_dict): def fconv_iwslt_de_en(args):
encoder = FConvEncoder( base_architecture(args)
src_dict, args.encoder_embed_dim = 256
embed_dim=args.encoder_embed_dim, args.encoder_layers = '[(256, 3)] * 4'
convolutions=eval(args.encoder_layers), args.decoder_embed_dim = 256
dropout=args.dropout, args.decoder_layers = '[(256, 3)] * 3'
max_positions=args.max_source_positions, args.decoder_out_embed_dim = 256
)
decoder = FConvDecoder(
dst_dict, @register_model_architecture('fconv', 'fconv_wmt_en_ro')
embed_dim=args.decoder_embed_dim, def fconv_wmt_en_ro(args):
convolutions=eval(args.decoder_layers), base_architecture(args)
out_embed_dim=args.decoder_out_embed_dim, args.encoder_embed_dim = 512
attention=eval(args.decoder_attention), args.encoder_layers = '[(512, 3)] * 20'
dropout=args.dropout, args.decoder_embed_dim = 512
max_positions=args.max_target_positions, args.decoder_layers = '[(512, 3)] * 20'
share_embed=args.share_input_output_embed args.decoder_out_embed_dim = 512
)
return FConvModel(encoder, decoder)
@register_model_architecture('fconv', 'fconv_wmt_en_de')
def fconv_wmt_en_de(args):
base_architecture(args)
convs = '[(512, 3)] * 9' # first 9 layers have 512 units
convs += ' + [(1024, 3)] * 4' # next 4 layers have 1024 units
convs += ' + [(2048, 1)] * 2' # final 2 layers use 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs
args.decoder_embed_dim = 768
args.decoder_layers = convs
args.decoder_out_embed_dim = 512
@register_model_architecture('fconv', 'fconv_wmt_en_fr')
def fconv_wmt_en_fr(args):
base_architecture(args)
convs = '[(512, 3)] * 6' # first 6 layers have 512 units
convs += ' + [(768, 3)] * 4' # next 4 layers have 768 units
convs += ' + [(1024, 3)] * 3' # next 3 layers have 1024 units
convs += ' + [(2048, 1)] * 1' # next 1 layer uses 1x1 convolutions
convs += ' + [(4096, 1)] * 1' # final 1 layer uses 1x1 convolutions
args.encoder_embed_dim = 768
args.encoder_layers = convs
args.decoder_embed_dim = 768
args.decoder_layers = convs
args.decoder_out_embed_dim = 512
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment