Commit 75cc8821 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Update MoE README

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/619

Differential Revision: D15562983

Pulled By: myleott

fbshipit-source-id: 9240f56f18c87120b7d38e0db374d24a55999395
parent d5f76d74
......@@ -63,20 +63,20 @@ $ for EXPERT in $(seq 0 2); do \
fairseq-interactive data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \
--buffer 500 --max-tokens 6000 ; \
--buffer-size 500 --max-tokens 6000 \
--task translation_moe \
--method hMoElp --mean-pool-gating-network \
--num-experts 3 \
--gen-expert $EXPERT \
--gen-expert $EXPERT ; \
done > wmt14-en-de.extra_refs.tok.gen.3experts
```
Finally use `scripts/score_moe.py` to compute pairwise BLUE and average oracle BLEU:
Finally use `score_moe.py` to compute pairwise BLUE and average oracle BLEU:
```
$ python scripts/score_moe.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
$ python examples/translation_moe/score.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
pairwise BLEU: 48.26
avg oracle BLEU: 49.50
#refs covered: 2.11
multi-reference BLEU (leave-one-out): 59.46
```
This matches row 3 from Table 7 in the paper.
......
......@@ -6,7 +6,7 @@
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
"""
Scoring script for computing pairwise BLEU and oracle BLEU over a set of
Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of
candidate hypotheses.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
......@@ -16,9 +16,9 @@ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
import argparse
from itertools import chain
import sys
import numpy as np
import random
import numpy as np
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
......@@ -37,6 +37,7 @@ def main():
print('pairwise BLEU: %.2f' % pairwise(hypos))
if args.output:
merge(src, tgt, hypos, log_probs, args.output)
if args.ref:
_, _, refs = load_ref(args.ref)
if args.sys:
......@@ -154,19 +155,20 @@ def multi_ref(refs, hypos):
refs = list(zip(*refs))
hypos = list(zip(*hypos))
# compute average corpus BLEU
# compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref)
k = len(hypos)
m = len(refs)
concat_hypos = []
concat_refs = [[] for j in range(m - 1)]
for i in range(m):
concat_hypos.append([h for hs in hypos for h in hs])
rest = refs[:i] + refs[i+1:]
for j in range(m - 1):
concat_refs[j].extend(rest[j] * k)
concat_hypos = list(chain.from_iterable(concat_hypos))
bleu = corpus_bleu(concat_hypos, concat_refs)
print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
duplicated_refs = [
[ref for ref in refs_i for _ in range(k)]
for refs_i in refs
]
loo_bleus = []
for held_out_ref in range(m):
remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:]
assert len(remaining_refs) == m - 1
loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus))
def intra_ref(refs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment