Commit 42be3ebd authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#483)

Summary:
Changelog:
- `4889802`: can now remove detokenize sentencepiece output with `--remove-bpe=sentencepiece` (fixes #331). Also added `--sacrebleu` for computing detokenized BLEU.
- `0d76427`: fix assertion error when training language model with dataset containing empty sentences
- minor bug and style fixes
Pull Request resolved: https://github.com/pytorch/fairseq/pull/483

Differential Revision: D13867899

Pulled By: myleott

fbshipit-source-id: 25c940b847fe270262ac8f5ac838407b3977fdda
parent 66ce2175
......@@ -35,6 +35,31 @@ class BleuStat(ctypes.Structure):
]
class SacrebleuScorer(object):
def __init__(self):
import sacrebleu
self.sacrebleu = sacrebleu
self.reset()
def reset(self, one_init=False):
if one_init:
raise NotImplementedError
self.ref = []
self.sys = []
def add_string(self, ref, pred):
self.ref.append(ref)
self.sys.append(pred)
def score(self, order=4):
return self.result_string(order).bleu
def result_string(self, order=4):
if order != 4:
raise NotImplementedError
return self.sacrebleu.corpus_bleu(self.sys, [self.ref])
class Scorer(object):
def __init__(self, pad, eos, unk):
self.stat = BleuStat()
......
......@@ -94,7 +94,8 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
return all(
all(a is None or b is None or a <= b
for a, b in zip(idx_size[key], max_positions[key]))
for key in intersect_keys)
for key in intersect_keys
)
else:
return all(a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions))
......
......@@ -57,8 +57,12 @@ class Dictionary(object):
else:
return self[i]
if bpe_symbol == 'sentencepiece':
sent = ''.join(token_string(i) for i in tensor if i != self.eos())
sent = sent.replace('\u2581', ' ').strip()
else:
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None:
if bpe_symbol is not None and bpe_symbol != 'sentencepiece':
sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent
......
......@@ -66,11 +66,9 @@ class TokenBlockDataset(FairseqDataset):
if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos':
self.slice_indices = np.empty((sum(sizes > 1), 2), dtype=int)
self.slice_indices = np.empty((len(sizes), 2), dtype=int)
curr = 0
for i, sz in enumerate(sizes):
# skip samples with just 1 example (which would be just the eos token)
if sz > 1:
self.slice_indices[i] = (curr, curr + sz)
curr += sz
else:
......@@ -78,6 +76,7 @@ class TokenBlockDataset(FairseqDataset):
self.sizes = np.array([e - s for s, e in self.slice_indices])
self.slice_indices = np.array(self.slice_indices, dtype=int)
# build index mapping block indices to the underlying dataset indices
self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int)
ds_idx, ds_remaining = -1, 0
......
......@@ -29,7 +29,7 @@ class BaseFairseqModel(nn.Module):
@classmethod
def build_model(cls, args, task):
"""Build a new model instance."""
raise NotImplementedError
raise NotImplementedError('FairseqModels must implement the build_model method')
def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output."""
......
......@@ -298,7 +298,7 @@ def add_common_eval_args(group):
group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring')
help='remove BPE tokens before scoring (can be set to sentencepiece)')
group.add_argument('--quiet', action='store_true',
help='only print final scores')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
......@@ -350,6 +350,8 @@ def add_generation_args(parser):
help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--sacrebleu', action='store_true',
help='score with sacrebleu')
group.add_argument('--score-reference', action='store_true',
help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
......
......@@ -246,6 +246,7 @@ class MultilingualTranslationTask(FairseqTask):
for k, v in agg_logging_output.items()
}
flat_logging_output['loss'] = sum_over_languages('loss')
if any('nll_loss' in logging_output for logging_output in agg_logging_outputs.values()):
flat_logging_output['nll_loss'] = sum_over_languages('nll_loss')
flat_logging_output['sample_size'] = sum_over_languages('sample_size')
flat_logging_output['nsentences'] = sum_over_languages('nsentences')
......
......@@ -438,9 +438,7 @@ def resolve_max_positions(*args):
def import_user_module(args):
if hasattr(args, 'user_dir'):
module_path = args.user_dir
module_path = getattr(args, 'user_dir', None)
if module_path is not None:
module_path = os.path.abspath(args.user_dir)
module_parent, module_name = os.path.split(module_path)
......
......@@ -95,6 +95,9 @@ def main(args):
translator.cuda()
# Generate and compute BLEU score
if args.sacrebleu:
scorer = bleu.SacrebleuScorer()
else:
scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
num_sentences = 0
has_target = True
......@@ -160,6 +163,9 @@ def main(args):
# Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize(
target_str, tgt_dict, add_if_not_exist=True)
if hasattr(scorer, 'add_string'):
scorer.add_string(target_str, hypo_str)
else:
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0))
......
......@@ -26,6 +26,8 @@ def get_parser():
type=int, help='consider ngrams up to this order')
parser.add_argument('--ignore-case', action='store_true',
help='case-insensitive scoring')
parser.add_argument('--sacrebleu', action='store_true',
help='score with sacrebleu')
# fmt: on
return parser
......@@ -49,6 +51,13 @@ def main():
else:
yield line
if args.sacrebleu:
import sacrebleu
def score(fdsys):
with open(args.ref) as fdref:
print(sacrebleu.corpus_bleu(fdsys, [fdref]))
else:
def score(fdsys):
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
......
......@@ -15,7 +15,7 @@ echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git
SCRIPTS=mosesdecoder/scripts
DETOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl
DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl
grep ^H $GEN \
| sed 's/^H\-//' \
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import unittest
import torch
from fairseq.data import TokenBlockDataset
import tests.utils as test_utils
class TestTokenBlockDataset(unittest.TestCase):
def _build_dataset(self, data, **kwargs):
sizes = [len(x) for x in data]
underlying_ds = test_utils.TestDataset(data)
return TokenBlockDataset(underlying_ds, sizes, **kwargs)
def test_eos_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([1]), # this should be filtered
torch.LongTensor([8, 7, 6, 1]),
]
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [1])
self.assertEqual(ds[2].tolist(), [8, 7, 6, 1])
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([1]), # this should be filtered
]
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1])
self.assertEqual(ds[2].tolist(), [1])
def test_block_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([9, 1]),
]
ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none')
self.assertEqual(ds[0].tolist(), [5, 4, 3])
self.assertEqual(ds[1].tolist(), [2, 1, 8])
self.assertEqual(ds[2].tolist(), [7, 6, 1])
self.assertEqual(ds[3].tolist(), [9, 1])
def test_complete_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([9, 1]),
]
ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1])
data = [
torch.LongTensor([4, 3, 2, 1]),
torch.LongTensor([5, 1]),
torch.LongTensor([1]),
torch.LongTensor([6, 1]),
]
ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete')
self.assertEqual(ds[0].tolist(), [4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [5, 1, 1])
self.assertEqual(ds[2].tolist(), [6, 1])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment