Commit 0e101e9c authored by Myle Ott's avatar Myle Ott
Browse files

Misc changes to simplify upcoming tutorial

parent d473620e
...@@ -109,7 +109,7 @@ def main(parsed_args): ...@@ -109,7 +109,7 @@ def main(parsed_args):
print('| Skipping tokens with inf scores:', print('| Skipping tokens with inf scores:',
task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()])) task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
pos_scores = pos_scores[(~inf_scores).nonzero()] pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum() score_sum += utils.item(pos_scores.sum())
count += pos_scores.numel() - skipped_toks count += pos_scores.numel() - skipped_toks
if args.output_word_probs or args.output_word_stats: if args.output_word_probs or args.output_word_stats:
......
...@@ -88,7 +88,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False): ...@@ -88,7 +88,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
if isinstance(max_positions, float) or isinstance(max_positions, int): if isinstance(max_positions, float) or isinstance(max_positions, int):
return size_fn(idx) <= max_positions return size_fn(idx) <= max_positions
else: else:
return all(a <= b for a, b in zip(size_fn(idx), max_positions)) return all(a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions))
ignored = [] ignored = []
itr = collect_filtered(check_size, indices, ignored) itr = collect_filtered(check_size, indices, ignored)
......
...@@ -13,7 +13,10 @@ from fairseq import utils ...@@ -13,7 +13,10 @@ from fairseq import utils
from . import data_utils, FairseqDataset from . import data_utils, FairseqDataset
def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False): def collate(
samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=False,
input_feeding=True,
):
if len(samples) == 0: if len(samples) == 0:
return {} return {}
...@@ -35,29 +38,33 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal ...@@ -35,29 +38,33 @@ def collate(samples, pad_idx, eos_idx, left_pad_source=True, left_pad_target=Fal
target = None target = None
if samples[0].get('target', None) is not None: if samples[0].get('target', None) is not None:
target = merge('target', left_pad=left_pad_target) target = merge('target', left_pad=left_pad_target)
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=left_pad_target,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
target = target.index_select(0, sort_order) target = target.index_select(0, sort_order)
ntokens = sum(len(s['target']) for s in samples) ntokens = sum(len(s['target']) for s in samples)
if input_feeding:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens = merge(
'target',
left_pad=left_pad_target,
move_eos_to_beginning=True,
)
prev_output_tokens = prev_output_tokens.index_select(0, sort_order)
else: else:
ntokens = sum(len(s['source']) for s in samples) ntokens = sum(len(s['source']) for s in samples)
return { batch = {
'id': id, 'id': id,
'ntokens': ntokens, 'ntokens': ntokens,
'net_input': { 'net_input': {
'src_tokens': src_tokens, 'src_tokens': src_tokens,
'src_lengths': src_lengths, 'src_lengths': src_lengths,
'prev_output_tokens': prev_output_tokens,
}, },
'target': target, 'target': target,
} }
if prev_output_tokens is not None:
batch['net_input']['prev_output_tokens'] = prev_output_tokens
return batch
class LanguagePairDataset(FairseqDataset): class LanguagePairDataset(FairseqDataset):
...@@ -81,6 +88,9 @@ class LanguagePairDataset(FairseqDataset): ...@@ -81,6 +88,9 @@ class LanguagePairDataset(FairseqDataset):
sentence. Default: ``1024`` sentence. Default: ``1024``
shuffle (bool, optional): shuffle dataset elements before batching. shuffle (bool, optional): shuffle dataset elements before batching.
Default: ``True`` Default: ``True``
input_feeding (bool, optional): create a shifted version of the targets
to be passed into the model for input feeding/teacher forcing.
Default: ``True``
""" """
def __init__( def __init__(
...@@ -88,7 +98,7 @@ class LanguagePairDataset(FairseqDataset): ...@@ -88,7 +98,7 @@ class LanguagePairDataset(FairseqDataset):
tgt=None, tgt_sizes=None, tgt_dict=None, tgt=None, tgt_sizes=None, tgt_dict=None,
left_pad_source=True, left_pad_target=False, left_pad_source=True, left_pad_target=False,
max_source_positions=1024, max_target_positions=1024, max_source_positions=1024, max_target_positions=1024,
shuffle=True, shuffle=True, input_feeding=True,
): ):
if tgt_dict is not None: if tgt_dict is not None:
assert src_dict.pad() == tgt_dict.pad() assert src_dict.pad() == tgt_dict.pad()
...@@ -105,6 +115,7 @@ class LanguagePairDataset(FairseqDataset): ...@@ -105,6 +115,7 @@ class LanguagePairDataset(FairseqDataset):
self.max_source_positions = max_source_positions self.max_source_positions = max_source_positions
self.max_target_positions = max_target_positions self.max_target_positions = max_target_positions
self.shuffle = shuffle self.shuffle = shuffle
self.input_feeding = input_feeding
def __getitem__(self, index): def __getitem__(self, index):
return { return {
...@@ -119,22 +130,23 @@ class LanguagePairDataset(FairseqDataset): ...@@ -119,22 +130,23 @@ class LanguagePairDataset(FairseqDataset):
def collater(self, samples): def collater(self, samples):
"""Merge a list of samples to form a mini-batch. """Merge a list of samples to form a mini-batch.
Returned mini-batches contain the following keys: Returns mini-batches with the following keys:
- `id` (torch.LongTensor): example IDs in the original input order - `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch - `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys: - `net_input` (dict): the input to the Model, containing keys:
- `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in - `src_tokens` (torch.LongTensor): a padded 2D Tensor of tokens in
the source sentence of shape `(bsz, src_len)`. Padding will appear the source sentence of shape `(bsz, src_len)`. Padding will appear
on the left if ``left_pad_source`` is True. on the left if *left_pad_source* is True.
- `src_lengths` (torch.LongTensor): 1D Tensor of the unpadded lengths - `src_lengths` (torch.LongTensor): 1D Tensor of the unpadded lengths
of each source sentence of shape `(bsz)` of each source sentence of shape `(bsz)`
- `prev_output_tokens` (torch.LongTensor): a padded 2D Tensor of - `prev_output_tokens` (torch.LongTensor): a padded 2D Tensor of
tokens in the target sentence, shifted right by one position for tokens in the target sentence, shifted right by one position for
input feeding/teacher forcing, of shape `(bsz, tgt_len)`. Padding input feeding/teacher forcing, of shape `(bsz, tgt_len)`. This key
will appear on the left if ``left_pad_target`` is True. will only be present if *input_feeding* is ``True``. Padding will
appear on the left if *left_pad_target* is ``True``.
- `target` (torch.LongTensor): a padded 2D Tensor of tokens in the - `target` (torch.LongTensor): a padded 2D Tensor of tokens in the
target sentence of shape `(bsz, tgt_len)`. Padding will appear on the target sentence of shape `(bsz, tgt_len)`. Padding will appear on the
left if ``left_pad_target`` is True. left if *left_pad_target* is ``True``.
Args: Args:
samples (List[dict]): samples to collate samples (List[dict]): samples to collate
...@@ -145,6 +157,7 @@ class LanguagePairDataset(FairseqDataset): ...@@ -145,6 +157,7 @@ class LanguagePairDataset(FairseqDataset):
return collate( return collate(
samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(), samples, pad_idx=self.src_dict.pad(), eos_idx=self.src_dict.eos(),
left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target, left_pad_source=self.left_pad_source, left_pad_target=self.left_pad_target,
input_feeding=self.input_feeding,
) )
def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128): def get_dummy_batch(self, num_tokens, max_positions, src_len=128, tgt_len=128):
......
...@@ -25,6 +25,9 @@ def collate(samples, pad_idx, eos_idx): ...@@ -25,6 +25,9 @@ def collate(samples, pad_idx, eos_idx):
'ntokens': sum(len(s['target']) for s in samples), 'ntokens': sum(len(s['target']) for s in samples),
'net_input': { 'net_input': {
'src_tokens': merge('source'), 'src_tokens': merge('source'),
'src_lengths': torch.LongTensor([
s['source'].numel() for s in samples
]),
}, },
'target': merge('target'), 'target': merge('target'),
} }
...@@ -42,7 +45,7 @@ class MonolingualDataset(FairseqDataset): ...@@ -42,7 +45,7 @@ class MonolingualDataset(FairseqDataset):
Default: ``True`` Default: ``True``
""" """
def __init__(self, dataset, sizes, vocab, shuffle): def __init__(self, dataset, sizes, vocab, shuffle=True):
self.dataset = dataset self.dataset = dataset
self.sizes = np.array(sizes) self.sizes = np.array(sizes)
self.vocab = vocab self.vocab = vocab
...@@ -58,7 +61,7 @@ class MonolingualDataset(FairseqDataset): ...@@ -58,7 +61,7 @@ class MonolingualDataset(FairseqDataset):
def collater(self, samples): def collater(self, samples):
"""Merge a list of samples to form a mini-batch. """Merge a list of samples to form a mini-batch.
Returned mini-batches contain the following keys: Returns mini-batches with the following keys:
- `id` (torch.LongTensor): example IDs in the original input order - `id` (torch.LongTensor): example IDs in the original input order
- `ntokens` (int): total number of tokens in the batch - `ntokens` (int): total number of tokens in the batch
- `net_input` (dict): the input to the Model, containing keys: - `net_input` (dict): the input to the Model, containing keys:
......
...@@ -5,8 +5,9 @@ ...@@ -5,8 +5,9 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from . import FairseqDecoder, FairseqEncoder from . import FairseqDecoder, FairseqEncoder
...@@ -34,11 +35,19 @@ class BaseFairseqModel(nn.Module): ...@@ -34,11 +35,19 @@ class BaseFairseqModel(nn.Module):
def get_normalized_probs(self, net_output, log_probs, sample=None): def get_normalized_probs(self, net_output, log_probs, sample=None):
"""Get normalized probabilities (or log probs) from a net's output.""" """Get normalized probabilities (or log probs) from a net's output."""
return self.decoder.get_normalized_probs(net_output, log_probs, sample) if hasattr(self, 'decoder'):
return self.decoder.get_normalized_probs(net_output, log_probs, sample)
elif torch.is_tensor(net_output):
logits = net_output.float()
if log_probs:
return F.log_softmax(logits, dim=-1)
else:
return F.softmax(logits, dim=-1)
raise NotImplementedError
def max_positions(self): def max_positions(self):
"""Maximum length supported by the model.""" """Maximum length supported by the model."""
raise NotImplementedError return None
def max_decoder_positions(self): def max_decoder_positions(self):
"""Maximum length supported by the decoder.""" """Maximum length supported by the decoder."""
...@@ -138,7 +147,7 @@ class FairseqLanguageModel(BaseFairseqModel): ...@@ -138,7 +147,7 @@ class FairseqLanguageModel(BaseFairseqModel):
self.decoder = decoder self.decoder = decoder
assert isinstance(self.decoder, FairseqDecoder) assert isinstance(self.decoder, FairseqDecoder)
def forward(self, src_tokens): def forward(self, src_tokens, src_lengths):
return self.decoder(src_tokens) return self.decoder(src_tokens)
def max_positions(self): def max_positions(self):
......
...@@ -25,7 +25,7 @@ class FairseqTask(object): ...@@ -25,7 +25,7 @@ class FairseqTask(object):
@classmethod @classmethod
def setup_task(cls, args, **kwargs): def setup_task(cls, args, **kwargs):
raise NotImplementedError return cls(args)
def load_dataset(self, split, combine=False): def load_dataset(self, split, combine=False):
raise NotImplementedError raise NotImplementedError
......
...@@ -403,6 +403,16 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'): ...@@ -403,6 +403,16 @@ def checkpoint_paths(path, pattern=r'checkpoint(\d+)\.pt'):
def resolve_max_positions(*args): def resolve_max_positions(*args):
"""Resolve max position constraints from multiple sources.""" """Resolve max position constraints from multiple sources."""
def nullsafe_min(l):
minim = None
for item in l:
if minim is None:
minim = item
elif item is not None and item < minim:
minim = item
return minim
max_positions = None max_positions = None
for arg in args: for arg in args:
if max_positions is None: if max_positions is None:
...@@ -411,5 +421,7 @@ def resolve_max_positions(*args): ...@@ -411,5 +421,7 @@ def resolve_max_positions(*args):
if isinstance(arg, float) or isinstance(arg, int): if isinstance(arg, float) or isinstance(arg, int):
max_positions = min(max_positions, arg) max_positions = min(max_positions, arg)
else: else:
max_positions = tuple(map(min, zip(max_positions, arg))) max_positions = tuple(
map(nullsafe_min, zip(max_positions, arg))
)
return max_positions return max_positions
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment