eval_lm.py 7.78 KB
Newer Older
alexeib's avatar
alexeib committed
1
#!/usr/bin/env python3 -u
2
# Copyright (c) Facebook, Inc. and its affiliates.
alexeib's avatar
alexeib committed
3
#
4
5
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Myle Ott's avatar
Myle Ott committed
6

Myle Ott's avatar
Myle Ott committed
7
8
9
"""
Evaluate the perplexity of a trained language model.
"""
alexeib's avatar
alexeib committed
10
11
12
13

import numpy as np
import torch

Myle Ott's avatar
Myle Ott committed
14
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
Myle Ott's avatar
Myle Ott committed
15
from fairseq.data import LMContextWindowDataset
alexeib's avatar
alexeib committed
16
17
18
19
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer


Alexei Baevski's avatar
Alexei Baevski committed
20
21
22
23
24
class WordStat(object):
    def __init__(self, word, is_bpe):
        self.word = word
        self.is_bpe = is_bpe
        self.log_prob = 0
Myle Ott's avatar
Myle Ott committed
25
        self.next_word_prob = 0
Alexei Baevski's avatar
Alexei Baevski committed
26
        self.count = 0
Myle Ott's avatar
Myle Ott committed
27
28
29
30
31
32
33
34
35
36
37
        self.missing_next_words = 0

    def add(self, log_prob, next_word_prob):
        """ increments counters for the sum of log probs of current word and next
            word (given context ending at current word). Since the next word might be at the end of the example,
            or it might be not counted because it is not an ending subword unit,
            also keeps track of how many of those we have seen """
        if next_word_prob is not None:
            self.next_word_prob += next_word_prob
        else:
            self.missing_next_words += 1
Alexei Baevski's avatar
Alexei Baevski committed
38
39
40
41
        self.log_prob += log_prob
        self.count += 1

    def __str__(self):
Myle Ott's avatar
Myle Ott committed
42
43
        return '{}\t{}\t{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob, self.is_bpe,
                                               self.next_word_prob, self.count - self.missing_next_words)
Alexei Baevski's avatar
Alexei Baevski committed
44
45


alexeib's avatar
alexeib committed
46
47
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'
alexeib's avatar
alexeib committed
48

Myle Ott's avatar
Myle Ott committed
49
    utils.import_user_module(parsed_args)
50

alexeib's avatar
alexeib committed
51
52
53
54
55
56
57
58
    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
Myle Ott's avatar
Myle Ott committed
59
60
61
62
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(':'),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
Myle Ott's avatar
Myle Ott committed
63
    )
alexeib's avatar
alexeib committed
64

alexeib's avatar
alexeib committed
65
    for arg in vars(parsed_args).keys():
Myle Ott's avatar
Myle Ott committed
66
67
68
69
        if arg not in {
            'self_target', 'future_target', 'past_target', 'tokens_per_sample',
            'output_size_dictionary', 'add_bos_token',
        }:
alexeib's avatar
alexeib committed
70
            setattr(args, arg, getattr(parsed_args, arg))
Myle Ott's avatar
Myle Ott committed
71
72
73

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
alexeib's avatar
alexeib committed
74
    task = tasks.setup_task(args)
Myle Ott's avatar
Myle Ott committed
75
76
77

    # Load dataset splits
    task.load_dataset(args.gen_subset)
Myle Ott's avatar
Myle Ott committed
78
79
80
81
82
83
84
85
86
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))
alexeib's avatar
alexeib committed
87
88
89
90

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
Myle Ott's avatar
Myle Ott committed
91
92
        if args.fp16:
            model.half()
Myle Ott's avatar
Myle Ott committed
93
94
        if use_cuda:
            model.cuda()
alexeib's avatar
alexeib committed
95

alexeib's avatar
alexeib committed
96
97
    assert len(models) > 0

Myle Ott's avatar
Myle Ott committed
98
99
    print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

100
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
101
        dataset=dataset,
Alexei Baevski's avatar
Alexei Baevski committed
102
103
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
104
105
106
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
Myle Ott's avatar
Myle Ott committed
107
        ignore_invalid_inputs=True,
Myle Ott's avatar
Myle Ott committed
108
109
        num_shards=args.num_shards,
        shard_id=args.shard_id,
Myle Ott's avatar
Myle Ott committed
110
        num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
111
    ).next_epoch_itr(shuffle=False)
alexeib's avatar
alexeib committed
112
113

    gen_timer = StopwatchMeter()
Myle Ott's avatar
Myle Ott committed
114
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
alexeib's avatar
alexeib committed
115
116
117

    score_sum = 0.
    count = 0
Alexei Baevski's avatar
Alexei Baevski committed
118
119

    if args.remove_bpe is not None:
120
121
122
123
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
Myle Ott's avatar
Myle Ott committed
124
125
126
127
128
            bpe_toks = set(
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            )
129
        bpe_len = len(bpe_cont)
Alexei Baevski's avatar
Alexei Baevski committed
130
131
    else:
        bpe_toks = None
132
        bpe_len = 0
Alexei Baevski's avatar
Alexei Baevski committed
133

Alexei Baevski's avatar
Alexei Baevski committed
134
135
    word_stats = dict()

alexeib's avatar
alexeib committed
136
137
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
Myle Ott's avatar
Myle Ott committed
138

Myle Ott's avatar
Myle Ott committed
139
140
141
142
        for sample in t:
            if 'net_input' not in sample:
                continue

Myle Ott's avatar
Myle Ott committed
143
144
            sample = utils.move_to_cuda(sample) if use_cuda else sample

Myle Ott's avatar
Myle Ott committed
145
146
147
148
149
150
            gen_timer.start()
            hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for hypos_i in hypos:
                hypo = hypos_i[0]
Myle Ott's avatar
Myle Ott committed
151
152
153
154

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()
Alexei Baevski's avatar
Alexei Baevski committed
155

Myle Ott's avatar
Myle Ott committed
156
157
158
159
160
                if args.add_bos_token:
                    assert hypo['tokens'][0].item() == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

Alexei Baevski's avatar
Alexei Baevski committed
161
162
                skipped_toks = 0
                if bpe_toks is not None:
Myle Ott's avatar
Myle Ott committed
163
164
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
Alexei Baevski's avatar
Alexei Baevski committed
165
166
167
168
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

alexeib's avatar
alexeib committed
169
170
171
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
Myle Ott's avatar
Myle Ott committed
172
                          task.target_dictionary.string(tokens[inf_scores.nonzero()]))
alexeib's avatar
alexeib committed
173
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
Myle Ott's avatar
Myle Ott committed
174
                score_sum += pos_scores.sum().cpu()
Alexei Baevski's avatar
Alexei Baevski committed
175
                count += pos_scores.numel() - skipped_toks
176

Alexei Baevski's avatar
Alexei Baevski committed
177
                if args.output_word_probs or args.output_word_stats:
178
179
                    w = ''
                    word_prob = []
Alexei Baevski's avatar
Alexei Baevski committed
180
                    is_bpe = False
Myle Ott's avatar
Myle Ott committed
181
182
183
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
184
185
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
Alexei Baevski's avatar
Alexei Baevski committed
186
                            is_bpe = True
187
188
                        else:
                            word_prob.append((w, pos_scores[i].item()))
Myle Ott's avatar
Myle Ott committed
189
190
191

                            next_prob = None
                            ind = i + 1
Myle Ott's avatar
Myle Ott committed
192
                            while ind < len(tokens):
Myle Ott's avatar
Myle Ott committed
193
194
195
196
197
198
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
Alexei Baevski's avatar
Alexei Baevski committed
199
                            is_bpe = False
200
                            w = ''
Alexei Baevski's avatar
Alexei Baevski committed
201
202
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
203

Myle Ott's avatar
Myle Ott committed
204
            wps_meter.update(sample['ntokens'])
alexeib's avatar
alexeib committed
205
206
207
208
209
210
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

Alexei Baevski's avatar
Alexei Baevski committed
211
212
213
214
    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)

alexeib's avatar
alexeib committed
215

Myle Ott's avatar
Myle Ott committed
216
def cli_main():
alexeib's avatar
alexeib committed
217
    parser = options.get_eval_lm_parser()
218
    args = options.parse_args_and_arch(parser)
alexeib's avatar
alexeib committed
219
    main(args)
Myle Ott's avatar
Myle Ott committed
220
221
222
223


if __name__ == '__main__':
    cli_main()