eval_lm.py 2.88 KB
Newer Older
alexeib's avatar
alexeib committed
1
2
3
4
5
6
7
8
9
10
11
#!/usr/bin/env python3 -u
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.

import numpy as np
import torch

Myle Ott's avatar
Myle Ott committed
12
from fairseq import data, options, progress_bar, tasks, utils
alexeib's avatar
alexeib committed
13
14
15
16
17
18
19
from fairseq.meters import StopwatchMeter, TimeMeter
from fairseq.sequence_scorer import SequenceScorer


def main(args):
    assert args.path is not None, '--path required for evaluation!'

20
    args.tokens_per_sample = getattr(args, 'tokens_per_sample', 1024)
Myle Ott's avatar
Myle Ott committed
21
    print(args)
alexeib's avatar
alexeib committed
22
23

    use_cuda = torch.cuda.is_available() and not args.cpu
Myle Ott's avatar
Myle Ott committed
24
25
26
27
28

    # Load dataset splits
    task = tasks.setup_task(args)
    task.load_dataset(args.gen_subset)
    print('| {} {} {} examples'.format(args.data, args.gen_subset, len(task.dataset(args.gen_subset))))
alexeib's avatar
alexeib committed
29
30

    # Load ensemble
alexeib's avatar
alexeib committed
31
    print('| loading model(s) from {}'.format(args.path))
32
    models, _ = utils.load_ensemble_for_inference(args.path.split(':'), task)
alexeib's avatar
alexeib committed
33
34
35
36
37

    # Optimize ensemble for generation and set the source and dest dicts on the model (required by scorer)
    for model in models:
        model.make_generation_fast_()

Myle Ott's avatar
Myle Ott committed
38
39
    itr = data.EpochBatchIterator(
        dataset=task.dataset(args.gen_subset),
alexeib's avatar
alexeib committed
40
        max_sentences=args.max_sentences or 4,
Myle Ott's avatar
Myle Ott committed
41
42
43
44
        max_positions=model.max_positions(),
        num_shards=args.num_shards,
        shard_id=args.shard_id,
    ).next_epoch_itr(shuffle=False)
alexeib's avatar
alexeib committed
45
46

    gen_timer = StopwatchMeter()
Myle Ott's avatar
Myle Ott committed
47
    scorer = SequenceScorer(models, task.target_dictionary)
alexeib's avatar
alexeib committed
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    if use_cuda:
        scorer.cuda()

    score_sum = 0.
    count = 0
    with progress_bar.build_progress_bar(args, itr) as t:
        results = scorer.score_batched_itr(t, cuda=use_cuda, timer=gen_timer)
        wps_meter = TimeMeter()
        for _, src_tokens, __, hypos in results:
            for hypo in hypos:
                pos_scores = hypo['positional_scores']
                inf_scores = pos_scores.eq(float('inf')) | pos_scores.eq(float('-inf'))
                if inf_scores.any():
                    print('| Skipping tokens with inf scores:',
Myle Ott's avatar
Myle Ott committed
62
                          task.target_dictionary.string(hypo['tokens'][inf_scores.nonzero()]))
alexeib's avatar
alexeib committed
63
64
65
66
67
68
69
70
71
72
73
74
75
                    pos_scores = pos_scores[(~inf_scores).nonzero()]
                score_sum += pos_scores.sum()
                count += pos_scores.numel()
            wps_meter.update(src_tokens.size(0))
            t.log({'wps': round(wps_meter.avg)})

    avg_nll_loss = -score_sum / count
    print('| Evaluated {} tokens in {:.1f}s ({:.2f} tokens/s)'.format(gen_timer.n, gen_timer.sum, 1. / gen_timer.avg))
    print('| Loss: {:.4f}, Perplexity: {:.2f}'.format(avg_nll_loss, np.exp(avg_nll_loss)))


if __name__ == '__main__':
    parser = options.get_eval_lm_parser()
76
    args = options.parse_args_and_arch(parser)
alexeib's avatar
alexeib committed
77
    main(args)