Commit 19a3865d authored by Myle Ott's avatar Myle Ott
Browse files

Only consider EOS in beam search if it's among top-k candidates

parent f6ac1aec
...@@ -152,14 +152,14 @@ $ python generate.py data-bin/wmt14.en-fr.newstest2014 \ ...@@ -152,14 +152,14 @@ $ python generate.py data-bin/wmt14.en-fr.newstest2014 \
--path data-bin/wmt14.en-fr.fconv-py/model.pt \ --path data-bin/wmt14.en-fr.fconv-py/model.pt \
--beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out --beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
... ...
| Translated 3003 sentences (95451 tokens) in 94.5s (1009.60 tokens/s) | Translated 3003 sentences (95451 tokens) in 93.8s (1018.09 tokens/s)
| Generate test with beam=5: BLEU4 = 40.70, 67.7/46.8/34.2/25.4 (BP=1.000, ratio=1.000, syslen=81190, reflen=81194) | Generate test with beam=5: BLEU4 = 40.67, 67.7/46.7/34.2/25.3 (BP=1.000, ratio=0.998, syslen=81377, reflen=81194)
# Scoring with score.py: # Scoring with score.py:
$ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys $ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys
$ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref $ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref
$ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref $ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
BLEU4 = 40.70, 67.7/46.8/34.2/25.4 (BP=1.000, ratio=1.000, syslen=81190, reflen=81194) BLEU4 = 40.67, 67.7/46.7/34.2/25.3 (BP=1.000, ratio=0.998, syslen=81377, reflen=81194)
``` ```
# Join the fairseq community # Join the fairseq community
......
...@@ -250,10 +250,11 @@ class SequenceGenerator(object): ...@@ -250,10 +250,11 @@ class SequenceGenerator(object):
eos_mask = cand_indices.eq(self.eos) eos_mask = cand_indices.eq(self.eos)
if step >= self.minlen: if step >= self.minlen:
eos_bbsz_idx = buffer('eos_bbsz_idx') eos_bbsz_idx = buffer('eos_bbsz_idx')
cand_bbsz_idx.masked_select(eos_mask, out=eos_bbsz_idx) # only consider eos when it's among the top beam_size indices
cand_bbsz_idx[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_bbsz_idx)
if eos_bbsz_idx.numel() > 0: if eos_bbsz_idx.numel() > 0:
eos_scores = buffer('eos_scores', type_of=scores) eos_scores = buffer('eos_scores', type_of=scores)
cand_scores.masked_select(eos_mask, out=eos_scores) cand_scores[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_scores)
num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores) num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores)
assert num_remaining_sent >= 0 assert num_remaining_sent >= 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment