eval_lm.py 6.65 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8

Myle Ott's avatar
Myle Ott committed
9
10
11
"""
Evaluate the perplexity of a trained language model.
"""
alexeib's avatar
alexeib committed
12
13
14
15

import numpy as np
import torch

Myle Ott's avatar
Myle Ott committed
16
from fairseq import options, progress_bar, tasks, utils
alexeib's avatar
alexeib committed
17
18
19
20
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer


Alexei Baevski's avatar
Alexei Baevski committed
21
22
23
24
25
class WordStat(object):
    def __init__(self, word, is_bpe):
        self.word = word
        self.is_bpe = is_bpe
        self.log_prob = 0
Myle Ott's avatar
Myle Ott committed
26
        self.next_word_prob = 0
Alexei Baevski's avatar
Alexei Baevski committed
27
        self.count = 0
Myle Ott's avatar
Myle Ott committed
28
29
30
31
32
33
34
35
36
37
38
        self.missing_next_words = 0

    def add(self, log_prob, next_word_prob):
        """ increments counters for the sum of log probs of current word and next
            word (given context ending at current word). Since the next word might be at the end of the example,
            or it might be not counted because it is not an ending subword unit,
            also keeps track of how many of those we have seen """
        if next_word_prob is not None:
            self.next_word_prob += next_word_prob
        else:
            self.missing_next_words += 1
Alexei Baevski's avatar
Alexei Baevski committed
39
40
41
42
        self.log_prob += log_prob
        self.count += 1

    def __str__(self):
Myle Ott's avatar
Myle Ott committed
43
44
        return '{}\t{}\t{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob, self.is_bpe,
                                               self.next_word_prob, self.count - self.missing_next_words)
Alexei Baevski's avatar
Alexei Baevski committed
45
46


alexeib's avatar
alexeib committed
47
48
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'
alexeib's avatar
alexeib committed
49

alexeib's avatar
alexeib committed
50
51
52
53
54
55
56
57
    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
Myle Ott's avatar
Myle Ott committed
58
59
60
    models, args = utils.load_ensemble_for_inference(
        parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
    )
alexeib's avatar
alexeib committed
61

alexeib's avatar
alexeib committed
62
63
64
65
    for arg in vars(parsed_args).keys():
        if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
            setattr(args, arg, getattr(parsed_args, arg))
    task = tasks.setup_task(args)
Myle Ott's avatar
Myle Ott committed
66
67
68
69

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
alexeib's avatar
alexeib committed
70
71
72
73

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
Myle Ott's avatar
Myle Ott committed
74
75
        if args.fp16:
            model.half()
alexeib's avatar
alexeib committed
76

alexeib's avatar
alexeib committed
77
78
    assert len(models) > 0

Myle Ott's avatar
Myle Ott committed
79
80
    print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

81
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
82
        dataset=task.dataset(args.gen_subset),
Alexei Baevski's avatar
Alexei Baevski committed
83
84
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
85
86
87
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
Myle Ott's avatar
Myle Ott committed
88
        ignore_invalid_inputs=True,
Myle Ott's avatar
Myle Ott committed
89
90
        num_shards=args.num_shards,
        shard_id=args.shard_id,
Myle Ott's avatar
Myle Ott committed
91
        num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
92
    ).next_epoch_itr(shuffle=False)
alexeib's avatar
alexeib committed
93
94

    gen_timer = StopwatchMeter()
Myle Ott's avatar
Myle Ott committed
95
    scorer = SequenceScorer(models, task.target_dictionary)
alexeib's avatar
alexeib committed
96
97
98
99
100
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0
Alexei Baevski's avatar
Alexei Baevski committed
101
102
103
104

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
105
        bpe_len = len(bpe_cont)
Alexei Baevski's avatar
Alexei Baevski committed
106
107
    else:
        bpe_toks = None
108
        bpe_len = 0
Alexei Baevski's avatar
Alexei Baevski committed
109

Alexei Baevski's avatar
Alexei Baevski committed
110
111
    word_stats = dict()

alexeib's avatar
alexeib committed
112
113
114
115
116
117
    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']
Alexei Baevski's avatar
Alexei Baevski committed
118
119
120
121
122
123
124
125
126

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

alexeib's avatar
alexeib committed
127
128
129
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
Myle Ott's avatar
Myle Ott committed
130
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
alexeib's avatar
alexeib committed
131
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
Myle Ott's avatar
Myle Ott committed
132
                score_sum += pos_scores.sum().cpu()
Alexei Baevski's avatar
Alexei Baevski committed
133
                count += pos_scores.numel() - skipped_toks
134

Alexei Baevski's avatar
Alexei Baevski committed
135
                if args.output_word_probs or args.output_word_stats:
136
137
                    w = ''
                    word_prob = []
Alexei Baevski's avatar
Alexei Baevski committed
138
                    is_bpe = False
139
140
141
142
143
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
Alexei Baevski's avatar
Alexei Baevski committed
144
                            is_bpe = True
145
146
                        else:
                            word_prob.append((w, pos_scores[i].item()))
Myle Ott's avatar
Myle Ott committed
147
148
149
150
151
152
153
154
155
156

                            next_prob = None
                            ind = i + 1
                            while ind < len(hypo['tokens']):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
Alexei Baevski's avatar
Alexei Baevski committed
157
                            is_bpe = False
158
                            w = ''
Alexei Baevski's avatar
Alexei Baevski committed
159
160
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
161

alexeib's avatar
alexeib committed
162
163
164
165
166
167
168
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

Alexei Baevski's avatar
Alexei Baevski committed
169
170
171
172
    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)

alexeib's avatar
alexeib committed
173
174
175

if __name__ == '__main__':
    parser = options.get_eval_lm_parser()
176
    args = options.parse_args_and_arch(parser)
alexeib's avatar
alexeib committed
177
    main(args)