"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "4dc5518e4d2ae89a687709bcbe05d2f3f80e00ad"
Commit dbe96371 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

option to print language model words and their log probs during evaluation

parent e7b494f8
......@@ -60,8 +60,10 @@ def main(args):
if args.remove_bpe is not None:
bpe_cont = args.remove_bpe.rstrip()
bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
bpe_len = len(bpe_cont)
else:
bpe_toks = None
bpe_len = 0
with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
......@@ -85,6 +87,20 @@ def main(args):
pos_scores = pos_scores[(~inf_scores).nonzero()]
score_sum += pos_scores.sum()
count += pos_scores.numel() - skipped_toks
if args.output_word_probs:
w = ''
word_prob = []
for i in range(len(hypo['tokens'])):
w_ind = hypo['tokens'][i].item()
w += task.dictionary[w_ind]
if bpe_toks is not None and w_ind in bpe_toks:
w = w[:-bpe_len]
else:
word_prob.append((w, pos_scores[i].item()))
w = ''
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
wps_meter.update(src_tokens.size(0))
t.log({'wps': round(wps_meter.avg)})
......
......@@ -249,6 +249,8 @@ def add_common_eval_args(group):
def add_eval_lm_args(parser):
group = parser.add_argument_group('LM Evaluation')
add_common_eval_args(group)
group.add_argument('--output-word-probs', action='store_true',
help='if set, outputs words and their predicted log probabilities to standard output')
def add_generation_args(parser):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment