Commit dbe96371 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

option to print language model words and their log probs during evaluation

parent e7b494f8
...@@ -60,8 +60,10 @@ def main(args): ...@@ -60,8 +60,10 @@ def main(args):
if args.remove_bpe is not None: if args.remove_bpe is not None:
bpe_cont = args.remove_bpe.rstrip() bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont)) bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
bpe_len = len(bpe_cont)
else: else:
bpe_toks = None bpe_toks = None
bpe_len = 0
with progress_bar.build_progress_bar(args, itr) as t: with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer) results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
...@@ -85,6 +87,20 @@ def main(args): ...@@ -85,6 +87,20 @@ def main(args):
pos_scores = pos_scores[(~inf_scores).nonzero()] pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum() score_sum += pos_scores.sum()
count += pos_scores.numel() - skipped_toks count += pos_scores.numel() - skipped_toks
if args.output_word_probs:
w = ''
word_prob = []
for i in range(len(hypo['tokens'])):
w_ind = hypo['tokens'][i].item()
w += task.dictionary[w_ind]
if bpe_toks is not None and w_ind in bpe_toks:
w = w[:-bpe_len]
else:
word_prob.append((w, pos_scores[i].item()))
w = ''
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
wps_meter.update(src_tokens.size(0)) wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)}) t.log({'wps': round(wps_meter.avg)})
......
...@@ -249,6 +249,8 @@ def add_common_eval_args(group): ...@@ -249,6 +249,8 @@ def add_common_eval_args(group):
def add_eval_lm_args(parser): def add_eval_lm_args(parser):
group = parser.add_argument_group('LM Evaluation') group = parser.add_argument_group('LM Evaluation')
add_common_eval_args(group) add_common_eval_args(group)
group.add_argument('--output-word-probs', action='store_true',
help='if set, outputs words and their predicted log probabilities to standard output')
def add_generation_args(parser): def add_generation_args(parser):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment