interactive.py 3.1 KB
Newer Older
Louis Martin's avatar
Louis Martin committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
#!/usr/bin/env python3
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
#
import sys
import torch
from torch.autograd import Variable

from fairseq import data, options, tokenizer, utils
from fairseq.sequence_generator import SequenceGenerator


def main():
    parser = options.get_parser('Generation')
    parser.add_argument('--path', metavar='FILE', required=True, action='append',
                        help='path(s) to model file(s)')
    options.add_dataset_args(parser)
    options.add_generation_args(parser)

    args = parser.parse_args()
    print(args)

    use_cuda = torch.cuda.is_available() and not args.cpu

Myle Ott's avatar
Myle Ott committed
29
    # Load dictionaries
Louis Martin's avatar
Louis Martin committed
30
    if args.source_lang is None or args.target_lang is None:
Myle Ott's avatar
Myle Ott committed
31
32
        args.source_lang, args.target_lang, _ = data.infer_language_pair(args.data, ['test'])
    src_dict, dst_dict = data.load_dictionaries(args.data, args.source_lang, args.target_lang)
Louis Martin's avatar
Louis Martin committed
33
34
35

    # Load ensemble
    print('| loading model(s) from {}'.format(', '.join(args.path)))
Myle Ott's avatar
Myle Ott committed
36
    models = utils.load_ensemble_for_inference(args.path, src_dict, dst_dict)
Louis Martin's avatar
Louis Martin committed
37

Myle Ott's avatar
Myle Ott committed
38
39
    print('| [{}] dictionary: {} types'.format(args.source_lang, len(src_dict)))
    print('| [{}] dictionary: {} types'.format(args.target_lang, len(dst_dict)))
Louis Martin's avatar
Louis Martin committed
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

    # Optimize ensemble for generation
    for model in models:
        model.make_generation_fast_(
            beamable_mm_beam_size=None if args.no_beamable_mm else args.beam)

    # Initialize generator
    translator = SequenceGenerator(
        models, beam_size=args.beam, stop_early=(not args.no_early_stop),
        normalize_scores=(not args.unnormalized), len_penalty=args.lenpen,
        unk_penalty=args.unkpen)
    if use_cuda:
        translator.cuda()

    # Load alignment dictionary for unknown word replacement
    # (None if no unknown word replacement, empty if no path to align dictionary)
    align_dict = utils.load_align_dict(args.replace_unk)

    print('Type the input sentence and press return:')
    for src_str in sys.stdin:
        src_str = src_str.strip()
Myle Ott's avatar
Myle Ott committed
61
        src_tokens = tokenizer.Tokenizer.tokenize(src_str, src_dict, add_if_not_exist=False).long()
Louis Martin's avatar
Louis Martin committed
62
63
64
65
66
67
68
69
70
71
72
73
74
        if use_cuda:
            src_tokens = src_tokens.cuda()
        translations = translator.generate(Variable(src_tokens.view(1, -1)))
        hypos = translations[0]
        print('O\t{}'.format(src_str))

        # Process top predictions
        for hypo in hypos[:min(len(hypos), args.nbest)]:
            hypo_tokens, hypo_str, alignment = utils.post_process_prediction(
                hypo_tokens=hypo['tokens'].int().cpu(),
                src_str=src_str,
                alignment=hypo['alignment'].int().cpu(),
                align_dict=align_dict,
Myle Ott's avatar
Myle Ott committed
75
                dst_dict=dst_dict,
Louis Martin's avatar
Louis Martin committed
76
77
78
79
80
81
                remove_bpe=args.remove_bpe)
            print('A\t{}'.format(' '.join(map(str, alignment))))
            print('H\t{}\t{}'.format(hypo['score'], hypo_str))

if __name__ == '__main__':
    main()