eval_lm.py 7.63 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8

Myle Ott's avatar
Myle Ott committed
9
10
11
"""
Evaluate the perplexity of a trained language model.
"""
alexeib's avatar
alexeib committed
12
13
14
15

import numpy as np
import torch

Myle Ott's avatar
Myle Ott committed
16
from fairseq import options, progress_bar, tasks, utils
Myle Ott's avatar
Myle Ott committed
17
from fairseq.data import LMContextWindowDataset
alexeib's avatar
alexeib committed
18
19
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
20
from fairseq.utils import import_user_module
alexeib's avatar
alexeib committed
21
22


Alexei Baevski's avatar
Alexei Baevski committed
23
24
25
26
27
class WordStat(object):
    def __init__(self, word, is_bpe):
        self.word = word
        self.is_bpe = is_bpe
        self.log_prob = 0
Myle Ott's avatar
Myle Ott committed
28
        self.next_word_prob = 0
Alexei Baevski's avatar
Alexei Baevski committed
29
        self.count = 0
Myle Ott's avatar
Myle Ott committed
30
31
32
33
34
35
36
37
38
39
40
        self.missing_next_words = 0

    def add(self, log_prob, next_word_prob):
        """ increments counters for the sum of log probs of current word and next
            word (given context ending at current word). Since the next word might be at the end of the example,
            or it might be not counted because it is not an ending subword unit,
            also keeps track of how many of those we have seen """
        if next_word_prob is not None:
            self.next_word_prob += next_word_prob
        else:
            self.missing_next_words += 1
Alexei Baevski's avatar
Alexei Baevski committed
41
42
43
44
        self.log_prob += log_prob
        self.count += 1

    def __str__(self):
Myle Ott's avatar
Myle Ott committed
45
46
        return '{}\t{}\t{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob, self.is_bpe,
                                               self.next_word_prob, self.count - self.missing_next_words)
Alexei Baevski's avatar
Alexei Baevski committed
47
48


alexeib's avatar
alexeib committed
49
50
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'
alexeib's avatar
alexeib committed
51

52
53
    import_user_module(parsed_args)

alexeib's avatar
alexeib committed
54
55
56
57
58
59
60
61
    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
Myle Ott's avatar
Myle Ott committed
62
63
64
    models, args = utils.load_ensemble_for_inference(
        parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
    )
alexeib's avatar
alexeib committed
65

alexeib's avatar
alexeib committed
66
67
68
    for arg in vars(parsed_args).keys():
        if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
            setattr(args, arg, getattr(parsed_args, arg))
Myle Ott's avatar
Myle Ott committed
69
70
71

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
alexeib's avatar
alexeib committed
72
    task = tasks.setup_task(args)
Myle Ott's avatar
Myle Ott committed
73
74
75

    # Load dataset splits
    task.load_dataset(args.gen_subset)
Myle Ott's avatar
Myle Ott committed
76
77
78
79
80
81
82
83
84
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))
alexeib's avatar
alexeib committed
85
86
87
88

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
Myle Ott's avatar
Myle Ott committed
89
90
        if args.fp16:
            model.half()
Myle Ott's avatar
Myle Ott committed
91
92
        if use_cuda:
            model.cuda()
alexeib's avatar
alexeib committed
93

alexeib's avatar
alexeib committed
94
95
    assert len(models) > 0

Myle Ott's avatar
Myle Ott committed
96
97
    print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

98
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
99
        dataset=dataset,
Alexei Baevski's avatar
Alexei Baevski committed
100
101
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
102
103
104
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
Myle Ott's avatar
Myle Ott committed
105
        ignore_invalid_inputs=True,
Myle Ott's avatar
Myle Ott committed
106
107
        num_shards=args.num_shards,
        shard_id=args.shard_id,
Myle Ott's avatar
Myle Ott committed
108
        num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
109
    ).next_epoch_itr(shuffle=False)
alexeib's avatar
alexeib committed
110
111

    gen_timer = StopwatchMeter()
Myle Ott's avatar
Myle Ott committed
112
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
alexeib's avatar
alexeib committed
113
114
115

    score_sum = 0.
    count = 0
Alexei Baevski's avatar
Alexei Baevski committed
116
117

    if args.remove_bpe is not None:
118
119
120
121
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
Myle Ott's avatar
Myle Ott committed
122
123
124
125
126
            bpe_toks = set(
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            )
127
        bpe_len = len(bpe_cont)
Alexei Baevski's avatar
Alexei Baevski committed
128
129
    else:
        bpe_toks = None
130
        bpe_len = 0
Alexei Baevski's avatar
Alexei Baevski committed
131

Alexei Baevski's avatar
Alexei Baevski committed
132
133
    word_stats = dict()

alexeib's avatar
alexeib committed
134
135
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
Myle Ott's avatar
Myle Ott committed
136

Myle Ott's avatar
Myle Ott committed
137
138
139
140
        for sample in t:
            if 'net_input' not in sample:
                continue

Myle Ott's avatar
Myle Ott committed
141
142
            sample = utils.move_to_cuda(sample) if use_cuda else sample

Myle Ott's avatar
Myle Ott committed
143
144
145
146
147
148
            gen_timer.start()
            hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for hypos_i in hypos:
                hypo = hypos_i[0]
Myle Ott's avatar
Myle Ott committed
149
150
151
152

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()
Alexei Baevski's avatar
Alexei Baevski committed
153
154
155

                skipped_toks = 0
                if bpe_toks is not None:
Myle Ott's avatar
Myle Ott committed
156
157
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
Alexei Baevski's avatar
Alexei Baevski committed
158
159
160
161
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

alexeib's avatar
alexeib committed
162
163
164
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
Myle Ott's avatar
Myle Ott committed
165
                          task.target_dictionary.string(tokens[inf_scores.nonzero()]))
alexeib's avatar
alexeib committed
166
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
Myle Ott's avatar
Myle Ott committed
167
                score_sum += pos_scores.sum().cpu()
Alexei Baevski's avatar
Alexei Baevski committed
168
                count += pos_scores.numel() - skipped_toks
169

Alexei Baevski's avatar
Alexei Baevski committed
170
                if args.output_word_probs or args.output_word_stats:
171
172
                    w = ''
                    word_prob = []
Alexei Baevski's avatar
Alexei Baevski committed
173
                    is_bpe = False
Myle Ott's avatar
Myle Ott committed
174
175
176
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
177
178
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
Alexei Baevski's avatar
Alexei Baevski committed
179
                            is_bpe = True
180
181
                        else:
                            word_prob.append((w, pos_scores[i].item()))
Myle Ott's avatar
Myle Ott committed
182
183
184

                            next_prob = None
                            ind = i + 1
Myle Ott's avatar
Myle Ott committed
185
                            while ind < len(tokens):
Myle Ott's avatar
Myle Ott committed
186
187
188
189
190
191
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
Alexei Baevski's avatar
Alexei Baevski committed
192
                            is_bpe = False
193
                            w = ''
Alexei Baevski's avatar
Alexei Baevski committed
194
195
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
196

Myle Ott's avatar
Myle Ott committed
197
            wps_meter.update(sample['ntokens'])
alexeib's avatar
alexeib committed
198
199
200
201
202
203
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

Alexei Baevski's avatar
Alexei Baevski committed
204
205
206
207
    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)

alexeib's avatar
alexeib committed
208

Myle Ott's avatar
Myle Ott committed
209
def cli_main():
alexeib's avatar
alexeib committed
210
    parser = options.get_eval_lm_parser()
211
    args = options.parse_args_and_arch(parser)
alexeib's avatar
alexeib committed
212
    main(args)
Myle Ott's avatar
Myle Ott committed
213
214
215
216


if __name__ == '__main__':
    cli_main()