rerank_score_bw.py 4.1 KB
Newer Older
Nathan Ng's avatar
Nathan Ng committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import rerank_utils
import os
from fairseq import options
from examples.noisychannel import rerank_options
from contextlib import redirect_stdout
import generate


def score_bw(args):
        if args.backwards1:
            scorer1_src = args.target_lang
            scorer1_tgt = args.source_lang
        else:
            scorer1_src = args.source_lang
            scorer1_tgt = args.target_lang

        if args.score_model2 is not None:
            if args.backwards2:
                scorer2_src = args.target_lang
                scorer2_tgt = args.source_lang
            else:
                scorer2_src = args.source_lang
                scorer2_tgt = args.target_lang

        rerank1_is_gen = args.gen_model == args.score_model1 and args.source_prefix_frac is None
        rerank2_is_gen = args.gen_model == args.score_model2 and args.source_prefix_frac is None

        pre_gen, left_to_right_preprocessed_dir, right_to_left_preprocessed_dir, \
            backwards_preprocessed_dir, lm_preprocessed_dir = \
            rerank_utils.get_directories(args.data_dir_name, args.num_rescore, args.gen_subset,
                                         args.gen_model_name, args.shard_id, args.num_shards,
                                         args.sampling, args.prefix_len, args.target_prefix_frac,
                                         args.source_prefix_frac)

        score1_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model1_name,
                                                     target_prefix_frac=args.target_prefix_frac,
                                                     source_prefix_frac=args.source_prefix_frac,
                                                     backwards=args.backwards1)

        if args.score_model2 is not None:
            score2_file = rerank_utils.rescore_file_name(pre_gen, args.prefix_len, args.model2_name,
                                                         target_prefix_frac=args.target_prefix_frac,
                                                         source_prefix_frac=args.source_prefix_frac,
                                                         backwards=args.backwards2)

        if args.right_to_left1:
            rerank_data1 = right_to_left_preprocessed_dir
        elif args.backwards1:
            rerank_data1 = backwards_preprocessed_dir
        else:
            rerank_data1 = left_to_right_preprocessed_dir

        gen_param = ["--batch-size", str(128), "--score-reference", "--gen-subset", "train"]
        if not rerank1_is_gen and not os.path.isfile(score1_file):
            print("STEP 4: score the translations for model 1")

            model_param1 = ["--path", args.score_model1, "--source-lang", scorer1_src, "--target-lang", scorer1_tgt]
            gen_model1_param = [rerank_data1] + gen_param + model_param1

            gen_parser = options.get_generation_parser()
            input_args = options.parse_args_and_arch(gen_parser, gen_model1_param)

            with open(score1_file, 'w') as f:
                with redirect_stdout(f):
                    generate.main(input_args)

        if args.score_model2 is not None and not os.path.isfile(score2_file) and not rerank2_is_gen:
            print("STEP 4: score the translations for model 2")

            if args.right_to_left2:
                rerank_data2 = right_to_left_preprocessed_dir
            elif args.backwards2:
                rerank_data2 = backwards_preprocessed_dir
            else:
                rerank_data2 = left_to_right_preprocessed_dir

            model_param2 = ["--path", args.score_model2, "--source-lang", scorer2_src, "--target-lang", scorer2_tgt]
            gen_model2_param = [rerank_data2] + gen_param + model_param2

            gen_parser = options.get_generation_parser()
            input_args = options.parse_args_and_arch(gen_parser, gen_model2_param)

            with open(score2_file, 'w') as f:
                with redirect_stdout(f):
                    generate.main(input_args)


def cli_main():
    parser = rerank_options.get_reranking_parser()
    args = options.parse_args_and_arch(parser)
    score_bw(args)


if __name__ == '__main__':
    cli_main()