Commit c7c567a7 authored by Alexei Baevski's avatar Alexei Baevski Committed by Myle Ott
Browse files

word stats in eval_lm

parent c9b800d2
......@@ -14,6 +14,21 @@ from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
class WordStat(object):
def __init__(self, word, is_bpe):
self.word = word
self.is_bpe = is_bpe
self.log_prob = 0
self.count = 0
def add(self, log_prob):
self.log_prob += log_prob
self.count += 1
def __str__(self):
return '{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob / self.count, self.is_bpe)
def main(parsed_args):
assert parsed_args.path is not None, '--path required for evaluation!'
......@@ -70,6 +85,8 @@ def main(parsed_args):
bpe_toks = None
bpe_len = 0
word_stats = dict()
with progress_bar.build_progress_bar(args, itr) as t:
results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
wps_meter = TimeMeter()
......@@ -93,17 +110,22 @@ def main(parsed_args):
score_sum += pos_scores.sum()
count += pos_scores.numel() - skipped_toks
if args.output_word_probs:
if args.output_word_probs or args.output_word_stats:
w = ''
word_prob = []
is_bpe = False
for i in range(len(hypo['tokens'])):
w_ind = hypo['tokens'][i].item()
w += task.dictionary[w_ind]
if bpe_toks is not None and w_ind in bpe_toks:
w = w[:-bpe_len]
is_bpe = True
else:
word_prob.append((w, pos_scores[i].item()))
word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item())
is_bpe = False
w = ''
if args.output_word_probs:
print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
wps_meter.update(src_tokens.size(0))
......@@ -113,6 +135,10 @@ def main(parsed_args):
print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))
if args.output_word_stats:
for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
print(ws)
if __name__ == '__main__':
parser = options.get_eval_lm_parser()
......
......@@ -264,6 +264,8 @@ def add_eval_lm_args(parser):
add_common_eval_args(group)
group.add_argument('--output-word-probs', action='store_true',
help='if set, outputs words and their predicted log probabilities to standard output')
group.add_argument('--output-word-stats', action='store_true',
help='if set, outputs word statistics such as word count, average probability, etc')
def add_generation_args(parser):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment