"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "58d2b10a2e9cd32dd9765dc50aca98690f516287"
Commit 98daf039 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Support LM generation from interactive.py (fixes #526)

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/528

Differential Revision: D14218377

Pulled By: myleott

fbshipit-source-id: facb0a32f6aebf56a4fea7259080394ad2d2d846
parent 00493490
...@@ -103,8 +103,11 @@ def main(parsed_args): ...@@ -103,8 +103,11 @@ def main(parsed_args):
count = 0 count = 0
if args.remove_bpe is not None: if args.remove_bpe is not None:
bpe_cont = args.remove_bpe.rstrip() if args.remove_bpe == 'sentencepiece':
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont)) raise NotImplementedError
else:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
bpe_len = len(bpe_cont) bpe_len = len(bpe_cont)
else: else:
bpe_toks = None bpe_toks = None
......
...@@ -28,19 +28,24 @@ def collate(samples, pad_idx, eos_idx): ...@@ -28,19 +28,24 @@ def collate(samples, pad_idx, eos_idx):
[s[key] for s in samples], pad_idx, eos_idx, left_pad=False, [s[key] for s in samples], pad_idx, eos_idx, left_pad=False,
) )
is_target_list = isinstance(samples[0]['target'], list) src_tokens = merge('source')
if samples[0]['target'] is not None:
is_target_list = isinstance(samples[0]['target'], list)
target = merge('target', is_target_list)
else:
target = src_tokens
return { return {
'id': torch.LongTensor([s['id'] for s in samples]), 'id': torch.LongTensor([s['id'] for s in samples]),
'nsentences': len(samples), 'nsentences': len(samples),
'ntokens': sum(len(s['source']) for s in samples), 'ntokens': sum(len(s['source']) for s in samples),
'net_input': { 'net_input': {
'src_tokens': merge('source'), 'src_tokens': src_tokens,
'src_lengths': torch.LongTensor([ 'src_lengths': torch.LongTensor([
s['source'].numel() for s in samples s['source'].numel() for s in samples
]), ]),
}, },
'target': merge('target', is_target_list), 'target': target,
} }
...@@ -72,8 +77,12 @@ class MonolingualDataset(FairseqDataset): ...@@ -72,8 +77,12 @@ class MonolingualDataset(FairseqDataset):
self.targets = targets self.targets = targets
def __getitem__(self, index): def __getitem__(self, index):
source, future_target, past_target = self.dataset[index] if self.targets is not None:
source, target = self._make_source_target(source, future_target, past_target) source, future_target, past_target = self.dataset[index]
source, target = self._make_source_target(source, future_target, past_target)
else:
source = self.dataset[index]
target = None
return {'id': index, 'source': source, 'target': target} return {'id': index, 'source': source, 'target': target}
def __len__(self): def __len__(self):
......
...@@ -32,6 +32,7 @@ class TransformEosDataset(FairseqDataset): ...@@ -32,6 +32,7 @@ class TransformEosDataset(FairseqDataset):
remove_eos_from_src=False, remove_eos_from_src=False,
append_eos_to_tgt=False, append_eos_to_tgt=False,
remove_eos_from_tgt=False, remove_eos_from_tgt=False,
has_target=True,
): ):
if not isinstance(dataset, FairseqDataset): if not isinstance(dataset, FairseqDataset):
raise ValueError('dataset must be an instance of FairseqDataset') raise ValueError('dataset must be an instance of FairseqDataset')
...@@ -46,6 +47,7 @@ class TransformEosDataset(FairseqDataset): ...@@ -46,6 +47,7 @@ class TransformEosDataset(FairseqDataset):
self.remove_eos_from_src = remove_eos_from_src self.remove_eos_from_src = remove_eos_from_src
self.append_eos_to_tgt = append_eos_to_tgt self.append_eos_to_tgt = append_eos_to_tgt
self.remove_eos_from_tgt = remove_eos_from_tgt self.remove_eos_from_tgt = remove_eos_from_tgt
self.has_target = has_target
# precompute how we should adjust the reported sizes # precompute how we should adjust the reported sizes
self._src_delta = 0 self._src_delta = 0
...@@ -64,7 +66,7 @@ class TransformEosDataset(FairseqDataset): ...@@ -64,7 +66,7 @@ class TransformEosDataset(FairseqDataset):
self._checked_src = True self._checked_src = True
def _check_tgt(self, tgt, expect_eos): def _check_tgt(self, tgt, expect_eos):
if not self._checked_tgt: if self.has_target and not self._checked_tgt:
assert (tgt[-1] == self.eos[0]) == expect_eos assert (tgt[-1] == self.eos[0]) == expect_eos
self._checked_tgt = True self._checked_tgt = True
...@@ -101,8 +103,11 @@ class TransformEosDataset(FairseqDataset): ...@@ -101,8 +103,11 @@ class TransformEosDataset(FairseqDataset):
return self.dataset.num_tokens(index) return self.dataset.num_tokens(index)
def size(self, index): def size(self, index):
src_len, tgt_len = self.dataset.size(index) if self.has_target:
return (src_len + self._src_delta, tgt_len + self._tgt_delta) src_len, tgt_len = self.dataset.size(index)
return (src_len + self._src_delta, tgt_len + self._tgt_delta)
else:
return self.dataset.size(index)
def ordered_indices(self): def ordered_indices(self):
# NOTE: we assume that the ordering does not change based on the # NOTE: we assume that the ordering does not change based on the
......
...@@ -343,16 +343,31 @@ class SequenceGenerator(object): ...@@ -343,16 +343,31 @@ class SequenceGenerator(object):
banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)] banned_tokens = [[] for bbsz_idx in range(bsz * beam_size)]
for bbsz_idx in range(bsz * beam_size): for bbsz_idx in range(bsz * beam_size):
lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = float('-Inf') lprobs[bbsz_idx, banned_tokens[bbsz_idx]] = -math.inf
if prefix_tokens is not None and step < prefix_tokens.size(1): if prefix_tokens is not None and step < prefix_tokens.size(1):
probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :] probs_slice = lprobs.view(bsz, -1, lprobs.size(-1))[:, 0, :]
cand_scores = torch.gather( cand_scores = torch.gather(
probs_slice, dim=1, probs_slice, dim=1,
index=prefix_tokens[:, step].view(-1, 1) index=prefix_tokens[:, step].view(-1, 1)
).expand(-1, cand_size) ).view(-1, 1).repeat(1, cand_size)
cand_indices = prefix_tokens[:, step].view(-1, 1).expand(bsz, cand_size) if step > 0:
# save cumulative scores for each hypothesis
cand_scores.add_(scores[:, step - 1].view(bsz, beam_size).repeat(1, 2))
cand_indices = prefix_tokens[:, step].view(-1, 1).repeat(1, cand_size)
cand_beams = torch.zeros_like(cand_indices) cand_beams = torch.zeros_like(cand_indices)
# handle prefixes of different lengths
partial_prefix_mask = prefix_tokens[:, step].eq(self.pad)
if partial_prefix_mask.any():
partial_scores, partial_indices, partial_beams = self.search.step(
step,
lprobs.view(bsz, -1, self.vocab_size),
scores.view(bsz, beam_size, -1)[:, :, :step],
)
cand_scores[partial_prefix_mask] = partial_scores[partial_prefix_mask]
cand_indices[partial_prefix_mask] = partial_indices[partial_prefix_mask]
cand_beams[partial_prefix_mask] = partial_beams[partial_prefix_mask]
else: else:
cand_scores, cand_indices, cand_beams = self.search.step( cand_scores, cand_indices, cand_beams = self.search.step(
step, step,
...@@ -531,7 +546,13 @@ class EnsembleModel(torch.nn.Module): ...@@ -531,7 +546,13 @@ class EnsembleModel(torch.nn.Module):
@torch.no_grad() @torch.no_grad()
def forward_decoder(self, tokens, encoder_outs): def forward_decoder(self, tokens, encoder_outs):
if len(self.models) == 1: if len(self.models) == 1:
return self._decode_one(tokens, self.models[0], encoder_outs[0], self.incremental_states, log_probs=True) return self._decode_one(
tokens,
self.models[0],
encoder_outs[0] if self.has_encoder() else None,
self.incremental_states,
log_probs=True,
)
log_probs = [] log_probs = []
avg_attn = None avg_attn = None
......
...@@ -6,9 +6,11 @@ ...@@ -6,9 +6,11 @@
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
import itertools import itertools
import numpy as np
import os import os
import torch
import numpy as np
from fairseq.data import ( from fairseq.data import (
ConcatDataset, ConcatDataset,
Dictionary, Dictionary,
...@@ -17,6 +19,7 @@ from fairseq.data import ( ...@@ -17,6 +19,7 @@ from fairseq.data import (
IndexedRawTextDataset, IndexedRawTextDataset,
MonolingualDataset, MonolingualDataset,
TokenBlockDataset, TokenBlockDataset,
TransformEosDataset,
TruncatedDictionary, TruncatedDictionary,
) )
...@@ -185,6 +188,43 @@ class LanguageModelingTask(FairseqTask): ...@@ -185,6 +188,43 @@ class LanguageModelingTask(FairseqTask):
targets=self.targets, targets=self.targets,
) )
def build_dataset_for_inference(self, src_tokens, src_lengths):
return TransformEosDataset(
MonolingualDataset(
TokenBlockDataset(
src_tokens,
src_lengths,
block_size=None,
pad=self.source_dictionary.pad(),
eos=self.source_dictionary.eos(),
break_mode='eos',
include_targets=False,
),
src_lengths,
self.source_dictionary,
self.target_dictionary,
add_eos_for_other_targets=False,
shuffle=False,
),
eos=self.source_dictionary.eos(),
# remove EOS since this will be used as a prefix for generation
remove_eos_from_src=True,
has_target=False,
)
def inference_step(self, generator, models, sample, prefix_tokens=None):
with torch.no_grad():
if prefix_tokens is None:
# note: EOS has already been removed in build_dataset_for_inference
prefix_tokens = sample['net_input']['src_tokens']
return generator.generate(models, sample, prefix_tokens=prefix_tokens)
@property
def source_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language
model."""
return self.output_dictionary
@property @property
def target_dictionary(self): def target_dictionary(self):
"""Return the :class:`~fairseq.data.Dictionary` for the language """Return the :class:`~fairseq.data.Dictionary` for the language
......
...@@ -134,8 +134,9 @@ def main(args): ...@@ -134,8 +134,9 @@ def main(args):
# sort output to match input order # sort output to match input order
for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]): for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]):
src_str = src_dict.string(src_tokens, args.remove_bpe) if src_dict is not None:
print('S-{}\t{}'.format(id, src_str)) src_str = src_dict.string(src_tokens, args.remove_bpe)
print('S-{}\t{}'.format(id, src_str))
# Process top predictions # Process top predictions
for hypo in hypos[:min(len(hypos), args.nbest)]: for hypo in hypos[:min(len(hypos), args.nbest)]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment