Commit ff2e8cf2 authored by Myle Ott's avatar Myle Ott
Browse files

Fix handling of continuation tokens that precede <unk> in generate.py

parent 8943fc78
...@@ -122,12 +122,12 @@ def main(): ...@@ -122,12 +122,12 @@ def main():
display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)]) display_hypotheses(None, tokens, line, None, hypos[:min(len(hypos), args.nbest)])
else: else:
def maybe_remove_bpe(tokens): def maybe_remove_bpe(tokens, escape_unk=False):
"""Helper for removing BPE symbols from a hypothesis.""" """Helper for removing BPE symbols from a hypothesis."""
if args.remove_bpe is None: if args.remove_bpe is None:
return tokens return tokens
assert (tokens == dataset.dst_dict.pad()).sum() == 0 assert (tokens == dataset.dst_dict.pad()).sum() == 0
hypo_minus_bpe = dataset.dst_dict.string(tokens, args.remove_bpe) hypo_minus_bpe = dataset.dst_dict.string(tokens, args.remove_bpe, escape_unk)
return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True) return tokenizer.Tokenizer.tokenize(hypo_minus_bpe, dataset.dst_dict, add_if_not_exist=True)
# Generate and compute BLEU score # Generate and compute BLEU score
...@@ -145,7 +145,7 @@ def main(): ...@@ -145,7 +145,7 @@ def main():
for id, src, ref, hypos in translations: for id, src, ref, hypos in translations:
ref = ref.int().cpu() ref = ref.int().cpu()
top_hypo = hypos[0]['tokens'].int().cpu() top_hypo = hypos[0]['tokens'].int().cpu()
scorer.add(maybe_remove_bpe(ref), maybe_remove_bpe(top_hypo)) scorer.add(maybe_remove_bpe(ref, escape_unk=True), maybe_remove_bpe(top_hypo))
display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)]) display_hypotheses(id, src, None, ref, hypos[:min(len(hypos), args.nbest)])
wps_meter.update(src.size(0)) wps_meter.update(src.size(0))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment