Commit f3050860 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Update scoring script for MoE paper

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/586

Differential Revision: D14517550

Pulled By: myleott

fbshipit-source-id: fab68a8f597a98cf28d812d89eff845c5776b65b
parent 66f033e6
...@@ -14,34 +14,48 @@ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade" ...@@ -14,34 +14,48 @@ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
""" """
import argparse import argparse
from itertools import chain
import sys import sys
import numpy as np import numpy as np
import random import random
from fairseq import bleu, tokenizer from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
from fairseq.data import dictionary
parser = argparse.ArgumentParser(sys.argv[0])
parser.add_argument('--sys', nargs='*', default='', metavar='FILE', def main():
parser = argparse.ArgumentParser(sys.argv[0])
parser.add_argument('--sys', nargs='*', default='', metavar='FILE',
help='path to system output') help='path to system output')
parser.add_argument('--ref', default='', metavar='FILE', parser.add_argument('--ref', default='', metavar='FILE',
help='path to references') help='path to references')
parser.add_argument('--output', default='', metavar='FILE', parser.add_argument('--output', default='', metavar='FILE',
help='print outputs into a pretty format') help='print outputs into a pretty format')
args = parser.parse_args() args = parser.parse_args()
if args.sys:
src, tgt, hypos, log_probs = load_sys(args.sys)
print('pairwise BLEU: %.2f' % pairwise(hypos))
if args.output:
merge(src, tgt, hypos, log_probs, args.output)
if args.ref:
_, _, refs = load_ref(args.ref)
if args.sys:
multi_ref(refs, hypos)
else:
intra_ref(refs)
dict = dictionary.Dictionary()
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
def dictolist(d): def dictolist(d):
a = sorted(d.items(), key=lambda i: i[0]) a = sorted(d.items(), key=lambda i: i[0])
return [i[1] for i in a] return [i[1] for i in a]
def load_sys(paths): def load_sys(paths):
src, tgt, hypos, log_probs = {}, {}, {}, {} src, tgt, hypos, log_probs = {}, {}, {}, {}
for path in paths: for path in paths:
with open(path) as f: with open(path) as f:
for line in f: for line in f:
line = line.rstrip()
if line.startswith(('S-', 'T-', 'H-')): if line.startswith(('S-', 'T-', 'H-')):
i = int(line[line.find('-')+1:line.find('\t')]) i = int(line[line.find('-')+1:line.find('\t')])
if line.startswith('S-'): if line.startswith('S-'):
...@@ -56,6 +70,7 @@ def load_sys(paths): ...@@ -56,6 +70,7 @@ def load_sys(paths):
log_probs[i].append(float(line.split('\t')[1])) log_probs[i].append(float(line.split('\t')[1]))
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs) return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
def load_ref(path): def load_ref(path):
with open(path) as f: with open(path) as f:
lines = f.readlines() lines = f.readlines()
...@@ -63,43 +78,48 @@ def load_ref(path): ...@@ -63,43 +78,48 @@ def load_ref(path):
i = 0 i = 0
while i < len(lines): while i < len(lines):
if lines[i].startswith('S-'): if lines[i].startswith('S-'):
src.append(lines[i].split('\t')[1]) src.append(lines[i].split('\t')[1].rstrip())
i += 1 i += 1
elif lines[i].startswith('T-'): elif lines[i].startswith('T-'):
tgt.append(lines[i].split('\t')[1]) tgt.append(lines[i].split('\t')[1].rstrip())
i += 1 i += 1
else: else:
a = [] a = []
while i < len(lines) and lines[i].startswith('R'): while i < len(lines) and lines[i].startswith('R'):
a.append(lines[i].split('\t')[1]) a.append(lines[i].split('\t')[1].rstrip())
i += 1 i += 1
refs.append(a) refs.append(a)
return src, tgt, refs return src, tgt, refs
def merge(src, tgt, hypos, log_probs, path): def merge(src, tgt, hypos, log_probs, path):
with open(path, 'w') as f: with open(path, 'w') as f:
for s, t, hs, lps in zip(src, tgt, hypos, log_probs): for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
f.write(s) f.write(s + '\n')
f.write(t) f.write(t + '\n')
f.write('\n') f.write('\n')
for h, lp in zip(hs, lps): for h, lp in zip(hs, lps):
f.write('%f\t' % lp + h) f.write('\t%f\t%s\n' % (lp, h.strip()))
f.write('------------------------------------------------------\n') f.write('------------------------------------------------------\n')
def corpus_bleu(ref, hypo):
scorer.reset() def corpus_bleu(sys_stream, ref_streams):
for r, h in zip(ref, hypo): bleu = _corpus_bleu(sys_stream, ref_streams, tokenize='none')
r_tok = tokenizer.Tokenizer.tokenize(r, dict) return bleu.score
h_tok = tokenizer.Tokenizer.tokenize(h, dict)
scorer.add(r_tok, h_tok)
return scorer.score() def sentence_bleu(hypothesis, reference):
bleu = _corpus_bleu(hypothesis, reference)
def sentence_bleu(ref, hypo): for i in range(1, 4):
scorer.reset(one_init=True) bleu.counts[i] += 1
r_tok = tokenizer.Tokenizer.tokenize(ref, dict) bleu.totals[i] += 1
h_tok = tokenizer.Tokenizer.tokenize(hypo, dict) bleu = compute_bleu(
scorer.add(r_tok, h_tok) bleu.counts, bleu.totals,
return scorer.score() bleu.sys_len, bleu.ref_len,
smooth='exp', smooth_floor=0.0,
)
return bleu.score
def pairwise(sents): def pairwise(sents):
_ref, _hypo = [], [] _ref, _hypo = [], []
...@@ -109,46 +129,61 @@ def pairwise(sents): ...@@ -109,46 +129,61 @@ def pairwise(sents):
if i != j: if i != j:
_ref.append(s[i]) _ref.append(s[i])
_hypo.append(s[j]) _hypo.append(s[j])
return corpus_bleu(_ref, _hypo) return corpus_bleu(_hypo, [_ref])
def multi_ref(refs, hypos): def multi_ref(refs, hypos):
_ref, _hypo = [], [] _ref, _hypo = [], []
ref_cnt = 0 ref_cnt = 0
assert len(refs) == len(hypos)
# count number of refs covered
for rs, hs in zip(refs, hypos): for rs, hs in zip(refs, hypos):
a = set() a = set()
for h in hs: for h in hs:
s = [sentence_bleu(r, h) for r in rs] s = [sentence_bleu(h, r) for r in rs]
j = np.argmax(s) j = np.argmax(s)
_ref.append(rs[j]) _ref.append(rs[j])
_hypo.append(h) _hypo.append(h)
best = [k for k in range(len(rs)) if s[k] == s[j]] best = [k for k in range(len(rs)) if s[k] == s[j]]
a.add(random.choice(best)) a.add(random.choice(best))
ref_cnt += len(a) ref_cnt += len(a)
print('avg oracle BLEU: %.2f' % corpus_bleu(_ref, _hypo))
print('#refs covered: %.2f' % (ref_cnt / len(refs))) print('#refs covered: %.2f' % (ref_cnt / len(refs)))
# transpose refs and hypos
refs = list(zip(*refs))
hypos = list(zip(*hypos))
# compute average corpus BLEU
k = len(hypos)
m = len(refs)
concat_hypos = []
concat_refs = [[] for j in range(m - 1)]
for i in range(m):
concat_hypos.append([h for hs in hypos for h in hs])
rest = refs[:i] + refs[i+1:]
for j in range(m - 1):
concat_refs[j].extend(rest[j] * k)
concat_hypos = list(chain.from_iterable(concat_hypos))
bleu = corpus_bleu(concat_hypos, concat_refs)
print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
def intra_ref(refs): def intra_ref(refs):
print('ref pairwise BLEU: %.2f' % pairwise(refs)) print('ref pairwise BLEU: %.2f' % pairwise(refs))
_ref, _hypo = [], [] refs = list(zip(*refs))
for rs in refs: m = len(refs)
for i, h in enumerate(rs): concat_h = []
rest = rs[:i] + rs[i+1:] concat_rest = [[] for j in range(m - 1)]
s = [sentence_bleu(r, h) for r in rest] for i, h in enumerate(refs):
j = np.argmax(s) rest = refs[:i] + refs[i+1:]
_ref.append(rest[j]) concat_h.append(h)
_hypo.append(h) for j in range(m - 1):
print('ref avg oracle BLEU (leave-one-out): %.2f' % corpus_bleu(_ref, _hypo)) concat_rest[j].extend(rest[j])
concat_h = list(chain.from_iterable(concat_h))
bleu = corpus_bleu(concat_h, concat_rest)
print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
if __name__ == '__main__':
if args.sys:
src, tgt, hypos, log_probs = load_sys(args.sys)
print('pairwise BLEU: %.2f' % pairwise(hypos))
if args.output:
merge(src, tgt, hypos, log_probs, args.output)
if args.ref:
_, _, refs = load_ref(args.ref)
if args.sys:
multi_ref(refs, hypos)
else:
intra_ref(refs)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment