"vscode:/vscode.git/clone" did not exist on "ba01560f21516805fc5ceba5c2566dcbd1cf66d8"
eval_lm.py 7.89 KB
Newer Older
alexeib's avatar
alexeib committed
1
#!/usr/bin/env python3 -u
2
# Copyright (c) Facebook, Inc. and its affiliates.
alexeib's avatar
alexeib committed
3
#
4
5
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
Myle Ott's avatar
Myle Ott committed
6

Myle Ott's avatar
Myle Ott committed
7
8
9
"""
Evaluate the perplexity of a trained language model.
"""
alexeib's avatar
alexeib committed
10
11
12
13

import numpy as np
import torch

Myle Ott's avatar
Myle Ott committed
14
from fairseq import checkpoint_utils, options, progress_bar, tasks, utils
Myle Ott's avatar
Myle Ott committed
15
from fairseq.data import LMContextWindowDataset
alexeib's avatar
alexeib committed
16
17
18
19
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer


Alexei Baevski's avatar
Alexei Baevski committed
20
21
22
23
24
class WordStat(object):
    def __init__(self, word, is_bpe):
        self.word = word
        self.is_bpe = is_bpe
        self.log_prob = 0
Myle Ott's avatar
Myle Ott committed
25
        self.next_word_prob = 0
Alexei Baevski's avatar
Alexei Baevski committed
26
        self.count = 0
Myle Ott's avatar
Myle Ott committed
27
28
29
30
31
32
33
34
35
36
37
        self.missing_next_words = 0

    def add(self, log_prob, next_word_prob):
        """ increments counters for the sum of log probs of current word and next
            word (given context ending at current word). Since the next word might be at the end of the example,
            or it might be not counted because it is not an ending subword unit,
            also keeps track of how many of those we have seen """
        if next_word_prob is not None:
            self.next_word_prob += next_word_prob
        else:
            self.missing_next_words += 1
Alexei Baevski's avatar
Alexei Baevski committed
38
39
40
41
        self.log_prob += log_prob
        self.count += 1

    def __str__(self):
Myle Ott's avatar
Myle Ott committed
42
43
        return '{}\t{}\t{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob, self.is_bpe,
                                               self.next_word_prob, self.count - self.missing_next_words)
Alexei Baevski's avatar
Alexei Baevski committed
44
45


alexeib's avatar
alexeib committed
46
47
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'
alexeib's avatar
alexeib committed
48

Myle Ott's avatar
Myle Ott committed
49
    utils.import_user_module(parsed_args)
50

alexeib's avatar
alexeib committed
51
52
53
54
55
56
57
58
    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
Myle Ott's avatar
Myle Ott committed
59
60
61
62
    models, args = checkpoint_utils.load_model_ensemble(
        parsed_args.path.split(':'),
        arg_overrides=eval(parsed_args.model_overrides),
        task=task,
Myle Ott's avatar
Myle Ott committed
63
    )
alexeib's avatar
alexeib committed
64

alexeib's avatar
alexeib committed
65
    for arg in vars(parsed_args).keys():
Myle Ott's avatar
Myle Ott committed
66
67
68
69
        if arg not in {
            'self_target', 'future_target', 'past_target', 'tokens_per_sample',
            'output_size_dictionary', 'add_bos_token',
        }:
alexeib's avatar
alexeib committed
70
            setattr(args, arg, getattr(parsed_args, arg))
Myle Ott's avatar
Myle Ott committed
71
72
73

    # reduce tokens per sample by the required context window size
    args.tokens_per_sample -= args.context_window
alexeib's avatar
alexeib committed
74
    task = tasks.setup_task(args)
Myle Ott's avatar
Myle Ott committed
75
76
77

    # Load dataset splits
    task.load_dataset(args.gen_subset)
Myle Ott's avatar
Myle Ott committed
78
79
80
81
82
83
84
85
86
    dataset = task.dataset(args.gen_subset)
    if args.context_window > 0:
        dataset = LMContextWindowDataset(
            dataset=dataset,
            tokens_per_sample=args.tokens_per_sample,
            context_window=args.context_window,
            pad_idx=task.source_dictionary.pad(),
        )
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset)))
alexeib's avatar
alexeib committed
87
88
89
90

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
Myle Ott's avatar
Myle Ott committed
91
92
        if args.fp16:
            model.half()
Myle Ott's avatar
Myle Ott committed
93
94
        if use_cuda:
            model.cuda()
alexeib's avatar
alexeib committed
95

alexeib's avatar
alexeib committed
96
97
    assert len(models) > 0

Myle Ott's avatar
Myle Ott committed
98
99
    print('num. model params: {}'.format(sum(p.numel() for p in models[0].parameters())))

100
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
101
        dataset=dataset,
Alexei Baevski's avatar
Alexei Baevski committed
102
103
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
104
105
106
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
Myle Ott's avatar
Myle Ott committed
107
        ignore_invalid_inputs=True,
Myle Ott's avatar
Myle Ott committed
108
109
        num_shards=args.num_shards,
        shard_id=args.shard_id,
Myle Ott's avatar
Myle Ott committed
110
        num_workers=args.num_workers,
Myle Ott's avatar
Myle Ott committed
111
    ).next_epoch_itr(shuffle=False)
alexeib's avatar
alexeib committed
112
113

    gen_timer = StopwatchMeter()
Myle Ott's avatar
Myle Ott committed
114
    scorer = SequenceScorer(task.target_dictionary, args.softmax_batch)
alexeib's avatar
alexeib committed
115
116
117

    score_sum = 0.
    count = 0
Alexei Baevski's avatar
Alexei Baevski committed
118
119

    if args.remove_bpe is not None:
120
121
122
123
        if args.remove_bpe == 'sentencepiece':
            raise NotImplementedError
        else:
            bpe_cont = args.remove_bpe.rstrip()
Myle Ott's avatar
Myle Ott committed
124
125
126
127
128
            bpe_toks = set(
                i
                for i in range(len(task.source_dictionary))
                if task.source_dictionary[i].endswith(bpe_cont)
            )
129
        bpe_len = len(bpe_cont)
Alexei Baevski's avatar
Alexei Baevski committed
130
131
    else:
        bpe_toks = None
132
        bpe_len = 0
Alexei Baevski's avatar
Alexei Baevski committed
133

Alexei Baevski's avatar
Alexei Baevski committed
134
135
    word_stats = dict()

alexeib's avatar
alexeib committed
136
137
    with progress_bar.build_progress_bar(args, itr) as t:
        wps_meter = TimeMeter()
Myle Ott's avatar
Myle Ott committed
138

Myle Ott's avatar
Myle Ott committed
139
140
141
142
        for sample in t:
            if 'net_input' not in sample:
                continue

Myle Ott's avatar
Myle Ott committed
143
144
            sample = utils.move_to_cuda(sample) if use_cuda else sample

Myle Ott's avatar
Myle Ott committed
145
146
147
148
            gen_timer.start()
            hypos = scorer.generate(models, sample)
            gen_timer.stop(sample['ntokens'])

Nathan Ng's avatar
Nathan Ng committed
149
            for i, hypos_i in enumerate(hypos):
Myle Ott's avatar
Myle Ott committed
150
                hypo = hypos_i[0]
Nathan Ng's avatar
Nathan Ng committed
151
                sample_id = sample['id'][i]
Myle Ott's avatar
Myle Ott committed
152
153
154
155

                tokens = hypo['tokens']
                tgt_len = tokens.numel()
                pos_scores = hypo['positional_scores'].float()
Alexei Baevski's avatar
Alexei Baevski committed
156

Myle Ott's avatar
Myle Ott committed
157
158
159
160
161
                if args.add_bos_token:
                    assert hypo['tokens'][0].item() == task.target_dictionary.bos()
                    tokens = tokens[1:]
                    pos_scores = pos_scores[1:]

Alexei Baevski's avatar
Alexei Baevski committed
162
163
                skipped_toks = 0
                if bpe_toks is not None:
Myle Ott's avatar
Myle Ott committed
164
165
                    for i in range(tgt_len - 1):
                        if tokens[i].item() in bpe_toks:
Alexei Baevski's avatar
Alexei Baevski committed
166
167
168
169
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

alexeib's avatar
alexeib committed
170
171
172
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
Myle Ott's avatar
Myle Ott committed
173
                          task.target_dictionary.string(tokens[inf_scores.nonzero()]))
alexeib's avatar
alexeib committed
174
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
Myle Ott's avatar
Myle Ott committed
175
                score_sum += pos_scores.sum().cpu()
Alexei Baevski's avatar
Alexei Baevski committed
176
                count += pos_scores.numel() - skipped_toks
177

Alexei Baevski's avatar
Alexei Baevski committed
178
                if args.output_word_probs or args.output_word_stats:
179
180
                    w = ''
                    word_prob = []
Alexei Baevski's avatar
Alexei Baevski committed
181
                    is_bpe = False
Myle Ott's avatar
Myle Ott committed
182
183
184
                    for i in range(len(tokens)):
                        w_ind = tokens[i].item()
                        w += task.source_dictionary[w_ind]
185
186
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
Alexei Baevski's avatar
Alexei Baevski committed
187
                            is_bpe = True
188
189
                        else:
                            word_prob.append((w, pos_scores[i].item()))
Myle Ott's avatar
Myle Ott committed
190
191
192

                            next_prob = None
                            ind = i + 1
Myle Ott's avatar
Myle Ott committed
193
                            while ind < len(tokens):
Myle Ott's avatar
Myle Ott committed
194
195
196
197
198
199
                                if pos_scores[ind].item() != 0:
                                    next_prob = pos_scores[ind]
                                    break
                                ind += 1

                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item(), next_prob)
Alexei Baevski's avatar
Alexei Baevski committed
200
                            is_bpe = False
201
                            w = ''
Alexei Baevski's avatar
Alexei Baevski committed
202
                    if args.output_word_probs:
Nathan Ng's avatar
Nathan Ng committed
203
204
                        print(str(int(sample_id)) + " " +
                                  ('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob)))
205

Myle Ott's avatar
Myle Ott committed
206
            wps_meter.update(sample['ntokens'])
alexeib's avatar
alexeib committed
207
208
209
210
211
212
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

Alexei Baevski's avatar
Alexei Baevski committed
213
214
215
216
    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)

alexeib's avatar
alexeib committed
217

Myle Ott's avatar
Myle Ott committed
218
def cli_main():
alexeib's avatar
alexeib committed
219
    parser = options.get_eval_lm_parser()
220
    args = options.parse_args_and_arch(parser)
alexeib's avatar
alexeib committed
221
    main(args)
Myle Ott's avatar
Myle Ott committed
222
223
224
225


if __name__ == '__main__':
    cli_main()