"tests/vscode:/vscode.git/clone" did not exist on "f79188da4c3081f786a48366deee0a1ac21f9bc0"
Unverified Commit 388c520b authored by Myle Ott's avatar Myle Ott Committed by GitHub
Browse files

0.4.0 -> 0.5.0

Changelog:
- 97b58b46: add Transformer model from Vaswani et al. (2017)
- b2374e52: faster Transformer inference with improved caching
- 2d27ae08: simulate large mini-batch training with delayed updates (`--update-freq`)
- 7ee1d284: add FP16 training support (`--fp16`)
- 2a84f46b: faster inference by removing completed sentences from the batch
- 663fd806: batched interactive generation
- 4c2ef2de: add language modeling / gated convolutional model from Dauphin et al. (2017)
- b59815bc: add Hierarchical Neural Story Generation model from Fan et al. (2018)
- ff68a9ef: add FairseqTask to modularize task definitions (e.g., translation, language modeling)
parents ec0031df 5383b5db
...@@ -5,9 +5,10 @@ ...@@ -5,9 +5,10 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import numpy as np
import os import os
import struct import struct
import numpy as np
import torch import torch
from fairseq.tokenizer import Tokenizer from fairseq.tokenizer import Tokenizer
...@@ -48,10 +49,11 @@ def data_file_path(prefix_path): ...@@ -48,10 +49,11 @@ def data_file_path(prefix_path):
return prefix_path + '.bin' return prefix_path + '.bin'
class IndexedDataset(object): class IndexedDataset(torch.utils.data.Dataset):
"""Loader for TorchNet IndexedDataset""" """Loader for TorchNet IndexedDataset"""
def __init__(self, path): def __init__(self, path):
super().__init__()
with open(index_file_path(path), 'rb') as f: with open(index_file_path(path), 'rb') as f:
magic = f.read(8) magic = f.read(8)
assert magic == b'TNTIDX\x00\x00' assert magic == b'TNTIDX\x00\x00'
...@@ -81,7 +83,7 @@ class IndexedDataset(object): ...@@ -81,7 +83,7 @@ class IndexedDataset(object):
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
self.data_file.seek(self.data_offsets[i] * self.element_size) self.data_file.seek(self.data_offsets[i] * self.element_size)
self.data_file.readinto(a) self.data_file.readinto(a)
return torch.from_numpy(a) return torch.from_numpy(a).long() - 1 # subtract 1 for 0-based indexing
def __len__(self): def __len__(self):
return self.size return self.size
...@@ -102,6 +104,7 @@ class IndexedInMemoryDataset(IndexedDataset): ...@@ -102,6 +104,7 @@ class IndexedInMemoryDataset(IndexedDataset):
self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype) self.buffer = np.empty(self.data_offsets[-1], dtype=self.dtype)
self.data_file.readinto(self.buffer) self.data_file.readinto(self.buffer)
self.data_file.close() self.data_file.close()
self.buffer -= 1 # subtract 1 for 0-based indexing
def __del__(self): def __del__(self):
pass pass
...@@ -111,7 +114,7 @@ class IndexedInMemoryDataset(IndexedDataset): ...@@ -111,7 +114,7 @@ class IndexedInMemoryDataset(IndexedDataset):
tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]] tensor_size = self.sizes[self.dim_offsets[i]:self.dim_offsets[i + 1]]
a = np.empty(tensor_size, dtype=self.dtype) a = np.empty(tensor_size, dtype=self.dtype)
np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]]) np.copyto(a, self.buffer[self.data_offsets[i]:self.data_offsets[i + 1]])
return torch.from_numpy(a) return torch.from_numpy(a).long()
class IndexedRawTextDataset(IndexedDataset): class IndexedRawTextDataset(IndexedDataset):
...@@ -134,7 +137,7 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -134,7 +137,7 @@ class IndexedRawTextDataset(IndexedDataset):
tokens = Tokenizer.tokenize( tokens = Tokenizer.tokenize(
line, dictionary, add_if_not_exist=False, line, dictionary, add_if_not_exist=False,
append_eos=self.append_eos, reverse_order=self.reverse_order, append_eos=self.append_eos, reverse_order=self.reverse_order,
) + 1 # +1 for Lua compatibility ).long()
self.tokens_list.append(tokens) self.tokens_list.append(tokens)
self.sizes.append(len(tokens)) self.sizes.append(len(tokens))
self.sizes = np.array(self.sizes) self.sizes = np.array(self.sizes)
...@@ -153,12 +156,15 @@ class IndexedRawTextDataset(IndexedDataset): ...@@ -153,12 +156,15 @@ class IndexedRawTextDataset(IndexedDataset):
def __len__(self): def __len__(self):
return self.size return self.size
@staticmethod
def exists(path):
return os.path.exists(path)
class IndexedDatasetBuilder(object):
class IndexedDatasetBuilder(object):
element_sizes = { element_sizes = {
np.uint8: 1, np.uint8: 1,
np.int8: 1, np.int8: 1,
np.int16: 2, np.int16: 2,
np.int32: 4, np.int32: 4,
np.int64: 8, np.int64: 8,
...@@ -187,10 +193,8 @@ class IndexedDatasetBuilder(object): ...@@ -187,10 +193,8 @@ class IndexedDatasetBuilder(object):
index = open(index_file, 'wb') index = open(index_file, 'wb')
index.write(b'TNTIDX\x00\x00') index.write(b'TNTIDX\x00\x00')
index.write(struct.pack('<Q', 1)) index.write(struct.pack('<Q', 1))
index.write(struct.pack('<QQ', code(self.dtype), index.write(struct.pack('<QQ', code(self.dtype), self.element_size))
self.element_size)) index.write(struct.pack('<QQ', len(self.data_offsets) - 1, len(self.sizes)))
index.write(struct.pack('<QQ', len(self.data_offsets) - 1,
len(self.sizes)))
write_longs(index, self.dim_offsets) write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets) write_longs(index, self.data_offsets)
write_longs(index, self.sizes) write_longs(index, self.sizes)
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
from . import data_utils, FairseqDataset
def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False):
if len(samples) == 0:
return {}
def merge(key, left_pad, move_eos_to_beginning=False):
return data_utils.collate_tokens(
[s[key] for s in samples],
pad_idx, eos_idx, left_pad, move_eos_to_beginning,
)
id = torch.LongTensor([s['id'] for s in samples])
src_tokens = merge('source', left_pad=left_pad_source)
# sort by descending source length
src_lengths = torch.LongTensor([s['source'].numel() for s in samples])
src_lengths, sort_order = src_lengths.sort(descending=True)
id = id.index_select(0, sort_order)
src_tokens = src_tokens.index_select(0, sort_order)
prev_output_tokens = None
target = None
if samples[0].get('target', None) is not None:
target = merge('target', left_pad=left_pad_target)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=left_pad_target,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
else:
ntokens = sum(len(s['source']) for s in samples)
return {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
},
'target': target,
}
class LanguagePairDataset(FairseqDataset):
"""A pair of torch.utils.data.Datasets."""
def __init__(
self, src, src_sizes, src_dict,
tgt=None, tgt_sizes=None, tgt_dict=None,
left_pad_source=True, left_pad_target=False,
max_source_positions=1024, max_target_positions=1024,
shuffle=True,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
assert src_dict.eos() == tgt_dict.eos()
assert src_dict.unk() == tgt_dict.unk()
self.src = src
self.tgt = tgt
self.src_sizes = np.array(src_sizes)
self.tgt_sizes = np.array(tgt_sizes) if tgt_sizes is not None else None
self.src_dict = src_dict
self.tgt_dict = tgt_dict
self.left_pad_source = left_pad_source
self.left_pad_target = left_pad_target
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions
self.shuffle = shuffle
def __getitem__(self, index):
return {
'id': index,
'source': self.src[index],
'target': self.tgt[index] if self.tgt is not None else None,
}
def __len__(self):
return len(self.src)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
return collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
)
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
src_len, tgt_len = min(src_len, max_source_positions), min(tgt_len, max_target_positions)
bsz = num_tokens // max(src_len, tgt_len)
return self.collater([
{
'id': i,
'source': self.src_dict.dummy_sentence(src_len),
'target': self.tgt_dict.dummy_sentence(tgt_len) if self.tgt_dict is not None else None,
}
for i in range(bsz)
])
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
return max(self.src_sizes[index], self.tgt_sizes[index] if self.tgt_sizes is not None else 0)
def ordered_indices(self):
"""Ordered indices for batching."""
if self.shuffle:
indices = np.random.permutation(len(self))
else:
indices = np.arange(len(self))
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind='mergesort')]
return indices[np.argsort(self.src_sizes[indices], kind='mergesort')]
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
max_source_positions, max_target_positions = self._get_max_positions(max_positions)
return (
self.src_sizes[index] <= max_source_positions
and (self.tgt_sizes is None or self.tgt_sizes[index] <= max_target_positions)
)
def _get_max_positions(self, max_positions):
if max_positions is None:
return self.max_source_positions, self.max_target_positions
assert len(max_positions) == 2
max_src_pos, max_tgt_pos = max_positions
return min(self.max_source_positions, max_src_pos), min(self.max_target_positions, max_tgt_pos)
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import numpy as np
import torch
from . import data_utils, FairseqDataset
def collate(samples, pad_idx, eos_idx):
if len(samples) == 0:
return {}
def merge(key):
return data_utils.collate_tokens(
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False,
)
return {
'id': torch.LongTensor([s['id'] for s in samples]),
'ntokens': sum(len(s['target']) for s in samples),
'net_input': {
'src_tokens': merge('source'),
},
'target': merge('target'),
}
class MonolingualDataset(FairseqDataset):
"""A wrapper around torch.utils.data.Dataset for monolingual data."""
def __init__(self, dataset, sizes, vocab, shuffle):
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = vocab
self.shuffle = shuffle
def __getitem__(self, index):
source, target = self.dataset[index]
return {'id': index, 'source': source, 'target': target}
def __len__(self):
return len(self.dataset)
def collater(self, samples):
"""Merge a list of samples to form a mini-batch."""
return collate(samples, self.vocab.pad(), self.vocab.eos())
def get_dummy_batch(self, num_tokens, max_positions, tgt_len=128):
assert isinstance(max_positions, float) or isinstance(max_positions, int)
tgt_len = min(tgt_len, max_positions)
bsz = num_tokens // tgt_len
target = self.vocab.dummy_sentence(tgt_len + 1)
source, target = target[:-1], target[1:]
return self.collater([
{'id': i, 'source': source, 'target': target}
for i in range(bsz)
])
def num_tokens(self, index):
"""Return an example's length (number of tokens), used for batching."""
source, target = self.dataset[index]
return len(source)
def ordered_indices(self):
"""Ordered indices for batching."""
if self.shuffle:
order = [np.random.permutation(len(self))]
else:
order = [np.arange(len(self))]
order.append(self.sizes)
return np.lexsort(order)
def valid_size(self, index, max_positions):
"""Check if an example's size is valid according to max_positions."""
assert isinstance(max_positions, float) or isinstance(max_positions, int)
return self.sizes[index] <= max_positions
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import numpy as np
import torch
class TokenBlockDataset(torch.utils.data.Dataset):
"""Break a 1d tensor of tokens into blocks.
The blocks are fetched from the original tensor so no additional memory is allocated.
Args:
tokens: 1d tensor of tokens to break into blocks
sizes: sentence lengths (required for 'complete' and 'eos')
block_size: maximum block size (ignored in 'eos' break mode)
break_mode: Mode used for breaking tokens. Values can be one of:
- 'none': break tokens into equally sized blocks (up to block_size)
- 'complete': break tokens into blocks (up to block_size) such that
blocks contains complete sentences, although block_size may be
exceeded if some sentences exceed block_size
- 'eos': each block contains one sentence (block_size is ignored)
include_targets: return next tokens as targets
"""
def __init__(self, tokens, sizes, block_size, break_mode=None, include_targets=False):
super().__init__()
self.tokens = tokens
self.total_size = len(tokens)
self.include_targets = include_targets
self.slice_indices = []
if break_mode is None or break_mode == 'none':
length = math.ceil(len(tokens) / block_size)
def block_at(i):
start = i * block_size
end = min(start + block_size, len(tokens))
return (start, end)
self.slice_indices = [block_at(i) for i in range(length)]
elif break_mode == 'complete':
assert sizes is not None and sum(sizes) == len(tokens)
tok_idx = 0
sz_idx = 0
curr_size = 0
while sz_idx < len(sizes):
if curr_size + sizes[sz_idx] <= block_size or curr_size == 0:
curr_size += sizes[sz_idx]
sz_idx += 1
else:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
tok_idx += curr_size
curr_size = 0
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
assert sizes is not None and sum(sizes) == len(tokens)
curr = 0
for sz in sizes:
# skip samples with just 1 example (which would be just the eos token)
if sz > 1:
self.slice_indices.append((curr, curr + sz))
curr += sz
else:
raise ValueError('Invalid break_mode: ' + break_mode)
self.sizes = np.array([e - s for s, e in self.slice_indices])
def __getitem__(self, index):
s, e = self.slice_indices[index]
item = torch.LongTensor(self.tokens[s:e])
if self.include_targets:
if e == self.total_size:
return item[:-1], item[1:]
else:
return item, torch.LongTensor(self.tokens[s + 1:e + 1])
else:
return item
def __len__(self):
return len(self.slice_indices)
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import math
import pickle import pickle
import torch.distributed import torch.distributed
...@@ -53,58 +52,6 @@ def suppress_output(): ...@@ -53,58 +52,6 @@ def suppress_output():
__builtin__.print = print __builtin__.print = print
def all_reduce_and_rescale_tensors(tensors, rescale_denom, buffer_size=10485760):
"""All-reduce and rescale tensors in chunks of the specified size.
Args:
tensors: list of Tensors to all-reduce
rescale_denom: denominator for rescaling summed Tensors
buffer_size: all-reduce chunk size in bytes
"""
# buffer size is in bytes, determine equiv. # of elements based on data type
buffer_t = tensors[0].new(math.ceil(buffer_size / tensors[0].element_size())).zero_()
buffer = []
def all_reduce_buffer():
# copy tensors into buffer_t
offset = 0
for t in buffer:
numel = t.numel()
buffer_t[offset:offset+numel].copy_(t.view(-1))
offset += numel
# all-reduce and rescale
torch.distributed.all_reduce(buffer_t[:offset])
buffer_t.div_(rescale_denom)
# copy all-reduced buffer back into tensors
offset = 0
for t in buffer:
numel = t.numel()
t.view(-1).copy_(buffer_t[offset:offset+numel])
offset += numel
filled = 0
for t in tensors:
sz = t.numel() * t.element_size()
if sz > buffer_size:
# tensor is bigger than buffer, all-reduce and rescale directly
torch.distributed.all_reduce(t)
t.div_(rescale_denom)
elif filled + sz > buffer_size:
# buffer is full, all-reduce and replace buffer with grad
all_reduce_buffer()
buffer = [t]
filled = sz
else:
# add tensor to buffer
buffer.append(t)
filled += sz
if len(buffer) > 0:
all_reduce_buffer()
def all_gather_list(data, max_size=4096): def all_gather_list(data, max_size=4096):
"""Gathers arbitrary data from all nodes into a list.""" """Gathers arbitrary data from all nodes into a list."""
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Train a network on multiple GPUs.
"""
import torch
from fairseq import optim, utils
from fairseq.meters import AverageMeter
from fairseq.optim import lr_scheduler
from fairseq.trainer import Trainer
class DynamicLossScaler:
def __init__(self, init_scale=2.**15, scale_factor=2., scale_window=2000):
self.loss_scale = init_scale
self.scale_factor = scale_factor
self.scale_window = scale_window
self._iter = 0
self._last_overflow_iter = -1
def update_scale(self, overflow):
if overflow:
self.loss_scale /= self.scale_factor
self._last_overflow_iter = self._iter
elif (self._iter - self._last_overflow_iter) % self.scale_window == 0:
self.loss_scale *= self.scale_factor
self._iter += 1
@staticmethod
def has_overflow(grad_norm):
# detect inf and nan
if grad_norm == float('inf') or grad_norm != grad_norm:
return True
return False
class FP16Trainer(Trainer):
"""Modified trainer for FP16.
We maintain two copies of the model's parameters, both in FP16 and FP32.
We do forward/backward with FP16 and compute the loss + optimize with FP32.
"""
def __init__(self, args, task, model, criterion):
super().__init__(args, task, model, criterion)
# convert model to FP16 (but keep criterion FP32)
self.model.half()
# dynamically scale loss to reduce overflow
self.scaler = DynamicLossScaler(init_scale=2.**7)
self.meters['loss_scale'] = AverageMeter()
def _build_optimizer(self):
# create FP32 copy of parameters and grads
params = [p for p in self.model.parameters() if p.requires_grad]
total_param_size = sum(p.data.numel() for p in params)
self.fp32_params = params[0].new(0).float().new(total_param_size)
offset = 0
for p in params:
numel = p.data.numel()
self.fp32_params[offset:offset+numel].copy_(p.data.view(-1))
offset += numel
self.fp32_params = torch.nn.Parameter(self.fp32_params)
self.fp32_params.grad = self.fp32_params.data.new(total_param_size)
# create optimizer using the copied FP32 params
self.optimizer = optim.build_optimizer(self.args, [self.fp32_params])
self.lr_scheduler = lr_scheduler.build_lr_scheduler(self.args, self.optimizer)
def save_checkpoint(self, filename, extra_state):
"""Save all training state in a checkpoint file."""
extra_state['loss_scale'] = self.scaler.loss_scale
super().save_checkpoint(filename, extra_state)
def load_checkpoint(self, filename):
"""Load all training state from a checkpoint file."""
extra_state = super().load_checkpoint(filename)
if extra_state is not None and 'loss_scale' in extra_state:
self.scaler.loss_scale = extra_state['loss_scale']
return extra_state
def zero_grad(self):
# zero both the FP16 and FP32 grads
self.model.zero_grad() # FP16
self.optimizer.zero_grad() # FP32
def _backward(self, loss):
self.meters['loss_scale'].reset()
self.meters['loss_scale'].update(self.scaler.loss_scale)
if loss is not None:
# dynamically rescale loss to stay in FP16 range
loss = loss * self.scaler.loss_scale
return super()._backward(loss)
def _all_reduce_and_rescale(self, grad_denom):
# undo effect of dynamic loss scaling on gradients
grad_denom *= self.scaler.loss_scale
if self.args.distributed_world_size > 1:
# flatten grads into a single buffer
flat_grads = self._flat_grads = self._get_flat_grads(self._flat_grads)
# scale gradients to avoid overflow in all-reduce
flat_grads.div_(self.args.distributed_world_size)
grad_denom /= self.args.distributed_world_size
# all-reduce flat grads
torch.distributed.all_reduce(flat_grads)
# copy grads back to FP32
self.fp32_params.grad.data.copy_(flat_grads)
else:
# single worker: copy grads directly to FP32
self._get_flat_grads(out=self.fp32_params.grad.data)
# rescale and clip grads
self.fp32_params.grad.data.div_(grad_denom)
grad_norm = utils.clip_grad_norm_(self.fp32_params.grad.data, self.args.clip_norm)
# detect overflow and adjust loss scale
overflow = DynamicLossScaler.has_overflow(grad_norm)
self.scaler.update_scale(overflow)
if overflow:
raise OverflowError('setting loss scale to: ' + str(self.scaler.loss_scale))
return grad_norm
def _opt(self):
# take an optimization step using the FP32 params and grads
super()._opt()
# copy FP32 params back into FP16 model
offset = 0
for p in self.model.parameters():
if not p.requires_grad:
continue
numel = p.data.numel()
p.data.copy_(self.fp32_params.data[offset:offset+numel].view_as(p.data))
offset += numel
...@@ -28,10 +28,11 @@ class AverageMeter(object): ...@@ -28,10 +28,11 @@ class AverageMeter(object):
class TimeMeter(object): class TimeMeter(object):
"""Computes the average occurrence of some event per second""" """Computes the average occurrence of some event per second"""
def __init__(self): def __init__(self, init=0):
self.reset() self.reset(init)
def reset(self): def reset(self, init=0):
self.init = init
self.start = time.time() self.start = time.time()
self.n = 0 self.n = 0
...@@ -40,12 +41,11 @@ class TimeMeter(object): ...@@ -40,12 +41,11 @@ class TimeMeter(object):
@property @property
def avg(self): def avg(self):
delta = time.time() - self.start return self.n / self.elapsed_time
return self.n / delta
@property @property
def elapsed_time(self): def elapsed_time(self):
return time.time() - self.start return self.init + (time.time() - self.start)
class StopwatchMeter(object): class StopwatchMeter(object):
......
...@@ -11,7 +11,9 @@ import os ...@@ -11,7 +11,9 @@ import os
from .fairseq_decoder import FairseqDecoder # noqa: F401 from .fairseq_decoder import FairseqDecoder # noqa: F401
from .fairseq_encoder import FairseqEncoder # noqa: F401 from .fairseq_encoder import FairseqEncoder # noqa: F401
from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401 from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
from .fairseq_model import FairseqModel # noqa: F401 from .fairseq_model import BaseFairseqModel, FairseqModel, FairseqLanguageModel # noqa: F401
from .composite_encoder import CompositeEncoder # noqa: F401
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
...@@ -19,8 +21,8 @@ ARCH_MODEL_REGISTRY = {} ...@@ -19,8 +21,8 @@ ARCH_MODEL_REGISTRY = {}
ARCH_CONFIG_REGISTRY = {} ARCH_CONFIG_REGISTRY = {}
def build_model(args, src_dict, dst_dict): def build_model(args, task):
return ARCH_MODEL_REGISTRY[args.arch].build_model(args, src_dict, dst_dict) return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task)
def register_model(name): def register_model(name):
...@@ -29,8 +31,8 @@ def register_model(name): ...@@ -29,8 +31,8 @@ def register_model(name):
def register_model_cls(cls): def register_model_cls(cls):
if name in MODEL_REGISTRY: if name in MODEL_REGISTRY:
raise ValueError('Cannot register duplicate model ({})'.format(name)) raise ValueError('Cannot register duplicate model ({})'.format(name))
if not issubclass(cls, FairseqModel): if not issubclass(cls, BaseFairseqModel):
raise ValueError('Model ({}: {}) must extend FairseqModel'.format(name, cls.__name__)) raise ValueError('Model ({}: {}) must extend BaseFairseqModel'.format(name, cls.__name__))
MODEL_REGISTRY[name] = cls MODEL_REGISTRY[name] = cls
return cls return cls
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from . import FairseqEncoder
class CompositeEncoder(FairseqEncoder):
"""
Encoder class that forwards on multiple encoders, for example for a fusion model or question-answering
Accepts a dictionary of encoder, the first encoder's dictionary is used for initialization
"""
def __init__(self, encoders):
super().__init__(next(iter(encoders.values())).dictionary)
self.encoders = encoders
for key in self.encoders:
self.add_module(key, self.encoders[key])
def forward(self, src_tokens, src_lengths):
encoder_out = {}
for key in self.encoders:
encoder_out[key] = self.encoders[key](src_tokens, src_lengths)
return encoder_out
def max_positions(self):
return min([self.encoders[key].max_positions() for key in self.encoders])
def upgrade_state_dict(self, state_dict):
for key in self.encoders:
self.encoders[key].upgrade_state_dict(state_dict)
return state_dict
...@@ -19,9 +19,9 @@ class FairseqDecoder(nn.Module): ...@@ -19,9 +19,9 @@ class FairseqDecoder(nn.Module):
def forward(self, prev_output_tokens, encoder_out): def forward(self, prev_output_tokens, encoder_out):
raise NotImplementedError raise NotImplementedError
def get_normalized_probs(self, net_output, log_probs): def get_normalized_probs(self, net_output, log_probs, _):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
logits = net_output[0] logits = net_output[0].float()
if log_probs: if log_probs:
return F.log_softmax(logits, dim=-1) return F.log_softmax(logits, dim=-1)
else: else:
......
...@@ -26,9 +26,15 @@ class FairseqIncrementalDecoder(FairseqDecoder): ...@@ -26,9 +26,15 @@ class FairseqIncrementalDecoder(FairseqDecoder):
""" """
def apply_reorder_incremental_state(module): def apply_reorder_incremental_state(module):
if module != self and hasattr(module, 'reorder_incremental_state'): if module != self and hasattr(module, 'reorder_incremental_state'):
module.reorder_incremental_state(incremental_state, new_order) module.reorder_incremental_state(
incremental_state,
new_order,
)
self.apply(apply_reorder_incremental_state) self.apply(apply_reorder_incremental_state)
def reorder_encoder_out(self, encoder_out, new_order):
return encoder_out
def set_beam_size(self, beam_size): def set_beam_size(self, beam_size):
"""Sets the beam size in the decoder and all children.""" """Sets the beam size in the decoder and all children."""
if getattr(self, '_beam_size', -1) != beam_size: if getattr(self, '_beam_size', -1) != beam_size:
......
...@@ -5,28 +5,17 @@ ...@@ -5,28 +5,17 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch.nn as nn import torch.nn as nn
from . import FairseqDecoder, FairseqEncoder from . import FairseqDecoder, FairseqEncoder
class FairseqModel(nn.Module): class BaseFairseqModel(nn.Module):
"""Base class for encoder-decoder models.""" """Base class for fairseq models."""
def __init__(self, encoder, decoder): def __init__(self):
super().__init__() super().__init__()
self.encoder = encoder
self.decoder = decoder
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
self.src_dict = encoder.dictionary
self.dst_dict = decoder.dictionary
assert self.src_dict.pad() == self.dst_dict.pad()
assert self.src_dict.eos() == self.dst_dict.eos()
assert self.src_dict.unk() == self.dst_dict.unk()
self._is_generation_fast = False self._is_generation_fast = False
@staticmethod @staticmethod
...@@ -35,29 +24,24 @@ class FairseqModel(nn.Module): ...@@ -35,29 +24,24 @@ class FairseqModel(nn.Module):
pass pass
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
raise NotImplementedError raise NotImplementedError
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out
def get_normalized_probs(self, net_output, log_probs):
"""Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs)
def get_targets(self, sample, net_output): def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output.""" """Get targets from either the sample or the net's output."""
return sample['target'] return sample['target']
def max_encoder_positions(self): def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Maximum input length supported by the encoder.""" """Get normalized probabilities (or log probs) from a net's output."""
return self.encoder.max_positions() return self.decoder.get_normalized_probs(net_output, log_probs, sample)
def max_positions(self):
"""Maximum length supported by the model."""
raise NotImplementedError
def max_decoder_positions(self): def max_decoder_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum length supported by the decoder."""
return self.decoder.max_positions() return self.decoder.max_positions()
def load_state_dict(self, state_dict, strict=True): def load_state_dict(self, state_dict, strict=True):
...@@ -67,13 +51,17 @@ class FairseqModel(nn.Module): ...@@ -67,13 +51,17 @@ class FairseqModel(nn.Module):
Overrides the method in nn.Module; compared with that method this Overrides the method in nn.Module; compared with that method this
additionally "upgrades" state_dicts from old checkpoints. additionally "upgrades" state_dicts from old checkpoints.
""" """
state_dict = self.upgrade_state_dict(state_dict) self.upgrade_state_dict(state_dict)
super().load_state_dict(state_dict, strict) super().load_state_dict(state_dict, strict)
def upgrade_state_dict(self, state_dict): def upgrade_state_dict(self, state_dict):
state_dict = self.encoder.upgrade_state_dict(state_dict) assert state_dict is not None
state_dict = self.decoder.upgrade_state_dict(state_dict)
return state_dict def do_upgrade(m):
if m != self and hasattr(m, 'upgrade_state_dict'):
m.upgrade_state_dict(state_dict)
self.apply(do_upgrade)
def make_generation_fast_(self, **kwargs): def make_generation_fast_(self, **kwargs):
"""Optimize model for faster generation.""" """Optimize model for faster generation."""
...@@ -87,11 +75,13 @@ class FairseqModel(nn.Module): ...@@ -87,11 +75,13 @@ class FairseqModel(nn.Module):
nn.utils.remove_weight_norm(module) nn.utils.remove_weight_norm(module)
except ValueError: # this module didn't have weight norm except ValueError: # this module didn't have weight norm
return return
self.apply(apply_remove_weight_norm) self.apply(apply_remove_weight_norm)
def apply_make_generation_fast_(module): def apply_make_generation_fast_(module):
if module != self and hasattr(module, 'make_generation_fast_'): if module != self and hasattr(module, 'make_generation_fast_'):
module.make_generation_fast_(**kwargs) module.make_generation_fast_(**kwargs)
self.apply(apply_make_generation_fast_) self.apply(apply_make_generation_fast_)
def train(mode): def train(mode):
...@@ -101,3 +91,40 @@ class FairseqModel(nn.Module): ...@@ -101,3 +91,40 @@ class FairseqModel(nn.Module):
# this model should no longer be used for training # this model should no longer be used for training
self.eval() self.eval()
self.train = train self.train = train
class FairseqModel(BaseFairseqModel):
"""Base class for encoder-decoder models."""
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
assert isinstance(self.encoder, FairseqEncoder)
assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens, src_lengths, prev_output_tokens):
encoder_out = self.encoder(src_tokens, src_lengths)
decoder_out = self.decoder(prev_output_tokens, encoder_out)
return decoder_out
def max_positions(self):
"""Maximum length supported by the model."""
return (self.encoder.max_positions(), self.decoder.max_positions())
class FairseqLanguageModel(BaseFairseqModel):
"""Base class for decoder-only models."""
def __init__(self, decoder):
super().__init__()
self.decoder = decoder
assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens):
return self.decoder(src_tokens)
def max_positions(self):
"""Maximum length supported by the model."""
return self.decoder.max_positions()
This diff is collapsed.
This diff is collapsed.
...@@ -6,14 +6,15 @@ ...@@ -6,14 +6,15 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch import torch
from torch.autograd import Variable
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from fairseq import utils from fairseq import options, utils
from fairseq.data import LanguagePairDataset
from . import FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model, register_model_architecture from . import (
FairseqEncoder, FairseqIncrementalDecoder, FairseqModel, register_model,
register_model_architecture,
)
@register_model('lstm') @register_model('lstm')
...@@ -60,18 +61,10 @@ class LSTMModel(FairseqModel): ...@@ -60,18 +61,10 @@ class LSTMModel(FairseqModel):
help='dropout probability for decoder output') help='dropout probability for decoder output')
@classmethod @classmethod
def build_model(cls, args, src_dict, dst_dict): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
if not hasattr(args, 'encoder_embed_path'): # make sure that all args are properly defaulted (in case there are any new ones)
args.encoder_embed_path = None base_architecture(args)
if not hasattr(args, 'decoder_embed_path'):
args.decoder_embed_path = None
if not hasattr(args, 'encoder_hidden_size'):
args.encoder_hidden_size = args.encoder_embed_dim
if not hasattr(args, 'decoder_hidden_size'):
args.decoder_hidden_size = args.decoder_embed_dim
if not hasattr(args, 'encoder_bidirectional'):
args.encoder_bidirectional = False
def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim): def load_pretrained_embedding_from_file(embed_path, dictionary, embed_dim):
num_embeddings = len(dictionary) num_embeddings = len(dictionary)
...@@ -84,14 +77,14 @@ class LSTMModel(FairseqModel): ...@@ -84,14 +77,14 @@ class LSTMModel(FairseqModel):
pretrained_encoder_embed = None pretrained_encoder_embed = None
if args.encoder_embed_path: if args.encoder_embed_path:
pretrained_encoder_embed = load_pretrained_embedding_from_file( pretrained_encoder_embed = load_pretrained_embedding_from_file(
args.encoder_embed_path, src_dict, args.encoder_embed_dim) args.encoder_embed_path, task.source_dictionary, args.encoder_embed_dim)
pretrained_decoder_embed = None pretrained_decoder_embed = None
if args.decoder_embed_path: if args.decoder_embed_path:
pretrained_decoder_embed = load_pretrained_embedding_from_file( pretrained_decoder_embed = load_pretrained_embedding_from_file(
args.decoder_embed_path, dst_dict, args.decoder_embed_dim) args.decoder_embed_path, task.target_dictionary, args.decoder_embed_dim)
encoder = LSTMEncoder( encoder = LSTMEncoder(
dictionary=src_dict, dictionary=task.source_dictionary,
embed_dim=args.encoder_embed_dim, embed_dim=args.encoder_embed_dim,
hidden_size=args.encoder_hidden_size, hidden_size=args.encoder_hidden_size,
num_layers=args.encoder_layers, num_layers=args.encoder_layers,
...@@ -100,19 +93,15 @@ class LSTMModel(FairseqModel): ...@@ -100,19 +93,15 @@ class LSTMModel(FairseqModel):
bidirectional=args.encoder_bidirectional, bidirectional=args.encoder_bidirectional,
pretrained_embed=pretrained_encoder_embed, pretrained_embed=pretrained_encoder_embed,
) )
try:
attention = bool(eval(args.decoder_attention))
except TypeError:
attention = bool(args.decoder_attention)
decoder = LSTMDecoder( decoder = LSTMDecoder(
dictionary=dst_dict, dictionary=task.target_dictionary,
embed_dim=args.decoder_embed_dim, embed_dim=args.decoder_embed_dim,
hidden_size=args.decoder_hidden_size, hidden_size=args.decoder_hidden_size,
out_embed_dim=args.decoder_out_embed_dim, out_embed_dim=args.decoder_out_embed_dim,
num_layers=args.decoder_layers, num_layers=args.decoder_layers,
dropout_in=args.decoder_dropout_in, dropout_in=args.decoder_dropout_in,
dropout_out=args.decoder_dropout_out, dropout_out=args.decoder_dropout_out,
attention=attention, attention=options.eval_bool(args.decoder_attention),
encoder_embed_dim=args.encoder_embed_dim, encoder_embed_dim=args.encoder_embed_dim,
encoder_output_units=encoder.output_units, encoder_output_units=encoder.output_units,
pretrained_embed=pretrained_decoder_embed, pretrained_embed=pretrained_decoder_embed,
...@@ -123,11 +112,9 @@ class LSTMModel(FairseqModel): ...@@ -123,11 +112,9 @@ class LSTMModel(FairseqModel):
class LSTMEncoder(FairseqEncoder): class LSTMEncoder(FairseqEncoder):
"""LSTM encoder.""" """LSTM encoder."""
def __init__( def __init__(
self, dictionary, embed_dim=512, hidden_size=512, num_layers=1, self, dictionary, embed_dim=512, hidden_size=512, num_layers=1,
dropout_in=0.1, dropout_out=0.1, bidirectional=False, dropout_in=0.1, dropout_out=0.1, bidirectional=False,
left_pad_source=LanguagePairDataset.LEFT_PAD_SOURCE, left_pad=True, pretrained_embed=None, padding_value=0.,
pretrained_embed=None,
padding_value=0.,
): ):
super().__init__(dictionary) super().__init__(dictionary)
self.num_layers = num_layers self.num_layers = num_layers
...@@ -147,10 +134,10 @@ class LSTMEncoder(FairseqEncoder): ...@@ -147,10 +134,10 @@ class LSTMEncoder(FairseqEncoder):
input_size=embed_dim, input_size=embed_dim,
hidden_size=hidden_size, hidden_size=hidden_size,
num_layers=num_layers, num_layers=num_layers,
dropout=self.dropout_out, dropout=self.dropout_out if num_layers > 1 else 0.,
bidirectional=bidirectional, bidirectional=bidirectional,
) )
self.left_pad_source = left_pad_source self.left_pad = left_pad
self.padding_value = padding_value self.padding_value = padding_value
self.output_units = hidden_size self.output_units = hidden_size
...@@ -158,7 +145,7 @@ class LSTMEncoder(FairseqEncoder): ...@@ -158,7 +145,7 @@ class LSTMEncoder(FairseqEncoder):
self.output_units *= 2 self.output_units *= 2
def forward(self, src_tokens, src_lengths): def forward(self, src_tokens, src_lengths):
if self.left_pad_source: if self.left_pad:
# convert left-padding to right-padding # convert left-padding to right-padding
src_tokens = utils.convert_padding_direction( src_tokens = utils.convert_padding_direction(
src_tokens, src_tokens,
...@@ -183,33 +170,32 @@ class LSTMEncoder(FairseqEncoder): ...@@ -183,33 +170,32 @@ class LSTMEncoder(FairseqEncoder):
state_size = 2 * self.num_layers, bsz, self.hidden_size state_size = 2 * self.num_layers, bsz, self.hidden_size
else: else:
state_size = self.num_layers, bsz, self.hidden_size state_size = self.num_layers, bsz, self.hidden_size
h0 = Variable(x.data.new(*state_size).zero_()) h0 = x.data.new(*state_size).zero_()
c0 = Variable(x.data.new(*state_size).zero_()) c0 = x.data.new(*state_size).zero_()
packed_outs, (final_hiddens, final_cells) = self.lstm( packed_outs, (final_hiddens, final_cells) = self.lstm(packed_x, (h0, c0))
packed_x,
(h0, c0),
)
# unpack outputs and apply dropout # unpack outputs and apply dropout
x, _ = nn.utils.rnn.pad_packed_sequence( x, _ = nn.utils.rnn.pad_packed_sequence(packed_outs, padding_value=self.padding_value)
packed_outs, padding_value=self.padding_value)
x = F.dropout(x, p=self.dropout_out, training=self.training) x = F.dropout(x, p=self.dropout_out, training=self.training)
assert list(x.size()) == [seqlen, bsz, self.output_units] assert list(x.size()) == [seqlen, bsz, self.output_units]
if self.bidirectional: if self.bidirectional:
bi_final_hiddens, bi_final_cells = [], []
for i in range(self.num_layers): def combine_bidir(outs):
bi_final_hiddens.append( return torch.cat([
torch.cat( torch.cat([outs[2 * i], outs[2 * i + 1]], dim=0).view(1, bsz, self.output_units)
(final_hiddens[2 * i], final_hiddens[2 * i + 1]), for i in range(self.num_layers)
dim=0).view(bsz, self.output_units)) ], dim=0)
bi_final_cells.append(
torch.cat( final_hiddens = combine_bidir(final_hiddens)
(final_cells[2 * i], final_cells[2 * i + 1]), final_cells = combine_bidir(final_cells)
dim=0).view(bsz, self.output_units))
return x, bi_final_hiddens, bi_final_cells encoder_padding_mask = src_tokens.eq(self.padding_idx).t()
return x, final_hiddens, final_cells return {
'encoder_out': (x, final_hiddens, final_cells),
'encoder_padding_mask': encoder_padding_mask if encoder_padding_mask.any() else None
}
def max_positions(self): def max_positions(self):
"""Maximum input length supported by the encoder.""" """Maximum input length supported by the encoder."""
...@@ -223,7 +209,7 @@ class AttentionLayer(nn.Module): ...@@ -223,7 +209,7 @@ class AttentionLayer(nn.Module):
self.input_proj = Linear(input_embed_dim, output_embed_dim, bias=False) self.input_proj = Linear(input_embed_dim, output_embed_dim, bias=False)
self.output_proj = Linear(2*output_embed_dim, output_embed_dim, bias=False) self.output_proj = Linear(2*output_embed_dim, output_embed_dim, bias=False)
def forward(self, input, source_hids, src_lengths=None): def forward(self, input, source_hids, encoder_padding_mask):
# input: bsz x input_embed_dim # input: bsz x input_embed_dim
# source_hids: srclen x bsz x output_embed_dim # source_hids: srclen x bsz x output_embed_dim
...@@ -232,7 +218,15 @@ class AttentionLayer(nn.Module): ...@@ -232,7 +218,15 @@ class AttentionLayer(nn.Module):
# compute attention # compute attention
attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2) attn_scores = (source_hids * x.unsqueeze(0)).sum(dim=2)
attn_scores = F.softmax(attn_scores.t(), dim=1).t() # srclen x bsz
# don't attend over padding
if encoder_padding_mask is not None:
attn_scores = attn_scores.float().masked_fill_(
encoder_padding_mask,
float('-inf')
).type_as(attn_scores) # FP16 support: cast to float and back
attn_scores = F.softmax(attn_scores, dim=0) # srclen x bsz
# sum weighted sources # sum weighted sources
x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0) x = (attn_scores.unsqueeze(2) * source_hids).sum(dim=0)
...@@ -244,10 +238,9 @@ class AttentionLayer(nn.Module): ...@@ -244,10 +238,9 @@ class AttentionLayer(nn.Module):
class LSTMDecoder(FairseqIncrementalDecoder): class LSTMDecoder(FairseqIncrementalDecoder):
"""LSTM decoder.""" """LSTM decoder."""
def __init__( def __init__(
self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512, self, dictionary, embed_dim=512, hidden_size=512, out_embed_dim=512,
num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True, num_layers=1, dropout_in=0.1, dropout_out=0.1, attention=True,
encoder_embed_dim=512, encoder_output_units=512, encoder_embed_dim=512, encoder_output_units=512, pretrained_embed=None,
pretrained_embed=None,
): ):
super().__init__(dictionary) super().__init__(dictionary)
self.dropout_in = dropout_in self.dropout_in = dropout_in
...@@ -263,7 +256,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -263,7 +256,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.encoder_output_units = encoder_output_units self.encoder_output_units = encoder_output_units
assert encoder_output_units == hidden_size, \ assert encoder_output_units == hidden_size, \
'{} {}'.format(encoder_output_units, hidden_size) 'encoder_output_units ({}) != hidden_size ({})'.format(encoder_output_units, hidden_size)
# TODO another Linear layer if not equal # TODO another Linear layer if not equal
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
...@@ -278,7 +271,10 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -278,7 +271,10 @@ class LSTMDecoder(FairseqIncrementalDecoder):
self.additional_fc = Linear(hidden_size, out_embed_dim) self.additional_fc = Linear(hidden_size, out_embed_dim)
self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out) self.fc_out = Linear(out_embed_dim, num_embeddings, dropout=dropout_out)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None): def forward(self, prev_output_tokens, encoder_out_dict, incremental_state=None):
encoder_out = encoder_out_dict['encoder_out']
encoder_padding_mask = encoder_out_dict['encoder_padding_mask']
if incremental_state is not None: if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:] prev_output_tokens = prev_output_tokens[:, -1:]
bsz, seqlen = prev_output_tokens.size() bsz, seqlen = prev_output_tokens.size()
...@@ -303,9 +299,9 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -303,9 +299,9 @@ class LSTMDecoder(FairseqIncrementalDecoder):
num_layers = len(self.layers) num_layers = len(self.layers)
prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)] prev_hiddens = [encoder_hiddens[i] for i in range(num_layers)]
prev_cells = [encoder_cells[i] for i in range(num_layers)] prev_cells = [encoder_cells[i] for i in range(num_layers)]
input_feed = Variable(x.data.new(bsz, self.encoder_output_units).zero_()) input_feed = x.data.new(bsz, self.encoder_output_units).zero_()
attn_scores = Variable(x.data.new(srclen, seqlen, bsz).zero_()) attn_scores = x.data.new(srclen, seqlen, bsz).zero_()
outs = [] outs = []
for j in range(seqlen): for j in range(seqlen):
# input feeding: concatenate context vector from previous time step # input feeding: concatenate context vector from previous time step
...@@ -324,7 +320,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -324,7 +320,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
# apply attention using the last layer's hidden state # apply attention using the last layer's hidden state
if self.attention is not None: if self.attention is not None:
out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs) out, attn_scores[:, j, :] = self.attention(hidden, encoder_outs, encoder_padding_mask)
else: else:
out = hidden out = hidden
out = F.dropout(out, p=self.dropout_out, training=self.training) out = F.dropout(out, p=self.dropout_out, training=self.training)
...@@ -357,6 +353,7 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -357,6 +353,7 @@ class LSTMDecoder(FairseqIncrementalDecoder):
return x, attn_scores return x, attn_scores
def reorder_incremental_state(self, incremental_state, new_order): def reorder_incremental_state(self, incremental_state, new_order):
super().reorder_incremental_state(incremental_state, new_order)
cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state') cached_state = utils.get_incremental_state(self, incremental_state, 'cached_state')
if cached_state is None: if cached_state is None:
return return
...@@ -366,11 +363,19 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -366,11 +363,19 @@ class LSTMDecoder(FairseqIncrementalDecoder):
return [reorder_state(state_i) for state_i in state] return [reorder_state(state_i) for state_i in state]
return state.index_select(0, new_order) return state.index_select(0, new_order)
if not isinstance(new_order, Variable):
new_order = Variable(new_order)
new_state = tuple(map(reorder_state, cached_state)) new_state = tuple(map(reorder_state, cached_state))
utils.set_incremental_state(self, incremental_state, 'cached_state', new_state) utils.set_incremental_state(self, incremental_state, 'cached_state', new_state)
def reorder_encoder_out(self, encoder_out_dict, new_order):
encoder_out_dict['encoder_out'] = tuple(
eo.index_select(1, new_order)
for eo in encoder_out_dict['encoder_out']
)
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(1, new_order)
return encoder_out_dict
def max_positions(self): def max_positions(self):
"""Maximum output length supported by the decoder.""" """Maximum output length supported by the decoder."""
return int(1e5) # an arbitrary large number return int(1e5) # an arbitrary large number
...@@ -378,7 +383,8 @@ class LSTMDecoder(FairseqIncrementalDecoder): ...@@ -378,7 +383,8 @@ class LSTMDecoder(FairseqIncrementalDecoder):
def Embedding(num_embeddings, embedding_dim, padding_idx): def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
m.weight.data.uniform_(-0.1, 0.1) nn.init.uniform_(m.weight, -0.1, 0.1)
nn.init.constant_(m.weight[padding_idx], 0)
return m return m
...@@ -410,48 +416,41 @@ def Linear(in_features, out_features, bias=True, dropout=0): ...@@ -410,48 +416,41 @@ def Linear(in_features, out_features, bias=True, dropout=0):
@register_model_architecture('lstm', 'lstm') @register_model_architecture('lstm', 'lstm')
def base_architecture(args): def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512) args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', 512) args.encoder_embed_path = getattr(args, 'encoder_embed_path', None)
args.encoder_hidden_size = getattr(args, 'encoder_hidden_size', args.encoder_embed_dim)
args.encoder_layers = getattr(args, 'encoder_layers', 1) args.encoder_layers = getattr(args, 'encoder_layers', 1)
args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False) args.encoder_bidirectional = getattr(args, 'encoder_bidirectional', False)
args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout) args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', args.dropout)
args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout) args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', args.dropout)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512) args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 512)
args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', 512) args.decoder_embed_path = getattr(args, 'decoder_embed_path', None)
args.decoder_hidden_size = getattr(args, 'decoder_hidden_size', args.decoder_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 1) args.decoder_layers = getattr(args, 'decoder_layers', 1)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512) args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 512)
args.decoder_attention = getattr(args, 'decoder_attention', True) args.decoder_attention = getattr(args, 'decoder_attention', '1')
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout) args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', args.dropout)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout) args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
@register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en') @register_model_architecture('lstm', 'lstm_wiseman_iwslt_de_en')
def lstm_wiseman_iwslt_de_en(args): def lstm_wiseman_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
args.encoder_dropout_in = getattr(args, 'encoder_dropout_in', 0)
args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 256)
args.decoder_dropout_in = getattr(args, 'decoder_dropout_in', 0)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', args.dropout)
base_architecture(args) base_architecture(args)
args.encoder_embed_dim = 256
args.encoder_hidden_size = 256
args.encoder_layers = 1
args.encoder_bidirectional = False
args.encoder_dropout_in = 0
args.encoder_dropout_out = 0
args.decoder_embed_dim = 256
args.decoder_hidden_size = 256
args.decoder_layers = 1
args.decoder_out_embed_dim = 256
args.decoder_attention = True
args.decoder_dropout_in = 0
@register_model_architecture('lstm', 'lstm_luong_wmt_en_de') @register_model_architecture('lstm', 'lstm_luong_wmt_en_de')
def lstm_luong_wmt_en_de(args): def lstm_luong_wmt_en_de(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1000)
args.encoder_layers = getattr(args, 'encoder_layers', 4)
args.encoder_dropout_out = getattr(args, 'encoder_dropout_out', 0)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1000)
args.decoder_layers = getattr(args, 'decoder_layers', 4)
args.decoder_out_embed_dim = getattr(args, 'decoder_out_embed_dim', 1000)
args.decoder_dropout_out = getattr(args, 'decoder_dropout_out', 0)
base_architecture(args) base_architecture(args)
args.encoder_embed_dim = 1000
args.encoder_hidden_size = 1000
args.encoder_layers = 4
args.encoder_dropout_out = 0
args.encoder_bidirectional = False
args.decoder_embed_dim = 1000
args.decoder_hidden_size = 1000
args.decoder_layers = 4
args.decoder_out_embed_dim = 1000
args.decoder_attention = True
args.decoder_dropout_out = 0
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules import (
LearnedPositionalEmbedding, MultiheadAttention,
SinusoidalPositionalEmbedding,
)
from . import (
FairseqIncrementalDecoder, FairseqEncoder, FairseqModel,
register_model, register_model_architecture,
)
@register_model('transformer')
class TransformerModel(FairseqModel):
def __init__(self, encoder, decoder):
super().__init__(encoder, decoder)
@staticmethod
def add_args(parser):
"""Add model-specific arguments to the parser."""
parser.add_argument('--dropout', type=float, metavar='D',
help='dropout probability')
parser.add_argument('--attention-dropout', type=float, metavar='D',
help='dropout probability for attention weights')
parser.add_argument('--relu-dropout', type=float, metavar='D',
help='dropout probability after ReLU in FFN')
parser.add_argument('--encoder-embed-dim', type=int, metavar='N',
help='encoder embedding dimension')
parser.add_argument('--encoder-ffn-embed-dim', type=int, metavar='N',
help='encoder embedding dimension for FFN')
parser.add_argument('--encoder-layers', type=int, metavar='N',
help='num encoder layers')
parser.add_argument('--encoder-attention-heads', type=int, metavar='N',
help='num encoder attention heads')
parser.add_argument('--encoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each encoder block')
parser.add_argument('--encoder-learned-pos', default=False, action='store_true',
help='use learned positional embeddings in the encoder')
parser.add_argument('--decoder-embed-dim', type=int, metavar='N',
help='decoder embedding dimension')
parser.add_argument('--decoder-ffn-embed-dim', type=int, metavar='N',
help='decoder embedding dimension for FFN')
parser.add_argument('--decoder-layers', type=int, metavar='N',
help='num decoder layers')
parser.add_argument('--decoder-attention-heads', type=int, metavar='N',
help='num decoder attention heads')
parser.add_argument('--decoder-learned-pos', default=False, action='store_true',
help='use learned positional embeddings in the decoder')
parser.add_argument('--decoder-normalize-before', default=False, action='store_true',
help='apply layernorm before each decoder block')
parser.add_argument('--share-decoder-input-output-embed', default=False, action='store_true',
help='share decoder input and output embeddings')
parser.add_argument('--share-all-embeddings', default=False, action='store_true',
help='share encoder, decoder and output embeddings'
' (requires shared dictionary and embed dim)')
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
src_dict, tgt_dict = task.source_dictionary, task.target_dictionary
def build_embedding(dictionary, embed_dim):
num_embeddings = len(dictionary)
padding_idx = dictionary.pad()
return Embedding(num_embeddings, embed_dim, padding_idx)
if args.share_all_embeddings:
if src_dict != tgt_dict:
raise RuntimeError('--share-all-embeddings requires a joined dictionary')
if args.encoder_embed_dim != args.decoder_embed_dim:
raise RuntimeError(
'--share-all-embeddings requires --encoder-embed-dim to match --decoder-embed-dim')
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = encoder_embed_tokens
args.share_decoder_input_output_embed = True
else:
encoder_embed_tokens = build_embedding(src_dict, args.encoder_embed_dim)
decoder_embed_tokens = build_embedding(tgt_dict, args.decoder_embed_dim)
encoder = TransformerEncoder(args, src_dict, encoder_embed_tokens)
decoder = TransformerDecoder(args, tgt_dict, decoder_embed_tokens)
return TransformerModel(encoder, decoder)
class TransformerEncoder(FairseqEncoder):
"""Transformer encoder."""
def __init__(self, args, dictionary, embed_tokens, left_pad=True):
super().__init__(dictionary)
self.dropout = args.dropout
embed_dim = embed_tokens.embedding_dim
self.padding_idx = embed_tokens.padding_idx
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, self.padding_idx,
left_pad=left_pad,
learned=args.encoder_learned_pos,
)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerEncoderLayer(args)
for i in range(args.encoder_layers)
])
def forward(self, src_tokens, src_lengths):
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(src_tokens)
x += self.embed_positions(src_tokens)
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# compute padding mask
encoder_padding_mask = src_tokens.eq(self.padding_idx)
if not encoder_padding_mask.any():
encoder_padding_mask = None
# encoder layers
for layer in self.layers:
x = layer(x, encoder_padding_mask)
return {
'encoder_out': x, # T x B x C
'encoder_padding_mask': encoder_padding_mask, # B x T
}
def max_positions(self):
"""Maximum input length supported by the encoder."""
return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'encoder.embed_positions.weights' in state_dict:
del state_dict['encoder.embed_positions.weights']
if 'encoder.embed_positions._float_tensor' not in state_dict:
state_dict['encoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
class TransformerDecoder(FairseqIncrementalDecoder):
"""Transformer decoder."""
def __init__(self, args, dictionary, embed_tokens, left_pad=False):
super().__init__(dictionary)
self.dropout = args.dropout
self.share_input_output_embed = args.share_decoder_input_output_embed
embed_dim = embed_tokens.embedding_dim
padding_idx = embed_tokens.padding_idx
self.embed_tokens = embed_tokens
self.embed_scale = math.sqrt(embed_dim)
self.embed_positions = PositionalEmbedding(
1024, embed_dim, padding_idx,
left_pad=left_pad,
learned=args.decoder_learned_pos,
)
self.layers = nn.ModuleList([])
self.layers.extend([
TransformerDecoderLayer(args)
for i in range(args.decoder_layers)
])
if not self.share_input_output_embed:
self.embed_out = nn.Parameter(torch.Tensor(len(dictionary), embed_dim))
nn.init.normal_(self.embed_out, mean=0, std=embed_dim ** -0.5)
def forward(self, prev_output_tokens, encoder_out, incremental_state=None):
# embed positions
positions = self.embed_positions(
prev_output_tokens,
incremental_state=incremental_state,
)
if incremental_state is not None:
prev_output_tokens = prev_output_tokens[:, -1:]
positions = positions[:, -1:]
# embed tokens and positions
x = self.embed_scale * self.embed_tokens(prev_output_tokens)
x += positions
x = F.dropout(x, p=self.dropout, training=self.training)
# B x T x C -> T x B x C
x = x.transpose(0, 1)
# decoder layers
for layer in self.layers:
x, attn = layer(
x,
encoder_out['encoder_out'],
encoder_out['encoder_padding_mask'],
incremental_state,
)
# T x B x C -> B x T x C
x = x.transpose(0, 1)
# project back to size of vocabulary
if self.share_input_output_embed:
x = F.linear(x, self.embed_tokens.weight)
else:
x = F.linear(x, self.embed_out)
return x, attn
def reorder_encoder_out(self, encoder_out_dict, new_order):
if encoder_out_dict['encoder_padding_mask'] is not None:
encoder_out_dict['encoder_padding_mask'] = \
encoder_out_dict['encoder_padding_mask'].index_select(0, new_order)
return encoder_out_dict
def max_positions(self):
"""Maximum output length supported by the decoder."""
return self.embed_positions.max_positions()
def upgrade_state_dict(self, state_dict):
if isinstance(self.embed_positions, SinusoidalPositionalEmbedding):
if 'decoder.embed_positions.weights' in state_dict:
del state_dict['decoder.embed_positions.weights']
if 'decoder.embed_positions._float_tensor' not in state_dict:
state_dict['decoder.embed_positions._float_tensor'] = torch.FloatTensor()
return state_dict
class TransformerEncoderLayer(nn.Module):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: dropout -> add residual -> layernorm.
In the tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
dropout -> add residual.
We default to the approach in the paper, but the tensor2tensor approach can
be enabled by setting `normalize_before=True`.
"""
def __init__(self, args):
super().__init__()
self.embed_dim = args.encoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.encoder_attention_heads,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.encoder_normalize_before
self.fc1 = Linear(self.embed_dim, args.encoder_ffn_embed_dim)
self.fc2 = Linear(args.encoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(2)])
def forward(self, x, encoder_padding_mask):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
return x
def maybe_layer_norm(self, i, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return self.layer_norms[i](x)
else:
return x
class TransformerDecoderLayer(nn.Module):
"""Decoder layer block."""
def __init__(self, args):
super().__init__()
self.embed_dim = args.decoder_embed_dim
self.self_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.dropout = args.dropout
self.relu_dropout = args.relu_dropout
self.normalize_before = args.decoder_normalize_before
self.encoder_attn = MultiheadAttention(
self.embed_dim, args.decoder_attention_heads,
dropout=args.attention_dropout,
)
self.fc1 = Linear(self.embed_dim, args.decoder_ffn_embed_dim)
self.fc2 = Linear(args.decoder_ffn_embed_dim, self.embed_dim)
self.layer_norms = nn.ModuleList([LayerNorm(self.embed_dim) for i in range(3)])
def forward(self, x, encoder_out, encoder_padding_mask, incremental_state):
residual = x
x = self.maybe_layer_norm(0, x, before=True)
x, _ = self.self_attn(
query=x,
key=x,
value=x,
mask_future_timesteps=True,
incremental_state=incremental_state,
need_weights=False,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(0, x, after=True)
residual = x
x = self.maybe_layer_norm(1, x, before=True)
x, attn = self.encoder_attn(
query=x,
key=encoder_out,
value=encoder_out,
key_padding_mask=encoder_padding_mask,
incremental_state=incremental_state,
static_kv=True,
)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(1, x, after=True)
residual = x
x = self.maybe_layer_norm(2, x, before=True)
x = F.relu(self.fc1(x))
x = F.dropout(x, p=self.relu_dropout, training=self.training)
x = self.fc2(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = residual + x
x = self.maybe_layer_norm(2, x, after=True)
return x, attn
def maybe_layer_norm(self, i, x, before=False, after=False):
assert before ^ after
if after ^ self.normalize_before:
return self.layer_norms[i](x)
else:
return x
def Embedding(num_embeddings, embedding_dim, padding_idx):
m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
return m
def LayerNorm(embedding_dim):
m = nn.LayerNorm(embedding_dim)
return m
def Linear(in_features, out_features, bias=True):
m = nn.Linear(in_features, out_features, bias)
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0.)
return m
def PositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad, learned=False):
if learned:
m = LearnedPositionalEmbedding(num_embeddings, embedding_dim, padding_idx, left_pad)
nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5)
nn.init.constant_(m.weight[padding_idx], 0)
else:
m = SinusoidalPositionalEmbedding(embedding_dim, padding_idx, left_pad, init_size=num_embeddings)
return m
@register_model_architecture('transformer', 'transformer')
def base_architecture(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 512)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 2048)
args.encoder_layers = getattr(args, 'encoder_layers', 6)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 8)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', args.encoder_embed_dim)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', args.encoder_ffn_embed_dim)
args.decoder_layers = getattr(args, 'decoder_layers', 6)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 8)
args.attention_dropout = getattr(args, 'attention_dropout', 0.)
args.relu_dropout = getattr(args, 'relu_dropout', 0.)
args.dropout = getattr(args, 'dropout', 0.1)
@register_model_architecture('transformer', 'transformer_iwslt_de_en')
def transformer_iwslt_de_en(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 256)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 512)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 4)
args.encoder_layers = getattr(args, 'encoder_layers', 3)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 256)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 512)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 4)
args.decoder_layers = getattr(args, 'decoder_layers', 3)
base_architecture(args)
@register_model_architecture('transformer', 'transformer_wmt_en_de')
def transformer_wmt_en_de(args):
base_architecture(args)
# parameters used in the "Attention Is All You Need" paper (Vaswani, et al, 2017)
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_de_big')
def transformer_vaswani_wmt_en_de_big(args):
args.encoder_embed_dim = getattr(args, 'encoder_embed_dim', 1024)
args.encoder_ffn_embed_dim = getattr(args, 'encoder_ffn_embed_dim', 4096)
args.encoder_attention_heads = getattr(args, 'encoder_attention_heads', 16)
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', False)
args.decoder_embed_dim = getattr(args, 'decoder_embed_dim', 1024)
args.decoder_ffn_embed_dim = getattr(args, 'decoder_ffn_embed_dim', 4096)
args.decoder_attention_heads = getattr(args, 'decoder_attention_heads', 16)
args.dropout = getattr(args, 'dropout', 0.3)
base_architecture(args)
@register_model_architecture('transformer', 'transformer_vaswani_wmt_en_fr_big')
def transformer_vaswani_wmt_en_fr_big(args):
args.dropout = getattr(args, 'dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args)
@register_model_architecture('transformer', 'transformer_wmt_en_de_big')
def transformer_wmt_en_de_big(args):
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args)
# default parameters used in tensor2tensor implementation
@register_model_architecture('transformer', 'transformer_wmt_en_de_big_t2t')
def transformer_wmt_en_de_big_t2t(args):
args.encoder_normalize_before = getattr(args, 'encoder_normalize_before', True)
args.encoder_normalize_before = getattr(args, 'decoder_normalize_before', True)
args.attention_dropout = getattr(args, 'attention_dropout', 0.1)
args.relu_dropout = getattr(args, 'relu_dropout', 0.1)
transformer_vaswani_wmt_en_de_big(args)
...@@ -5,16 +5,26 @@ ...@@ -5,16 +5,26 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
from .adaptive_softmax import AdaptiveSoftmax
from .beamable_mm import BeamableMM from .beamable_mm import BeamableMM
from .conv_tbc import ConvTBC from .conv_tbc import ConvTBC
from .downsampled_multihead_attention import DownsampledMultiHeadAttention
from .grad_multiply import GradMultiply from .grad_multiply import GradMultiply
from .learned_positional_embedding import LearnedPositionalEmbedding from .learned_positional_embedding import LearnedPositionalEmbedding
from .linearized_convolution import LinearizedConvolution from .linearized_convolution import LinearizedConvolution
from .multihead_attention import MultiheadAttention
from .scalar_bias import ScalarBias
from .sinusoidal_positional_embedding import SinusoidalPositionalEmbedding
__all__ = [ __all__ = [
'AdaptiveSoftmax',
'BeamableMM', 'BeamableMM',
'ConvTBC', 'ConvTBC',
'DownsampledMultiHeadAttention',
'GradMultiply', 'GradMultiply',
'LearnedPositionalEmbedding', 'LearnedPositionalEmbedding',
'LinearizedConvolution', 'LinearizedConvolution',
'MultiheadAttention',
'ScalarBias',
'SinusoidalPositionalEmbedding',
] ]
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch.nn.functional as F
from torch import nn
class AdaptiveSoftmax(nn.Module):
"""
This is an implementation of the efficient softmax approximation for
graphical processing units (GPU), described in the paper "Efficient softmax
approximation for GPUs" (http://arxiv.org/abs/1609.04309).
"""
def __init__(self, vocab_size, input_dim, cutoff, dropout):
super().__init__()
if vocab_size > cutoff[-1]:
cutoff = cutoff + [vocab_size]
output_dim = cutoff[0] + len(cutoff) - 1
self.vocab_size = vocab_size
self.cutoff = cutoff
self.dropout = dropout
self.lsm = nn.LogSoftmax(dim=1)
self.head = nn.Linear(input_dim, output_dim, bias=False)
self.tail = nn.ModuleList()
for i in range(len(cutoff) - 1):
self.tail.append(
nn.Sequential(
nn.Linear(input_dim, input_dim // 4 ** i, bias=False),
nn.Dropout(dropout),
nn.Linear(input_dim // 4 ** i, cutoff[i + 1] - cutoff[i], bias=False)
)
)
def init_weights(m):
if hasattr(m, 'weight'):
nn.init.xavier_uniform_(m.weight)
self.apply(init_weights)
def adapt_target(self, target):
"""
In order to be efficient, the AdaptiveSoftMax does not compute the
scores for all the word of the vocabulary for all the examples. It is
thus necessary to call the method adapt_target of the AdaptiveSoftMax
layer inside each forward pass.
"""
target = target.view(-1)
new_target = [target.clone()]
target_idxs = []
for i in range(len(self.cutoff) - 1):
mask = target.ge(self.cutoff[i]).mul(target.lt(self.cutoff[i + 1]))
new_target[0][mask] = self.cutoff[0] + i - 1
if mask.any():
target_idxs.append(mask.nonzero().squeeze(1))
new_target.append(target[mask].add(-self.cutoff[i]))
else:
target_idxs.append(None)
new_target.append(None)
return new_target, target_idxs
def forward(self, input, target):
"""
Args:
input: (b x t x d)
target: (b x t)
Returns:
2 lists: output for each cutoff section and new targets by cut off
"""
input = input.contiguous().view(-1, input.size(-1))
input = F.dropout(input, p=self.dropout, training=self.training)
new_target, target_idxs = self.adapt_target(target)
output = [self.head(input)]
for i in range(len(target_idxs)):
if target_idxs[i] is not None:
output.append(self.tail[i](input.index_select(0, target_idxs[i])))
else:
output.append(None)
return output, new_target
def get_log_prob(self, input, target):
"""
Computes the log probabilities for all the words of the vocabulary,
given a 2D tensor of hidden vectors.
"""
bsz, length, dim = input.size()
input = input.contiguous().view(-1, dim)
if target is not None:
_, target_idxs = self.adapt_target(target)
else:
target_idxs = None
head_y = self.head(input)
log_probs = head_y.new_zeros(input.size(0), self.vocab_size)
head_sz = self.cutoff[0] + len(self.tail)
log_probs[:, :head_sz] = self.lsm(head_y)
tail_priors = log_probs[:, self.cutoff[0] - 1: head_sz - 1].clone()
for i in range(len(self.tail)):
start = self.cutoff[i]
end = self.cutoff[i + 1]
if target_idxs is None:
tail_out = log_probs[:, start:end]
tail_out.copy_(self.tail[i](input))
log_probs[:, start:end] = self.lsm(tail_out).add_(tail_priors[:, i, None])
elif target_idxs[i] is not None:
idxs = target_idxs[i]
tail_out = log_probs[idxs, start:end]
tail_out.copy_(self.tail[i](input[idxs]))
log_probs[idxs, start:end] = self.lsm(tail_out).add_(tail_priors[idxs, i, None])
log_probs = log_probs.view(bsz, length, -1)
return log_probs
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fairseq.modules.scalar_bias import scalar_bias
class SingleHeadAttention(nn.Module):
"""
Single-head attention that supports Gating and Downsampling
"""
def __init__(
self, out_channels, embed_dim, head_dim, head_index, dropout=0.,
bias=True, project_input=True, gated=False, downsample=False,
num_heads=1,
):
super().__init__()
self.embed_dim = embed_dim
self.dropout = dropout
self.head_index = head_index
self.head_dim = head_dim
self.project_input = project_input
self.gated = gated
self.downsample = downsample
self.num_heads = num_heads
self.projection = None
k_layers = []
v_layers = []
if self.downsample:
k_layers.append(Downsample(self.head_index))
v_layers.append(Downsample(self.head_index))
out_proj_size = self.head_dim
else:
out_proj_size = self.head_dim * self.num_heads
if self.gated:
k_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_q = GatedLinear(self.embed_dim, out_proj_size, bias=bias)
v_layers.append(GatedLinear(self.embed_dim, out_proj_size, bias=bias))
else:
k_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_q = Linear(self.embed_dim, out_proj_size, bias=bias)
v_layers.append(Linear(self.embed_dim, out_proj_size, bias=bias))
self.in_proj_k = nn.Sequential(*k_layers)
self.in_proj_v = nn.Sequential(*v_layers)
if self.downsample:
self.out_proj = Linear(out_proj_size, self.head_dim, bias=bias)
else:
self.out_proj = Linear(out_proj_size, out_channels, bias=bias)
self.scaling = self.head_dim**-0.5
def forward(
self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, use_scalar_bias=False,
):
"""Input shape: Time x Batch x Channel
Self-attention can be implemented by passing in the same arguments for
query, key and value. Future timesteps can be masked with the
`mask_future_timesteps` argument. Padding elements can be excluded from
the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
batch x src_len, where padding elements are indicated by 1s.
"""
src_len, bsz, out_channels = key.size()
tgt_len = query.size(0)
assert list(query.size()) == [tgt_len, bsz, out_channels]
assert key.size() == value.size()
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.downsample:
size = bsz
else:
size = bsz * self.num_heads
k = key
v = value
q = query
if self.project_input:
q = self.in_proj_q(q)
k = self.in_proj_k(k)
v = self.in_proj_v(v)
src_len = k.size()[0]
q *= self.scaling
if not self.downsample:
q = q.view(tgt_len, size, self.head_dim)
k = k.view(src_len, size, self.head_dim)
v = v.view(src_len, size, self.head_dim)
q = q.transpose(0, 1)
k = k.transpose(0, 1)
v = v.transpose(0, 1)
attn_weights = torch.bmm(q, k.transpose(1, 2))
if mask_future_timesteps:
assert query.size() == key.size(), \
'mask_future_timesteps only applies to self-attention'
attn_weights *= torch.tril(
attn_weights.data.new([1]).expand(tgt_len, tgt_len).clone(),
diagonal=-1,
)[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
attn_weights += torch.triu(
attn_weights.data.new([-math.inf]).expand(tgt_len, tgt_len).clone(),
diagonal=0
)[:, ::self.head_index + 1 if self.downsample else 1].unsqueeze(0)
tgt_size = tgt_len
if use_scalar_bias:
attn_weights = scalar_bias(attn_weights, 2)
v = scalar_bias(v, 1)
tgt_size += 1
if key_padding_mask is not None:
# don't attend to padding symbols
if key_padding_mask.max() > 0:
if self.downsample:
attn_weights = attn_weights.view(bsz, 1, tgt_len, src_len)
else:
attn_weights = attn_weights.view(size, self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
-math.inf,
)
attn_weights = attn_weights.view(size, tgt_len, src_len)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn = torch.bmm(attn_weights, v)
if self.downsample:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.head_dim)
else:
attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
attn = self.out_proj(attn)
return attn, attn_weights
class DownsampledMultiHeadAttention(nn.ModuleList):
"""
Multi-headed attention with Gating and Downsampling
"""
def __init__(
self, out_channels, embed_dim, num_heads, dropout=0., bias=True,
project_input=True, gated=False, downsample=False,
):
self.embed_dim = embed_dim
self.num_heads = num_heads
self.dropout = dropout
self.head_dim = embed_dim // num_heads
self.downsample = downsample
self.gated = gated
self.project_input = project_input
assert self.head_dim * num_heads == embed_dim
if self.downsample:
attention_heads = []
for index in range(self.num_heads):
attention_heads.append(
SingleHeadAttention(
out_channels, self.embed_dim, self.head_dim, index,
self.dropout, bias, self.project_input, self.gated,
self.downsample, self.num_heads,
)
)
super().__init__(modules=attention_heads)
self.out_proj = Linear(embed_dim, out_channels, bias=bias)
else:
# either we have a list of attention heads, or just one attention head
# if not being downsampled, we can do the heads with one linear layer instead of separate ones
super().__init__()
self.attention_module = SingleHeadAttention(
out_channels, self.embed_dim, self.head_dim, 1, self.dropout,
bias, self.project_input, self.gated, self.downsample, self.num_heads,
)
def forward(
self, query, key, value, mask_future_timesteps=False,
key_padding_mask=None, use_scalar_bias=False,
):
src_len, bsz, embed_dim = key.size()
tgt_len = query.size(0)
assert embed_dim == self.embed_dim
assert list(query.size()) == [tgt_len, bsz, embed_dim]
assert key.size() == value.size()
tgt_size = tgt_len
if use_scalar_bias:
tgt_size += 1
attn = []
attn_weights = []
if self.downsample:
for attention_head_number in range(self.num_heads):
# call the forward of each attention head
_attn, _attn_weight = self[attention_head_number](
query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
)
attn.append(_attn)
attn_weights.append(_attn_weight)
full_attn = torch.cat(attn, dim=2)
full_attn = self.out_proj(full_attn)
return full_attn, attn_weights[0].clone()
else:
_attn, _attn_weight = self.attention_module(
query, key, value, mask_future_timesteps, key_padding_mask, use_scalar_bias,
)
attn.append(_attn)
attn_weights.append(_attn_weight)
full_attn = torch.cat(attn, dim=2)
full_attn_weights = torch.cat(attn_weights)
full_attn_weights = full_attn_weights.view(bsz, self.num_heads, tgt_size, src_len)
full_attn_weights = full_attn_weights.sum(dim=1) / self.num_heads
return full_attn, full_attn_weights
class Downsample(nn.Module):
"""
Selects every nth element, where n is the index
"""
def __init__(self, index):
super().__init__()
self.index = index
def forward(self, x):
return x[::self.index+1]
def Linear(in_features, out_features, dropout=0., bias=True):
"""Weight-normalized Linear layer (input: B x T x C)"""
m = nn.Linear(in_features, out_features, bias=bias)
m.weight.data.normal_(mean=0, std=math.sqrt((1 - dropout) / in_features))
m.bias.data.zero_()
return nn.utils.weight_norm(m)
def GatedLinear(in_features, out_features, dropout=0., bias=True):
"""Weight-normalized Linear layer (input: B x T x C) with interspersed GLU units"""
return nn.Sequential(
Linear(in_features, out_features*4, dropout, bias),
nn.GLU(),
Linear(out_features*2, out_features*2, dropout, bias),
nn.GLU(),
Linear(out_features, out_features, dropout, bias)
)
...@@ -13,7 +13,6 @@ class GradMultiply(torch.autograd.Function): ...@@ -13,7 +13,6 @@ class GradMultiply(torch.autograd.Function):
def forward(ctx, x, scale): def forward(ctx, x, scale):
ctx.scale = scale ctx.scale = scale
res = x.new(x) res = x.new(x)
ctx.mark_shared_storage((x, res))
return res return res
@staticmethod @staticmethod
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment