"vscode:/vscode.git/clone" did not exist on "7fbab730bd8e91b85e3b2ee2defc9a6de2a09a7c"
eval_lm.py 5.29 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
Myle Ott's avatar
Myle Ott committed
8
9
10
"""
Evaluate the perplexity of a trained language model.
"""
alexeib's avatar
alexeib committed
11
12
13
14

import numpy as np
import torch

Myle Ott's avatar
Myle Ott committed
15
from fairseq import data, options, progress_bar, tasks, utils
alexeib's avatar
alexeib committed
16
17
18
19
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer


Alexei Baevski's avatar
Alexei Baevski committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class WordStat(object):
    def __init__(self, word, is_bpe):
        self.word = word
        self.is_bpe = is_bpe
        self.log_prob = 0
        self.count = 0

    def add(self, log_prob):
        self.log_prob += log_prob
        self.count += 1

    def __str__(self):
        return '{}\t{}\t{}\t{}'.format(self.word, self.count, self.log_prob / self.count, self.is_bpe)


alexeib's avatar
alexeib committed
35
36
def main(parsed_args):
    assert parsed_args.path is not None, '--path required for evaluation!'
alexeib's avatar
alexeib committed
37

alexeib's avatar
alexeib committed
38
39
40
41
42
43
44
45
46
47
48
    print(parsed_args)

    use_cuda = torch.cuda.is_available() and not parsed_args.cpu

    task = tasks.setup_task(parsed_args)

    # Load ensemble
    print('| loading model(s) from {}'.format(parsed_args.path))
    models, args = utils.load_ensemble_for_inference(parsed_args.path.split(':'), task)

    args.__dict__.update(parsed_args.__dict__)
Myle Ott's avatar
Myle Ott committed
49
    print(args)
alexeib's avatar
alexeib committed
50

alexeib's avatar
alexeib committed
51
    task.args = args
Myle Ott's avatar
Myle Ott committed
52
53
54
55

    # Load dataset splits
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
alexeib's avatar
alexeib committed
56
57
58
59

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
Myle Ott's avatar
Myle Ott committed
60
61
        if args.fp16:
            model.half()
alexeib's avatar
alexeib committed
62

alexeib's avatar
alexeib committed
63
64
    assert len(models) > 0

65
    itr = task.get_batch_iterator(
Myle Ott's avatar
Myle Ott committed
66
        dataset=task.dataset(args.gen_subset),
Alexei Baevski's avatar
Alexei Baevski committed
67
68
        max_tokens=args.max_tokens or 36000,
        max_sentences=args.max_sentences,
69
70
71
        max_positions=utils.resolve_max_positions(*[
            model.max_positions() for model in models
        ]),
Myle Ott's avatar
Myle Ott committed
72
73
        num_shards=args.num_shards,
        shard_id=args.shard_id,
74
        ignore_invalid_inputs=True,
Myle Ott's avatar
Myle Ott committed
75
    ).next_epoch_itr(shuffle=False)
alexeib's avatar
alexeib committed
76
77

    gen_timer = StopwatchMeter()
Myle Ott's avatar
Myle Ott committed
78
    scorer = SequenceScorer(models, task.target_dictionary)
alexeib's avatar
alexeib committed
79
80
81
82
83
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0
Alexei Baevski's avatar
Alexei Baevski committed
84
85
86
87

    if args.remove_bpe is not None:
        bpe_cont = args.remove_bpe.rstrip()
        bpe_toks = set(i for i in range(len(task.dictionary)) if task.dictionary[i].endswith(bpe_cont))
88
        bpe_len = len(bpe_cont)
Alexei Baevski's avatar
Alexei Baevski committed
89
90
    else:
        bpe_toks = None
91
        bpe_len = 0
Alexei Baevski's avatar
Alexei Baevski committed
92

Alexei Baevski's avatar
Alexei Baevski committed
93
94
    word_stats = dict()

alexeib's avatar
alexeib committed
95
96
97
98
99
100
    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']
Alexei Baevski's avatar
Alexei Baevski committed
101
102
103
104
105
106
107
108
109

                skipped_toks = 0
                if bpe_toks is not None:
                    for i in range(len(hypo['tokens']) - 1):
                        if hypo['tokens'][i].item() in bpe_toks:
                            skipped_toks += 1
                            pos_scores[i + 1] += pos_scores[i]
                            pos_scores[i] = 0

alexeib's avatar
alexeib committed
110
111
112
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
Myle Ott's avatar
Myle Ott committed
113
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
alexeib's avatar
alexeib committed
114
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
115
                score_sum += utils.item(pos_scores.sum())
Alexei Baevski's avatar
Alexei Baevski committed
116
                count += pos_scores.numel() - skipped_toks
117

Alexei Baevski's avatar
Alexei Baevski committed
118
                if args.output_word_probs or args.output_word_stats:
119
120
                    w = ''
                    word_prob = []
Alexei Baevski's avatar
Alexei Baevski committed
121
                    is_bpe = False
122
123
124
125
126
                    for i in range(len(hypo['tokens'])):
                        w_ind = hypo['tokens'][i].item()
                        w += task.dictionary[w_ind]
                        if bpe_toks is not None and w_ind in bpe_toks:
                            w = w[:-bpe_len]
Alexei Baevski's avatar
Alexei Baevski committed
127
                            is_bpe = True
128
129
                        else:
                            word_prob.append((w, pos_scores[i].item()))
Alexei Baevski's avatar
Alexei Baevski committed
130
131
                            word_stats.setdefault(w, WordStat(w, is_bpe)).add(pos_scores[i].item())
                            is_bpe = False
132
                            w = ''
Alexei Baevski's avatar
Alexei Baevski committed
133
134
                    if args.output_word_probs:
                        print('\t'.join('{} [{:2f}]'.format(x[0], x[1]) for x in word_prob))
135

alexeib's avatar
alexeib committed
136
137
138
139
140
141
142
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))

Alexei Baevski's avatar
Alexei Baevski committed
143
144
145
146
    if args.output_word_stats:
        for ws in sorted(word_stats.values(), key=lambda x: x.count, reverse=True):
            print(ws)

alexeib's avatar
alexeib committed
147
148
149

if __name__ == '__main__':
    parser = options.get_eval_lm_parser()
150
    args = options.parse_args_and_arch(parser)
alexeib's avatar
alexeib committed
151
    main(args)