Commit 42be3ebd authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Merge internal changes (#483)

Summary:
Changelog:
- `4889802`: can now remove detokenize sentencepiece output with `--remove-bpe=sentencepiece` (fixes #331). Also added `--sacrebleu` for computing detokenized BLEU.
- `0d76427`: fix assertion error when training language model with dataset containing empty sentences
- minor bug and style fixes
Pull Request resolved: https://github.com/pytorch/fairseq/pull/483

Differential Revision: D13867899

Pulled By: myleott

fbshipit-source-id: 25c940b847fe270262ac8f5ac838407b3977fdda
parent 66ce2175
...@@ -35,6 +35,31 @@ class BleuStat(ctypes.Structure): ...@@ -35,6 +35,31 @@ class BleuStat(ctypes.Structure):
] ]
class SacrebleuScorer(object):
def __init__(self):
import sacrebleu
self.sacrebleu = sacrebleu
self.reset()
def reset(self, one_init=False):
if one_init:
raise NotImplementedError
self.ref = []
self.sys = []
def add_string(self, ref, pred):
self.ref.append(ref)
self.sys.append(pred)
def score(self, order=4):
return self.result_string(order).bleu
def result_string(self, order=4):
if order != 4:
raise NotImplementedError
return self.sacrebleu.corpus_bleu(self.sys, [self.ref])
class Scorer(object): class Scorer(object):
def __init__(self, pad, eos, unk): def __init__(self, pad, eos, unk):
self.stat = BleuStat() self.stat = BleuStat()
......
...@@ -92,9 +92,10 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False): ...@@ -92,9 +92,10 @@ def filter_by_size(indices, size_fn, max_positions, raise_exception=False):
assert isinstance(idx_size, dict) assert isinstance(idx_size, dict)
intersect_keys = set(max_positions.keys()) & set(idx_size.keys()) intersect_keys = set(max_positions.keys()) & set(idx_size.keys())
return all( return all(
all(a is None or b is None or a <= b all(a is None or b is None or a <= b
for a, b in zip(idx_size[key], max_positions[key])) for a, b in zip(idx_size[key], max_positions[key]))
for key in intersect_keys) for key in intersect_keys
)
else: else:
return all(a is None or b is None or a <= b return all(a is None or b is None or a <= b
for a, b in zip(size_fn(idx), max_positions)) for a, b in zip(size_fn(idx), max_positions))
......
...@@ -57,8 +57,12 @@ class Dictionary(object): ...@@ -57,8 +57,12 @@ class Dictionary(object):
else: else:
return self[i] return self[i]
sent = ' '.join(token_string(i) for i in tensor if i != self.eos()) if bpe_symbol == 'sentencepiece':
if bpe_symbol is not None: sent = ''.join(token_string(i) for i in tensor if i != self.eos())
sent = sent.replace('\u2581', ' ').strip()
else:
sent = ' '.join(token_string(i) for i in tensor if i != self.eos())
if bpe_symbol is not None and bpe_symbol != 'sentencepiece':
sent = (sent + ' ').replace(bpe_symbol, '').rstrip() sent = (sent + ' ').replace(bpe_symbol, '').rstrip()
return sent return sent
......
...@@ -66,18 +66,17 @@ class TokenBlockDataset(FairseqDataset): ...@@ -66,18 +66,17 @@ class TokenBlockDataset(FairseqDataset):
if curr_size > 0: if curr_size > 0:
self.slice_indices.append((tok_idx, tok_idx + curr_size)) self.slice_indices.append((tok_idx, tok_idx + curr_size))
elif break_mode == 'eos': elif break_mode == 'eos':
self.slice_indices = np.empty((sum(sizes > 1), 2), dtype=int) self.slice_indices = np.empty((len(sizes), 2), dtype=int)
curr = 0 curr = 0
for i, sz in enumerate(sizes): for i, sz in enumerate(sizes):
# skip samples with just 1 example (which would be just the eos token) self.slice_indices[i] = (curr, curr + sz)
if sz > 1:
self.slice_indices[i] = (curr, curr + sz)
curr += sz curr += sz
else: else:
raise ValueError('Invalid break_mode: ' + break_mode) raise ValueError('Invalid break_mode: ' + break_mode)
self.sizes = np.array([e - s for s, e in self.slice_indices]) self.sizes = np.array([e - s for s, e in self.slice_indices])
self.slice_indices = np.array(self.slice_indices, dtype=int) self.slice_indices = np.array(self.slice_indices, dtype=int)
# build index mapping block indices to the underlying dataset indices # build index mapping block indices to the underlying dataset indices
self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int) self.block_to_dataset_index = np.empty((len(self.slice_indices), 3), dtype=int)
ds_idx, ds_remaining = -1, 0 ds_idx, ds_remaining = -1, 0
......
...@@ -29,7 +29,7 @@ class BaseFairseqModel(nn.Module): ...@@ -29,7 +29,7 @@ class BaseFairseqModel(nn.Module):
@classmethod @classmethod
def build_model(cls, args, task): def build_model(cls, args, task):
"""Build a new model instance.""" """Build a new model instance."""
raise NotImplementedError raise NotImplementedError('FairseqModels must implement the build_model method')
def get_targets(self, sample, net_output): def get_targets(self, sample, net_output):
"""Get targets from either the sample or the net's output.""" """Get targets from either the sample or the net's output."""
......
...@@ -298,7 +298,7 @@ def add_common_eval_args(group): ...@@ -298,7 +298,7 @@ def add_common_eval_args(group):
group.add_argument('--path', metavar='FILE', group.add_argument('--path', metavar='FILE',
help='path(s) to model file(s), colon separated') help='path(s) to model file(s), colon separated')
group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None, group.add_argument('--remove-bpe', nargs='?', const='@@ ', default=None,
help='remove BPE tokens before scoring') help='remove BPE tokens before scoring (can be set to sentencepiece)')
group.add_argument('--quiet', action='store_true', group.add_argument('--quiet', action='store_true',
help='only print final scores') help='only print final scores')
group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT',
...@@ -350,6 +350,8 @@ def add_generation_args(parser): ...@@ -350,6 +350,8 @@ def add_generation_args(parser):
help='unknown word penalty: <0 produces more unks, >0 produces fewer') help='unknown word penalty: <0 produces more unks, >0 produces fewer')
group.add_argument('--replace-unk', nargs='?', const=True, default=None, group.add_argument('--replace-unk', nargs='?', const=True, default=None,
help='perform unknown replacement (optionally with alignment dictionary)') help='perform unknown replacement (optionally with alignment dictionary)')
group.add_argument('--sacrebleu', action='store_true',
help='score with sacrebleu')
group.add_argument('--score-reference', action='store_true', group.add_argument('--score-reference', action='store_true',
help='just score the reference translation') help='just score the reference translation')
group.add_argument('--prefix-size', default=0, type=int, metavar='PS', group.add_argument('--prefix-size', default=0, type=int, metavar='PS',
......
...@@ -246,7 +246,8 @@ class MultilingualTranslationTask(FairseqTask): ...@@ -246,7 +246,8 @@ class MultilingualTranslationTask(FairseqTask):
for k, v in agg_logging_output.items() for k, v in agg_logging_output.items()
} }
flat_logging_output['loss'] = sum_over_languages('loss') flat_logging_output['loss'] = sum_over_languages('loss')
flat_logging_output['nll_loss'] = sum_over_languages('nll_loss') if any('nll_loss' in logging_output for logging_output in agg_logging_outputs.values()):
flat_logging_output['nll_loss'] = sum_over_languages('nll_loss')
flat_logging_output['sample_size'] = sum_over_languages('sample_size') flat_logging_output['sample_size'] = sum_over_languages('sample_size')
flat_logging_output['nsentences'] = sum_over_languages('nsentences') flat_logging_output['nsentences'] = sum_over_languages('nsentences')
flat_logging_output['ntokens'] = sum_over_languages('ntokens') flat_logging_output['ntokens'] = sum_over_languages('ntokens')
......
...@@ -438,14 +438,12 @@ def resolve_max_positions(*args): ...@@ -438,14 +438,12 @@ def resolve_max_positions(*args):
def import_user_module(args): def import_user_module(args):
if hasattr(args, 'user_dir'): module_path = getattr(args, 'user_dir', None)
module_path = args.user_dir if module_path is not None:
module_path = os.path.abspath(args.user_dir)
if module_path is not None: module_parent, module_name = os.path.split(module_path)
module_path = os.path.abspath(args.user_dir)
module_parent, module_name = os.path.split(module_path) if module_name not in sys.modules:
sys.path.insert(0, module_parent)
if module_name not in sys.modules: importlib.import_module(module_name)
sys.path.insert(0, module_parent) sys.path.pop(0)
importlib.import_module(module_name)
sys.path.pop(0)
...@@ -95,7 +95,10 @@ def main(args): ...@@ -95,7 +95,10 @@ def main(args):
translator.cuda() translator.cuda()
# Generate and compute BLEU score # Generate and compute BLEU score
scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk()) if args.sacrebleu:
scorer = bleu.SacrebleuScorer()
else:
scorer = bleu.Scorer(tgt_dict.pad(), tgt_dict.eos(), tgt_dict.unk())
num_sentences = 0 num_sentences = 0
has_target = True has_target = True
with progress_bar.build_progress_bar(args, itr) as t: with progress_bar.build_progress_bar(args, itr) as t:
...@@ -160,7 +163,10 @@ def main(args): ...@@ -160,7 +163,10 @@ def main(args):
# Convert back to tokens for evaluation with unk replacement and/or without BPE # Convert back to tokens for evaluation with unk replacement and/or without BPE
target_tokens = tokenizer.Tokenizer.tokenize( target_tokens = tokenizer.Tokenizer.tokenize(
target_str, tgt_dict, add_if_not_exist=True) target_str, tgt_dict, add_if_not_exist=True)
scorer.add(target_tokens, hypo_tokens) if hasattr(scorer, 'add_string'):
scorer.add_string(target_str, hypo_str)
else:
scorer.add(target_tokens, hypo_tokens)
wps_meter.update(src_tokens.size(0)) wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)}) t.log({'wps': round(wps_meter.avg)})
......
...@@ -26,6 +26,8 @@ def get_parser(): ...@@ -26,6 +26,8 @@ def get_parser():
type=int, help='consider ngrams up to this order') type=int, help='consider ngrams up to this order')
parser.add_argument('--ignore-case', action='store_true', parser.add_argument('--ignore-case', action='store_true',
help='case-insensitive scoring') help='case-insensitive scoring')
parser.add_argument('--sacrebleu', action='store_true',
help='score with sacrebleu')
# fmt: on # fmt: on
return parser return parser
...@@ -49,14 +51,21 @@ def main(): ...@@ -49,14 +51,21 @@ def main():
else: else:
yield line yield line
def score(fdsys): if args.sacrebleu:
with open(args.ref) as fdref: import sacrebleu
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)): def score(fdsys):
sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict) with open(args.ref) as fdref:
ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict) print(sacrebleu.corpus_bleu(fdsys, [fdref]))
scorer.add(ref_tok, sys_tok) else:
print(scorer.result_string(args.order)) def score(fdsys):
with open(args.ref) as fdref:
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict)
ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict)
scorer.add(ref_tok, sys_tok)
print(scorer.result_string(args.order))
if args.sys == '-': if args.sys == '-':
score(sys.stdin) score(sys.stdin)
......
...@@ -15,7 +15,7 @@ echo 'Cloning Moses github repository (for tokenization scripts)...' ...@@ -15,7 +15,7 @@ echo 'Cloning Moses github repository (for tokenization scripts)...'
git clone https://github.com/moses-smt/mosesdecoder.git git clone https://github.com/moses-smt/mosesdecoder.git
SCRIPTS=mosesdecoder/scripts SCRIPTS=mosesdecoder/scripts
DETOKENIZER=$SCRIPTS/tokenizer/tokenizer.perl DETOKENIZER=$SCRIPTS/tokenizer/detokenizer.perl
grep ^H $GEN \ grep ^H $GEN \
| sed 's/^H\-//' \ | sed 's/^H\-//' \
......
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import unittest
import torch
from fairseq.data import TokenBlockDataset
import tests.utils as test_utils
class TestTokenBlockDataset(unittest.TestCase):
def _build_dataset(self, data, **kwargs):
sizes = [len(x) for x in data]
underlying_ds = test_utils.TestDataset(data)
return TokenBlockDataset(underlying_ds, sizes, **kwargs)
def test_eos_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([1]), # this should be filtered
torch.LongTensor([8, 7, 6, 1]),
]
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [1])
self.assertEqual(ds[2].tolist(), [8, 7, 6, 1])
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([1]), # this should be filtered
]
ds = self._build_dataset(data, block_size=None, pad=0, eos=1, break_mode='eos')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1])
self.assertEqual(ds[2].tolist(), [1])
def test_block_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([9, 1]),
]
ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='none')
self.assertEqual(ds[0].tolist(), [5, 4, 3])
self.assertEqual(ds[1].tolist(), [2, 1, 8])
self.assertEqual(ds[2].tolist(), [7, 6, 1])
self.assertEqual(ds[3].tolist(), [9, 1])
def test_complete_break_mode(self):
data = [
torch.LongTensor([5, 4, 3, 2, 1]),
torch.LongTensor([8, 7, 6, 1]),
torch.LongTensor([9, 1]),
]
ds = self._build_dataset(data, block_size=6, pad=0, eos=1, break_mode='complete')
self.assertEqual(ds[0].tolist(), [5, 4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [8, 7, 6, 1, 9, 1])
data = [
torch.LongTensor([4, 3, 2, 1]),
torch.LongTensor([5, 1]),
torch.LongTensor([1]),
torch.LongTensor([6, 1]),
]
ds = self._build_dataset(data, block_size=3, pad=0, eos=1, break_mode='complete')
self.assertEqual(ds[0].tolist(), [4, 3, 2, 1])
self.assertEqual(ds[1].tolist(), [5, 1, 1])
self.assertEqual(ds[2].tolist(), [6, 1])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment