eval_lm.py 5.23 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
8
9
10
11
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import numpy as np
import torch

Myle Ott's avatar
Myle Ott committed
12
from fairseq import data, options, progress_bar, tasks, utils
alexeib's avatar
alexeib committed
13
14
15
16
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer


Alexei Baevski's avatar
Alexei Baevski committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class WordStat(object):
    def __init__(self, word, is_bpe):
        self.word = word
        self.is_bpe = is_bpe
        self.log_prob = 0
        self.count = 0

    def add(self, log_prob):
        self.log_prob += log_prob
        self.count += 1

    def __str__(self):
        return '{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob / self.count, self.is_bpe)


alexeib's avatar
alexeib committed
32
33
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'
alexeib's avatar
alexeib committed
34

alexeib's avatar
alexeib committed
35
36
37
38
39
40
41
42
43
44
45
    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task)

    args.__dict__.update(parsed_args.__dict__)
Myle Ott's avatar
Myle Ott committed
46
    print(args)
alexeib's avatar
alexeib committed
47

alexeib's avatar
alexeib committed
48
    task.args = args
Myle Ott's avatar
Myle Ott committed
49
50
51
52

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
alexeib's avatar
alexeib committed
53
54
55
56

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
Myle Ott's avatar
Myle Ott committed
57
58
        if args.fp16:
            model.half()
alexeib's avatar
alexeib committed
59

alexeib's avatar
alexeib committed
60
61
    assert len(models) > 0

62
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
63
        dataset=task.dataset(args.gen_subset),
Alexei Baevski's avatar
Alexei Baevski committed
64
65
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
66
67
68
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
Myle Ott's avatar
Myle Ott committed
69
70
        num_shards=args.num_shards,
        shard_id=args.shard_id,
71
        ignore_invalid_inputs=True,
Myle Ott's avatar
Myle Ott committed
72
    ).next_epoch_itr(shuffle=False)
alexeib's avatar
alexeib committed
73
74

    gen_timer = StopwatchMeter()
Myle Ott's avatar
Myle Ott committed
75
    scorer = SequenceScorer(models, task.target_dictionary)
alexeib's avatar
alexeib committed
76
77
78
79
80
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0
Alexei Baevski's avatar
Alexei Baevski committed
81
82
83
84

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
85
        bpe_len = len(bpe_cont)
Alexei Baevski's avatar
Alexei Baevski committed
86
87
    else:
        bpe_toks = None
88
        bpe_len = 0
Alexei Baevski's avatar
Alexei Baevski committed
89

Alexei Baevski's avatar
Alexei Baevski committed
90
91
    word_stats = dict()

alexeib's avatar
alexeib committed
92
93
94
95
96
97
    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']
Alexei Baevski's avatar
Alexei Baevski committed
98
99
100
101
102
103
104
105
106

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

alexeib's avatar
alexeib committed
107
108
109
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
Myle Ott's avatar
Myle Ott committed
110
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
alexeib's avatar
alexeib committed
111
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
112
                score_sum += utils.item(pos_scores.sum())
Alexei Baevski's avatar
Alexei Baevski committed
113
                count += pos_scores.numel() - skipped_toks
114

Alexei Baevski's avatar
Alexei Baevski committed
115
                if args.output_word_probs or args.output_word_stats:
116
117
                    w = ''
                    word_prob = []
Alexei Baevski's avatar
Alexei Baevski committed
118
                    is_bpe = False
119
120
121
122
123
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
Alexei Baevski's avatar
Alexei Baevski committed
124
                            is_bpe = True
125
126
                        else:
                            word_prob.append((w, pos_scores[i].item()))
Alexei Baevski's avatar
Alexei Baevski committed
127
128
                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item())
                            is_bpe = False
129
                            w = ''
Alexei Baevski's avatar
Alexei Baevski committed
130
131
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
132

alexeib's avatar
alexeib committed
133
134
135
136
137
138
139
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

Alexei Baevski's avatar
Alexei Baevski committed
140
141
142
143
    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)

alexeib's avatar
alexeib committed
144
145
146

if __name__ == '__main__':
    parser = options.get_eval_lm_parser()
147
    args = options.parse_args_and_arch(parser)
alexeib's avatar
alexeib committed
148
    main(args)