Commit 98daf039 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Support LM generation from interactive.py (fixes #526)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/528

Differential Revision: D14218377

Pulled By: myleott

fbshipit-source-id: facb0a32f6aebf56a4fea7259080394ad2d2d846
parent 00493490
......@@ -103,6 +103,9 @@ def main(parsed_args):
count = 0
if args.remove_bpe is not None:
if args.remove_bpe == 'sentencepiece':
raise NotImplementedError
else:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
bpe_len = len(bpe_cont)
......
......@@ -28,19 +28,24 @@ def collate(samples, pad_idx, eos_idx):
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False,
)
src_tokens = merge('source')
if samples[0]['target'] is not None:
is_target_list = isinstance(samples[0]['target'], list)
target = merge('target', is_target_list)
else:
target = src_tokens
return {
'id': torch.LongTensor([s['id'] for s in samples]),
'nsentences': len(samples),
'ntokens': sum(len(s['source']) for s in samples),
'net_input': {
'src_tokens': merge('source'),
'src_tokens': src_tokens,
'src_lengths': torch.LongTensor([
s['source'].numel() for s in samples
]),
},
'target': merge('target', is_target_list),
'target': target,
}
......@@ -72,8 +77,12 @@ class MonolingualDataset(FairseqDataset):
self.targets = targets
def __getitem__(self, index):
if self.targets is not None:
source, future_target, past_target = self.dataset[index]
source, target = self._make_source_target(source, future_target, past_target)
else:
source = self.dataset[index]
target = None
return {'id': index, 'source': source, 'target': target}
def __len__(self):
......
......@@ -32,6 +32,7 @@ class TransformEosDataset(FairseqDataset):
remove_eos_from_src=False,
append_eos_to_tgt=False,
remove_eos_from_tgt=False,
has_target=True,
):
if not isinstance(dataset, FairseqDataset):
raise ValueError('dataset must be an instance of FairseqDataset')
......@@ -46,6 +47,7 @@ class TransformEosDataset(FairseqDataset):
self.remove_eos_from_src = remove_eos_from_src
self.append_eos_to_tgt = append_eos_to_tgt
self.remove_eos_from_tgt = remove_eos_from_tgt
self.has_target = has_target
# precompute how we should adjust the reported sizes
self._src_delta = 0
......@@ -64,7 +66,7 @@ class TransformEosDataset(FairseqDataset):
self._checked_src = True
def _check_tgt(self, tgt, expect_eos):
if not self._checked_tgt:
if self.has_target and not self._checked_tgt:
assert (tgt[-1] == self.eos[0]) == expect_eos
self._checked_tgt = True
......@@ -101,8 +103,11 @@ class TransformEosDataset(FairseqDataset):
return self.dataset.num_tokens(index)
def size(self, index):
if self.has_target:
src_len, tgt_len = self.dataset.size(index)
return (src_len + self._src_delta, tgt_len + self._tgt_delta)
else:
return self.dataset.size(index)
def ordered_indices(self):
# NOTE: we assume that the ordering does not change based on the
......
......@@ -343,16 +343,31 @@ class SequenceGenerator(object):
banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = float('-Inf')
lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
if prefix_tokens is not None and step < prefix_tokens.size(1):
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather(
probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1)
).expand(-1, cand_size)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size)
).view(-1, 1).repeat(1, cand_size)
if step > 0:
# save cumulative scores for each hypothesis
cand_scores.add_(scores[:, step - 1].view(bsz, beam_size).repeat(1, 2))
cand_indices = prefix_tokens[:, step].view(-1, 1).repeat(1, cand_size)
cand_beams = torch.zeros_like(cand_indices)
# handle prefixes of different lengths
partial_prefix_mask = prefix_tokens[:, step].eq(self.pad)
if partial_prefix_mask.any():
partial_scores, partial_indices, partial_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
)
cand_scores[partial_prefix_mask] = partial_scores[partial_prefix_mask]
cand_indices[partial_prefix_mask] = partial_indices[partial_prefix_mask]
cand_beams[partial_prefix_mask] = partial_beams[partial_prefix_mask]
else:
cand_scores, cand_indices, cand_beams = self.search.step(
step,
......@@ -531,7 +546,13 @@ class EnsembleModel(torch.nn.Module):
@torch.no_grad()
def forward_decoder(self, tokens, encoder_outs):
if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], self.incremental_states, log_probs=True)
return self._decode_one(
tokens,
self.models[0],
encoder_outs[0] if self.has_encoder() else None,
self.incremental_states,
log_probs=True,
)
log_probs = []
avg_attn = None
......
......@@ -6,9 +6,11 @@
# can be found in the PATENTS file in the same directory.
import itertools
import numpy as np
import os
import torch
import numpy as np
from fairseq.data import (
ConcatDataset,
Dictionary,
......@@ -17,6 +19,7 @@ from fairseq.data import (
IndexedRawTextDataset,
MonolingualDataset,
TokenBlockDataset,
TransformEosDataset,
TruncatedDictionary,
)
......@@ -185,6 +188,43 @@ class LanguageModelingTask(FairseqTask):
targets=self.targets,
)
def build_dataset_for_inference(self, src_tokens, src_lengths):
return TransformEosDataset(
MonolingualDataset(
TokenBlockDataset(
src_tokens,
src_lengths,
block_size=None,
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode='eos',
include_targets=False,
),
src_lengths,
self.source_dictionary,
self.target_dictionary,
add_eos_for_other_targets=False,
shuffle=False,
),
eos=self.source_dictionary.eos(),
# remove EOS since this will be used as a prefix for generation
remove_eos_from_src=True,
has_target=False,
)
def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad():
if prefix_tokens is None:
# note: EOS has already been removed in build_dataset_for_inference
prefix_tokens = sample['net_input']['src_tokens']
return generator.generate(models, sample, prefix_tokens=prefix_tokens)
@property
def source_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.output_dictionary
@property
def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
......
......@@ -134,6 +134,7 @@ def main(args):
# sort output to match input order
for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
if src_dict is not None:
src_str = src_dict.string(src_tokens, args.remove_bpe)
print('S-{}\t{}'.format(id, src_str))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment