Commit 19a3865d authored by Myle Ott's avatar Myle Ott
Browse files

Only consider EOS in beam search if it's among top-k candidates

parent f6ac1aec
......@@ -152,14 +152,14 @@ $ python generate.py data-bin/wmt14.en-fr.newstest2014 \
--path data-bin/wmt14.en-fr.fconv-py/model.pt \
--beam 5 --batch-size 128 --remove-bpe | tee /tmp/gen.out
...
| Translated 3003 sentences (95451 tokens) in 94.5s (1009.60 tokens/s)
| Generate test with beam=5: BLEU4 = 40.70, 67.7/46.8/34.2/25.4 (BP=1.000, ratio=1.000, syslen=81190, reflen=81194)
| Translated 3003 sentences (95451 tokens) in 93.8s (1018.09 tokens/s)
| Generate test with beam=5: BLEU4 = 40.67, 67.7/46.7/34.2/25.3 (BP=1.000, ratio=0.998, syslen=81377, reflen=81194)
# Scoring with score.py:
$ grep ^H /tmp/gen.out | cut -f3- > /tmp/gen.out.sys
$ grep ^T /tmp/gen.out | cut -f2- > /tmp/gen.out.ref
$ python score.py --sys /tmp/gen.out.sys --ref /tmp/gen.out.ref
BLEU4 = 40.70, 67.7/46.8/34.2/25.4 (BP=1.000, ratio=1.000, syslen=81190, reflen=81194)
BLEU4 = 40.67, 67.7/46.7/34.2/25.3 (BP=1.000, ratio=0.998, syslen=81377, reflen=81194)
```
# Join the fairseq community
......
......@@ -250,10 +250,11 @@ class SequenceGenerator(object):
eos_mask = cand_indices.eq(self.eos)
if step >= self.minlen:
eos_bbsz_idx = buffer('eos_bbsz_idx')
cand_bbsz_idx.masked_select(eos_mask, out=eos_bbsz_idx)
# only consider eos when it's among the top beam_size indices
cand_bbsz_idx[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_bbsz_idx)
if eos_bbsz_idx.numel() > 0:
eos_scores = buffer('eos_scores', type_of=scores)
cand_scores.masked_select(eos_mask, out=eos_scores)
cand_scores[:, :beam_size].masked_select(eos_mask[:, :beam_size], out=eos_scores)
num_remaining_sent -= finalize_hypos(step, eos_bbsz_idx, eos_scores)
assert num_remaining_sent >= 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment