eval_lm.py 7.88 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8

Myle Ott's avatar
Myle Ott committed
9
10
11
"""
Evaluate the perplexity of a trained language model.
"""
alexeib's avatar
alexeib committed
12
13
14
15

import numpy as np
import torch

Myle Ott's avatar
Myle Ott committed
16
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
Myle Ott's avatar
Myle Ott committed
17
from fairseq.data import LMContextWindowDataset
alexeib's avatar
alexeib committed
18
19
20
21
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer


Alexei Baevski's avatar
Alexei Baevski committed
22
23
24
25
26
class WordStat(object):
    def __init__(self, word, is_bpe):
        self.word = word
        self.is_bpe = is_bpe
        self.log_prob = 0
Myle Ott's avatar
Myle Ott committed
27
        self.next_word_prob = 0
Alexei Baevski's avatar
Alexei Baevski committed
28
        self.count = 0
Myle Ott's avatar
Myle Ott committed
29
30
31
32
33
34
35
36
37
38
39
        self.missing_next_words = 0

    def add(self, log_prob, next_word_prob):
        """ increments counters for the sum of log probs of current word and next
            word (given context ending at current word). Since the next word might be at the end of the example,
            or it might be not counted because it is not an ending subword unit,
            also keeps track of how many of those we have seen """
        if next_word_prob is not None:
            self.next_word_prob += next_word_prob
        else:
            self.missing_next_words += 1
Alexei Baevski's avatar
Alexei Baevski committed
40
41
42
43
        self.log_prob += log_prob
        self.count += 1

    def __str__(self):
Myle Ott's avatar
Myle Ott committed
44
45
        return '{}\t{}\t{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob, self.is_bpe,
                                               self.next_word_prob, self.count - self.missing_next_words)
Alexei Baevski's avatar
Alexei Baevski committed
46
47


alexeib's avatar
alexeib committed
48
49
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'
alexeib's avatar
alexeib committed
50

Myle Ott's avatar
Myle Ott committed
51
    utils.import_user_module(parsed_args)
52

alexeib's avatar
alexeib committed
53
54
55
56
57
58
59
60
    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
Myle Ott's avatar
Myle Ott committed
61
62
63
64
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(':'),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
Myle Ott's avatar
Myle Ott committed
65
    )
alexeib's avatar
alexeib committed
66

alexeib's avatar
alexeib committed
67
    for arg in vars(parsed_args).keys():
Myle Ott's avatar
Myle Ott committed
68
69
70
71
        if arg not in {
            'self_target', 'future_target', 'past_target', 'tokens_per_sample',
            'output_size_dictionary', 'add_bos_token',
        }:
alexeib's avatar
alexeib committed
72
            setattr(args, arg, getattr(parsed_args, arg))
Myle Ott's avatar
Myle Ott committed
73
74
75

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
alexeib's avatar
alexeib committed
76
    task = tasks.setup_task(args)
Myle Ott's avatar
Myle Ott committed
77
78
79

    # Load dataset splits
    task.load_dataset(args.gen_subset)
Myle Ott's avatar
Myle Ott committed
80
81
82
83
84
85
86
87
88
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))
alexeib's avatar
alexeib committed
89
90
91
92

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
Myle Ott's avatar
Myle Ott committed
93
94
        if args.fp16:
            model.half()
Myle Ott's avatar
Myle Ott committed
95
96
        if use_cuda:
            model.cuda()
alexeib's avatar
alexeib committed
97

alexeib's avatar
alexeib committed
98
99
    assert len(models) > 0

Myle Ott's avatar
Myle Ott committed
100
101
    print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

102
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
103
        dataset=dataset,
Alexei Baevski's avatar
Alexei Baevski committed
104
105
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
106
107
108
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
Myle Ott's avatar
Myle Ott committed
109
        ignore_invalid_inputs=True,
Myle Ott's avatar
Myle Ott committed
110
111
        num_shards=args.num_shards,
        shard_id=args.shard_id,
Myle Ott's avatar
Myle Ott committed
112
        num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
113
    ).next_epoch_itr(shuffle=False)
alexeib's avatar
alexeib committed
114
115

    gen_timer = StopwatchMeter()
Myle Ott's avatar
Myle Ott committed
116
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
alexeib's avatar
alexeib committed
117
118
119

    score_sum = 0.
    count = 0
Alexei Baevski's avatar
Alexei Baevski committed
120
121

    if args.remove_bpe is not None:
122
123
124
125
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
Myle Ott's avatar
Myle Ott committed
126
127
128
129
130
            bpe_toks = set(
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            )
131
        bpe_len = len(bpe_cont)
Alexei Baevski's avatar
Alexei Baevski committed
132
133
    else:
        bpe_toks = None
134
        bpe_len = 0
Alexei Baevski's avatar
Alexei Baevski committed
135

Alexei Baevski's avatar
Alexei Baevski committed
136
137
    word_stats = dict()

alexeib's avatar
alexeib committed
138
139
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
Myle Ott's avatar
Myle Ott committed
140

Myle Ott's avatar
Myle Ott committed
141
142
143
144
        for sample in t:
            if 'net_input' not in sample:
                continue

Myle Ott's avatar
Myle Ott committed
145
146
            sample = utils.move_to_cuda(sample) if use_cuda else sample

Myle Ott's avatar
Myle Ott committed
147
148
149
150
151
152
            gen_timer.start()
            hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for hypos_i in hypos:
                hypo = hypos_i[0]
Myle Ott's avatar
Myle Ott committed
153
154
155
156

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()
Alexei Baevski's avatar
Alexei Baevski committed
157

Myle Ott's avatar
Myle Ott committed
158
159
160
161
162
                if args.add_bos_token:
                    assert hypo['tokens'][0].item() == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

Alexei Baevski's avatar
Alexei Baevski committed
163
164
                skipped_toks = 0
                if bpe_toks is not None:
Myle Ott's avatar
Myle Ott committed
165
166
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
Alexei Baevski's avatar
Alexei Baevski committed
167
168
169
170
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

alexeib's avatar
alexeib committed
171
172
173
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
Myle Ott's avatar
Myle Ott committed
174
                          task.target_dictionary.string(tokens[inf_scores.nonzero()]))
alexeib's avatar
alexeib committed
175
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
Myle Ott's avatar
Myle Ott committed
176
                score_sum += pos_scores.sum().cpu()
Alexei Baevski's avatar
Alexei Baevski committed
177
                count += pos_scores.numel() - skipped_toks
178

Alexei Baevski's avatar
Alexei Baevski committed
179
                if args.output_word_probs or args.output_word_stats:
180
181
                    w = ''
                    word_prob = []
Alexei Baevski's avatar
Alexei Baevski committed
182
                    is_bpe = False
Myle Ott's avatar
Myle Ott committed
183
184
185
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
186
187
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
Alexei Baevski's avatar
Alexei Baevski committed
188
                            is_bpe = True
189
190
                        else:
                            word_prob.append((w, pos_scores[i].item()))
Myle Ott's avatar
Myle Ott committed
191
192
193

                            next_prob = None
                            ind = i + 1
Myle Ott's avatar
Myle Ott committed
194
                            while ind < len(tokens):
Myle Ott's avatar
Myle Ott committed
195
196
197
198
199
200
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
Alexei Baevski's avatar
Alexei Baevski committed
201
                            is_bpe = False
202
                            w = ''
Alexei Baevski's avatar
Alexei Baevski committed
203
204
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
205

Myle Ott's avatar
Myle Ott committed
206
            wps_meter.update(sample['ntokens'])
alexeib's avatar
alexeib committed
207
208
209
210
211
212
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

Alexei Baevski's avatar
Alexei Baevski committed
213
214
215
216
    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)

alexeib's avatar
alexeib committed
217

Myle Ott's avatar
Myle Ott committed
218
def cli_main():
alexeib's avatar
alexeib committed
219
    parser = options.get_eval_lm_parser()
220
    args = options.parse_args_and_arch(parser)
alexeib's avatar
alexeib committed
221
    main(args)
Myle Ott's avatar
Myle Ott committed
222
223
224
225


if __name__ == '__main__':
    cli_main()