Commit 0e101e9c authored by Myle Ott's avatar Myle Ott
Browse files

Misc changes to simplify upcoming tutorial

parent d473620e
......@@ -109,7 +109,7 @@ def main(parsed_args):
print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum()
score_sum += utils.item(pos_scores.sum())
count += pos_scores.numel() - skipped_toks
if args.output_word_probs or args.output_word_stats:
......
......@@ -88,7 +88,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
if isinstance(max_positions, float) or isinstance(max_positions, int):
return size_fn(idx) <= max_positions
else:
return all(a <= b for a, b in zip(size_fn(idx), max_positions))
return all(a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions))
ignored = []
itr = collect_filtered(check_size, indices, ignored)
......
......@@ -13,7 +13,10 @@ from fairseq import utils
from . import data_utils, FairseqDataset
def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False):
def collate(
samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False,
input_feeding=True,
):
if len(samples) == 0:
return {}
......@@ -35,6 +38,10 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal
target = None
if samples[0].get('target', None) is not None:
target = merge('target', left_pad=left_pad_target)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
if input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
......@@ -43,21 +50,21 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples)
else:
ntokens = sum(len(s['source']) for s in samples)
return {
batch = {
'id': id,
'ntokens': ntokens,
'net_input': {
'src_tokens': src_tokens,
'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
},
'target': target,
}
if prev_output_tokens is not None:
batch['net_input']['prev_output_tokens'] = prev_output_tokens
return batch
class LanguagePairDataset(FairseqDataset):
......@@ -81,6 +88,9 @@ class LanguagePairDataset(FairseqDataset):
sentence. Default: ``1024``
shuffle (bool, optional): shuffle dataset elements before batching.
Default: ``True``
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for input feeding/teacher forcing.
Default: ``True``
"""
def __init__(
......@@ -88,7 +98,7 @@ class LanguagePairDataset(FairseqDataset):
tgt=None, tgt_sizes=None, tgt_dict=None,
left_pad_source=True, left_pad_target=False,
max_source_positions=1024, max_target_positions=1024,
shuffle=True,
shuffle=True, input_feeding=True,
):
if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad()
......@@ -105,6 +115,7 @@ class LanguagePairDataset(FairseqDataset):
self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions
self.shuffle = shuffle
self.input_feeding = input_feeding
def __getitem__(self, index):
return {
......@@ -119,22 +130,23 @@ class LanguagePairDataset(FairseqDataset):
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Returned mini-batches contain the following keys:
Returns mini-batches with the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will appear
on the left if ``left_pad_source`` is True.
on the left if *left_pad_source* is True.
- `src_lengths` (torch.LongTensor): 1D Tensor of the unpadded lengths
of each source sentence of shape `(bsz)`
- `prev_output_tokens` (torch.LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one position for
input feeding/teacher forcing, of shape `(bsz, tgt_len)`. Padding
will appear on the left if ``left_pad_target`` is True.
input feeding/teacher forcing, of shape `(bsz, tgt_len)`. This key
will only be present if *input_feeding* is ``True``. Padding will
appear on the left if *left_pad_target* is ``True``.
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
left if ``left_pad_target`` is True.
left if *left_pad_target* is ``True``.
Args:
samples (List[dict]): samples to collate
......@@ -145,6 +157,7 @@ class LanguagePairDataset(FairseqDataset):
return collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
input_feeding=self.input_feeding,
)
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
......
......@@ -25,6 +25,9 @@ def collate(samples, pad_idx, eos_idx):
'ntokens': sum(len(s['target']) for s in samples),
'net_input': {
'src_tokens': merge('source'),
'src_lengths': torch.LongTensor([
s['source'].numel() for s in samples
]),
},
'target': merge('target'),
}
......@@ -42,7 +45,7 @@ class MonolingualDataset(FairseqDataset):
Default: ``True``
"""
def __init__(self, dataset, sizes, vocab, shuffle):
def __init__(self, dataset, sizes, vocab, shuffle=True):
self.dataset = dataset
self.sizes = np.array(sizes)
self.vocab = vocab
......@@ -58,7 +61,7 @@ class MonolingualDataset(FairseqDataset):
def collater(self, samples):
"""Merge a list of samples to form a mini-batch.
Returned mini-batches contain the following keys:
Returns mini-batches with the following keys:
- `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys:
......
......@@ -5,8 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import FairseqDecoder, FairseqEncoder
......@@ -34,11 +35,19 @@ class BaseFairseqModel(nn.Module):
def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output."""
if hasattr(self, 'decoder'):
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
elif torch.is_tensor(net_output):
logits = net_output.float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
raise NotImplementedError
def max_positions(self):
"""Maximum length supported by the model."""
raise NotImplementedError
return None
def max_decoder_positions(self):
"""Maximum length supported by the decoder."""
......@@ -138,7 +147,7 @@ class FairseqLanguageModel(BaseFairseqModel):
self.decoder = decoder
assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens):
def forward(self, src_tokens, src_lengths):
return self.decoder(src_tokens)
def max_positions(self):
......
......@@ -25,7 +25,7 @@ class FairseqTask(object):
@classmethod
def setup_task(cls, args, **kwargs):
raise NotImplementedError
return cls(args)
def load_dataset(self, split, combine=False):
raise NotImplementedError
......
......@@ -403,6 +403,16 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
def resolve_max_positions(*args):
"""Resolve max position constraints from multiple sources."""
def nullsafe_min(l):
minim = None
for item in l:
if minim is None:
minim = item
elif item is not None and item < minim:
minim = item
return minim
max_positions = None
for arg in args:
if max_positions is None:
......@@ -411,5 +421,7 @@ def resolve_max_positions(*args):
if isinstance(arg, float) or isinstance(arg, int):
max_positions = min(max_positions, arg)
else:
max_positions = tuple(map(min, zip(max_positions, arg)))
max_positions = tuple(
map(nullsafe_min, zip(max_positions, arg))
)
return max_positions
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment