"vscode:/vscode.git/clone" did not exist on "bcf6cda14cf2f52f8a0a0021230eaa39143840c0"
eval_lm.py 3.2 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import numpy as np
import torch

from fairseq import options, utils, progress_bar
from fairseq.data import data_utils, data_loaders
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer


def main(args):
    assert args.path is not None, '--path required for evaluation!'
    print(args)

    if args.max_target_positions is None:
        args.max_target_positions = 1024

    use_cuda = torch.cuda.is_available() and not args.cpu
    dataset = data_loaders.load_dataset(args, [args.gen_subset], False)

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
    models, _ = utils.load_ensemble_for_inference(args.path, dataset.src_dict, dataset.dst_dict)

    print('| Dictionary: {} types'.format(len(dataset.src_dict)))
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(dataset.splits[args.gen_subset])))

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()
        model.src_dict = dataset.src_dict
        model.dst_dict = dataset.dst_dict

    itr = dataset.eval_dataloader(
        args.gen_subset,
        max_sentences=args.max_sentences or 4,
        max_positions=args.max_target_positions or 1024,
        descending=True,
    )

    if args.num_shards > 1:
        if args.shard_id < 0 or args.shard_id >= args.num_shards:
            raise ValueError('--shard-id must be between 0 and num_shards')
        itr = data_utils.sharded_iterator(itr, args.num_shards, args.shard_id)

    gen_timer = StopwatchMeter()
    scorer = SequenceScorer(models)
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
                          dataset.src_dict.string(hypo['tokens'][inf_scores.nonzero()]))
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum()
                count += pos_scores.numel()
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))


if __name__ == '__main__':
    parser = options.get_eval_lm_parser()
    args = parser.parse_args()
    main(args)