Commit f3050860 authored by Myle Ott's avatar Myle Ott Committed by Facebook Github Bot
Browse files

Update scoring script for MoE paper

Summary: Pull Request resolved: https://github.com/pytorch/fairseq/pull/586

Differential Revision: D14517550

Pulled By: myleott

fbshipit-source-id: fab68a8f597a98cf28d812d89eff845c5776b65b
parent 66f033e6
......@@ -14,34 +14,48 @@ See `"Mixture Models for Diverse Machine Translation: Tricks of the Trade"
"""
import argparse
from itertools import chain
import sys
import numpy as np
import random
from fairseq import bleu, tokenizer
from fairseq.data import dictionary
from sacrebleu import compute_bleu, corpus_bleu as _corpus_bleu
parser = argparse.ArgumentParser(sys.argv[0])
parser.add_argument('--sys', nargs='*', default='', metavar='FILE',
help='path to system output')
parser.add_argument('--ref', default='', metavar='FILE',
help='path to references')
parser.add_argument('--output', default='', metavar='FILE',
help='print outputs into a pretty format')
args = parser.parse_args()
dict = dictionary.Dictionary()
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
def main():
parser = argparse.ArgumentParser(sys.argv[0])
parser.add_argument('--sys', nargs='*', default='', metavar='FILE',
help='path to system output')
parser.add_argument('--ref', default='', metavar='FILE',
help='path to references')
parser.add_argument('--output', default='', metavar='FILE',
help='print outputs into a pretty format')
args = parser.parse_args()
if args.sys:
src, tgt, hypos, log_probs = load_sys(args.sys)
print('pairwise BLEU: %.2f' % pairwise(hypos))
if args.output:
merge(src, tgt, hypos, log_probs, args.output)
if args.ref:
_, _, refs = load_ref(args.ref)
if args.sys:
multi_ref(refs, hypos)
else:
intra_ref(refs)
def dictolist(d):
a = sorted(d.items(), key=lambda i: i[0])
return [i[1] for i in a]
def load_sys(paths):
src, tgt, hypos, log_probs = {}, {}, {}, {}
for path in paths:
with open(path) as f:
for line in f:
line = line.rstrip()
if line.startswith(('S-', 'T-', 'H-')):
i = int(line[line.find('-')+1:line.find('\t')])
if line.startswith('S-'):
......@@ -56,6 +70,7 @@ def load_sys(paths):
log_probs[i].append(float(line.split('\t')[1]))
return dictolist(src), dictolist(tgt), dictolist(hypos), dictolist(log_probs)
def load_ref(path):
with open(path) as f:
lines = f.readlines()
......@@ -63,43 +78,48 @@ def load_ref(path):
i = 0
while i < len(lines):
if lines[i].startswith('S-'):
src.append(lines[i].split('\t')[1])
src.append(lines[i].split('\t')[1].rstrip())
i += 1
elif lines[i].startswith('T-'):
tgt.append(lines[i].split('\t')[1])
tgt.append(lines[i].split('\t')[1].rstrip())
i += 1
else:
a = []
while i < len(lines) and lines[i].startswith('R'):
a.append(lines[i].split('\t')[1])
a.append(lines[i].split('\t')[1].rstrip())
i += 1
refs.append(a)
return src, tgt, refs
def merge(src, tgt, hypos, log_probs, path):
with open(path, 'w') as f:
for s, t, hs, lps in zip(src, tgt, hypos, log_probs):
f.write(s)
f.write(t)
f.write(s + '\n')
f.write(t + '\n')
f.write('\n')
for h, lp in zip(hs, lps):
f.write('%f\t' % lp + h)
f.write('\t%f\t%s\n' % (lp, h.strip()))
f.write('------------------------------------------------------\n')
def corpus_bleu(ref, hypo):
scorer.reset()
for r, h in zip(ref, hypo):
r_tok = tokenizer.Tokenizer.tokenize(r, dict)
h_tok = tokenizer.Tokenizer.tokenize(h, dict)
scorer.add(r_tok, h_tok)
return scorer.score()
def sentence_bleu(ref, hypo):
scorer.reset(one_init=True)
r_tok = tokenizer.Tokenizer.tokenize(ref, dict)
h_tok = tokenizer.Tokenizer.tokenize(hypo, dict)
scorer.add(r_tok, h_tok)
return scorer.score()
def corpus_bleu(sys_stream, ref_streams):
bleu = _corpus_bleu(sys_stream, ref_streams, tokenize='none')
return bleu.score
def sentence_bleu(hypothesis, reference):
bleu = _corpus_bleu(hypothesis, reference)
for i in range(1, 4):
bleu.counts[i] += 1
bleu.totals[i] += 1
bleu = compute_bleu(
bleu.counts, bleu.totals,
bleu.sys_len, bleu.ref_len,
smooth='exp', smooth_floor=0.0,
)
return bleu.score
def pairwise(sents):
_ref, _hypo = [], []
......@@ -109,46 +129,61 @@ def pairwise(sents):
if i != j:
_ref.append(s[i])
_hypo.append(s[j])
return corpus_bleu(_ref, _hypo)
return corpus_bleu(_hypo, [_ref])
def multi_ref(refs, hypos):
_ref, _hypo = [], []
ref_cnt = 0
assert len(refs) == len(hypos)
# count number of refs covered
for rs, hs in zip(refs, hypos):
a = set()
for h in hs:
s = [sentence_bleu(r, h) for r in rs]
s = [sentence_bleu(h, r) for r in rs]
j = np.argmax(s)
_ref.append(rs[j])
_hypo.append(h)
best = [k for k in range(len(rs)) if s[k] == s[j]]
a.add(random.choice(best))
ref_cnt += len(a)
print('avg oracle BLEU: %.2f' % corpus_bleu(_ref, _hypo))
print('#refs covered: %.2f' % (ref_cnt / len(refs)))
# transpose refs and hypos
refs = list(zip(*refs))
hypos = list(zip(*hypos))
# compute average corpus BLEU
k = len(hypos)
m = len(refs)
concat_hypos = []
concat_refs = [[] for j in range(m - 1)]
for i in range(m):
concat_hypos.append([h for hs in hypos for h in hs])
rest = refs[:i] + refs[i+1:]
for j in range(m - 1):
concat_refs[j].extend(rest[j] * k)
concat_hypos = list(chain.from_iterable(concat_hypos))
bleu = corpus_bleu(concat_hypos, concat_refs)
print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
def intra_ref(refs):
print('ref pairwise BLEU: %.2f' % pairwise(refs))
_ref, _hypo = [], []
for rs in refs:
for i, h in enumerate(rs):
rest = rs[:i] + rs[i+1:]
s = [sentence_bleu(r, h) for r in rest]
j = np.argmax(s)
_ref.append(rest[j])
_hypo.append(h)
print('ref avg oracle BLEU (leave-one-out): %.2f' % corpus_bleu(_ref, _hypo))
refs = list(zip(*refs))
m = len(refs)
concat_h = []
concat_rest = [[] for j in range(m - 1)]
for i, h in enumerate(refs):
rest = refs[:i] + refs[i+1:]
concat_h.append(h)
for j in range(m - 1):
concat_rest[j].extend(rest[j])
concat_h = list(chain.from_iterable(concat_h))
bleu = corpus_bleu(concat_h, concat_rest)
print('multi-reference BLEU (leave-one-out): %.2f' % bleu)
if __name__ == '__main__':
if args.sys:
src, tgt, hypos, log_probs = load_sys(args.sys)
print('pairwise BLEU: %.2f' % pairwise(hypos))
if args.output:
merge(src, tgt, hypos, log_probs, args.output)
if args.ref:
_, _, refs = load_ref(args.ref)
if args.sys:
multi_ref(refs, hypos)
else:
intra_ref(refs)
if __name__ == '__main__':
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment