Commit 75cc8821 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Update MoE README

Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/619

Differential Revision: D15562983

Pulled By: myleott

fbshipit-source-id: 9240f56f18c87120b7d38e0db374d24a55999395
parent d5f76d74
...@@ -63,20 +63,20 @@ $ for EXPERT in $(seq 0 2); do \ ...@@ -63,20 +63,20 @@ $ for EXPERT in $(seq 0 2); do \
fairseq-interactive data-bin/wmt17_en_de \ fairseq-interactive data-bin/wmt17_en_de \
--path checkpoints/checkpoint_best.pt \ --path checkpoints/checkpoint_best.pt \
--beam 1 --remove-bpe \ --beam 1 --remove-bpe \
--buffer 500 --max-tokens 6000 ; \ --buffer-size 500 --max-tokens 6000 \
--task translation_moe \ --task translation_moe \
--method hMoElp --mean-pool-gating-network \ --method hMoElp --mean-pool-gating-network \
--num-experts 3 \ --num-experts 3 \
--gen-expert $EXPERT \ --gen-expert $EXPERT ; \
done > wmt14-en-de.extra_refs.tok.gen.3experts done > wmt14-en-de.extra_refs.tok.gen.3experts
``` ```
Finally use `scripts/score_moe.py` to compute pairwise BLUE and average oracle BLEU: Finally use `score_moe.py` to compute pairwise BLUE and average oracle BLEU:
``` ```
$ python scripts/score_moe.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok $ python examples/translation_moe/score.py --sys wmt14-en-de.extra_refs.tok.gen.3experts --ref wmt14-en-de.extra_refs.tok
pairwise BLEU: 48.26 pairwise BLEU: 48.26
avg oracle BLEU: 49.50
#refs covered: 2.11 #refs covered: 2.11
multi-reference BLEU (leave-one-out): 59.46
``` ```
This matches row 3 from Table 7 in the paper. This matches row 3 from Table 7 in the paper.
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# the root directory of this source tree. An additional grant of patent rights # the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory. # can be found in the PATENTS file in the same directory.
""" """
Scoring script for computing pairwise BLEU and oracle BLEU over a set of Scoring script for computing pairwise BLEU and multi-ref BLEU over a set of
candidate hypotheses. candidate hypotheses.
See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
...@@ -16,9 +16,9 @@ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" ...@@ -16,9 +16,9 @@ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
import argparse import argparse
from itertools import chain from itertools import chain
import sys import sys
import numpy as np
import random import random
import numpy as np
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
...@@ -37,6 +37,7 @@ def main(): ...@@ -37,6 +37,7 @@ def main():
print('pairwise BLEU: %.2f' % pairwise(hypos)) print('pairwise BLEU: %.2f' % pairwise(hypos))
if args.output: if args.output:
merge(src, tgt, hypos, log_probs, args.output) merge(src, tgt, hypos, log_probs, args.output)
if args.ref: if args.ref:
_, _, refs = load_ref(args.ref) _, _, refs = load_ref(args.ref)
if args.sys: if args.sys:
...@@ -154,19 +155,20 @@ def multi_ref(refs, hypos): ...@@ -154,19 +155,20 @@ def multi_ref(refs, hypos):
refs = list(zip(*refs)) refs = list(zip(*refs))
hypos = list(zip(*hypos)) hypos = list(zip(*hypos))
# compute average corpus BLEU # compute multi-ref corpus BLEU (leave-one-out to be comparable to intra_ref)
k = len(hypos) k = len(hypos)
m = len(refs) m = len(refs)
concat_hypos = [] flat_hypos = [hypos[j][i] for i in range(len(hypos[0])) for j in range(k)]
concat_refs = [[] for j in range(m - 1)] duplicated_refs = [
for i in range(m): [ref for ref in refs_i for _ in range(k)]
concat_hypos.append([h for hs in hypos for h in hs]) for refs_i in refs
rest = refs[:i] + refs[i+1:] ]
for j in range(m - 1): loo_bleus = []
concat_refs[j].extend(rest[j] * k) for held_out_ref in range(m):
concat_hypos = list(chain.from_iterable(concat_hypos)) remaining_refs = duplicated_refs[:held_out_ref] + duplicated_refs[held_out_ref+1:]
bleu = corpus_bleu(concat_hypos, concat_refs) assert len(remaining_refs) == m - 1
print('multi-reference BLEU (leave-one-out): %.2f' % bleu) loo_bleus.append(corpus_bleu(flat_hypos, remaining_refs))
print('average multi-reference BLEU (leave-one-out): %.2f' % np.mean(loo_bleus))
def intra_ref(refs): def intra_ref(refs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment