eval_lm.py 7.07 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8

Myle Ott's avatar
Myle Ott committed
9
10
11
"""
Evaluate the perplexity of a trained language model.
"""
alexeib's avatar
alexeib committed
12
13
14
15

import numpy as np
import torch

Myle Ott's avatar
Myle Ott committed
16
from fairseq import options, progress_bar, tasks, utils
alexeib's avatar
alexeib committed
17
18
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer
19
from fairseq.utils import import_user_module
alexeib's avatar
alexeib committed
20
21


Alexei Baevski's avatar
Alexei Baevski committed
22
23
24
25
26
class WordStat(object):
    def __init__(self, word, is_bpe):
        self.word = word
        self.is_bpe = is_bpe
        self.log_prob = 0
Myle Ott's avatar
Myle Ott committed
27
        self.next_word_prob = 0
Alexei Baevski's avatar
Alexei Baevski committed
28
        self.count = 0
Myle Ott's avatar
Myle Ott committed
29
30
31
32
33
34
35
36
37
38
39
        self.missing_next_words = 0

    def add(self, log_prob, next_word_prob):
        """ increments counters for the sum of log probs of current word and next
            word (given context ending at current word). Since the next word might be at the end of the example,
            or it might be not counted because it is not an ending subword unit,
            also keeps track of how many of those we have seen """
        if next_word_prob is not None:
            self.next_word_prob += next_word_prob
        else:
            self.missing_next_words += 1
Alexei Baevski's avatar
Alexei Baevski committed
40
41
42
43
        self.log_prob += log_prob
        self.count += 1

    def __str__(self):
Myle Ott's avatar
Myle Ott committed
44
45
        return '{}\t{}\t{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob, self.is_bpe,
                                               self.next_word_prob, self.count - self.missing_next_words)
Alexei Baevski's avatar
Alexei Baevski committed
46
47


alexeib's avatar
alexeib committed
48
49
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'
alexeib's avatar
alexeib committed
50

51
52
    import_user_module(parsed_args)

alexeib's avatar
alexeib committed
53
54
55
56
57
58
59
60
    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
Myle Ott's avatar
Myle Ott committed
61
62
63
    models, args = utils.load_ensemble_for_inference(
        parsed_args.path.split(':'), task, model_arg_overrides=eval(parsed_args.model_overrides),
    )
alexeib's avatar
alexeib committed
64

alexeib's avatar
alexeib committed
65
66
67
68
    for arg in vars(parsed_args).keys():
        if arg not in {'self_target', 'future_target', 'past_target', 'tokens_per_sample', 'output_size_dictionary'}:
            setattr(args, arg, getattr(parsed_args, arg))
    task = tasks.setup_task(args)
Myle Ott's avatar
Myle Ott committed
69
70
71
72

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
alexeib's avatar
alexeib committed
73
74
75
76

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
Myle Ott's avatar
Myle Ott committed
77
78
        if args.fp16:
            model.half()
Myle Ott's avatar
Myle Ott committed
79
80
        if use_cuda:
            model.cuda()
alexeib's avatar
alexeib committed
81

alexeib's avatar
alexeib committed
82
83
    assert len(models) > 0

Myle Ott's avatar
Myle Ott committed
84
85
    print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

86
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
87
        dataset=task.dataset(args.gen_subset),
Alexei Baevski's avatar
Alexei Baevski committed
88
89
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
90
91
92
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
Myle Ott's avatar
Myle Ott committed
93
        ignore_invalid_inputs=True,
Myle Ott's avatar
Myle Ott committed
94
95
        num_shards=args.num_shards,
        shard_id=args.shard_id,
Myle Ott's avatar
Myle Ott committed
96
        num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
97
    ).next_epoch_itr(shuffle=False)
alexeib's avatar
alexeib committed
98
99

    gen_timer = StopwatchMeter()
Myle Ott's avatar
Myle Ott committed
100
    scorer = SequenceScorer(task.target_dictionary)
alexeib's avatar
alexeib committed
101
102
103

    score_sum = 0.
    count = 0
Alexei Baevski's avatar
Alexei Baevski committed
104
105

    if args.remove_bpe is not None:
106
107
108
109
110
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
            bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
111
        bpe_len = len(bpe_cont)
Alexei Baevski's avatar
Alexei Baevski committed
112
113
    else:
        bpe_toks = None
114
        bpe_len = 0
Alexei Baevski's avatar
Alexei Baevski committed
115

Alexei Baevski's avatar
Alexei Baevski committed
116
117
    word_stats = dict()

alexeib's avatar
alexeib committed
118
119
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
Myle Ott's avatar
Myle Ott committed
120
121
122
123
124
125
126
127
128
129
130
        for sample in t:
            sample = utils.move_to_cuda(sample) if use_cuda else sample
            if 'net_input' not in sample:
                continue

            gen_timer.start()
            hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

            for hypos_i in hypos:
                hypo = hypos_i[0]
alexeib's avatar
alexeib committed
131
                pos_scores = hypo['positional_scores']
Alexei Baevski's avatar
Alexei Baevski committed
132
133
134
135
136
137
138
139
140

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

alexeib's avatar
alexeib committed
141
142
143
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
Myle Ott's avatar
Myle Ott committed
144
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
alexeib's avatar
alexeib committed
145
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
Myle Ott's avatar
Myle Ott committed
146
                score_sum += pos_scores.sum().cpu()
Alexei Baevski's avatar
Alexei Baevski committed
147
                count += pos_scores.numel() - skipped_toks
148

Alexei Baevski's avatar
Alexei Baevski committed
149
                if args.output_word_probs or args.output_word_stats:
150
151
                    w = ''
                    word_prob = []
Alexei Baevski's avatar
Alexei Baevski committed
152
                    is_bpe = False
153
154
155
156
157
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
Alexei Baevski's avatar
Alexei Baevski committed
158
                            is_bpe = True
159
160
                        else:
                            word_prob.append((w, pos_scores[i].item()))
Myle Ott's avatar
Myle Ott committed
161
162
163
164
165
166
167
168
169
170

                            next_prob = None
                            ind = i + 1
                            while ind < len(hypo['tokens']):
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
Alexei Baevski's avatar
Alexei Baevski committed
171
                            is_bpe = False
172
                            w = ''
Alexei Baevski's avatar
Alexei Baevski committed
173
174
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
175

Myle Ott's avatar
Myle Ott committed
176
            wps_meter.update(sample['ntokens'])
alexeib's avatar
alexeib committed
177
178
179
180
181
182
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

Alexei Baevski's avatar
Alexei Baevski committed
183
184
185
186
    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)

alexeib's avatar
alexeib committed
187

Myle Ott's avatar
Myle Ott committed
188
def cli_main():
alexeib's avatar
alexeib committed
189
    parser = options.get_eval_lm_parser()
190
    args = options.parse_args_and_arch(parser)
alexeib's avatar
alexeib committed
191
    main(args)
Myle Ott's avatar
Myle Ott committed
192
193
194
195


if __name__ == '__main__':
    cli_main()