run_eval_search.py 5.2 KB
Newer Older
1
2
#!/usr/bin/env python

3
4
5
6
7
8
import argparse
import itertools
import operator
import sys
from collections import OrderedDict

9
from run_eval import datetime_now, run_generate
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65


# A table of supported tasks and the list of scores in the order of importance to be sorted by.
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names = {
    "translation": ["bleu"],
    "summarization": ["rouge1", "rouge2", "rougeL"],
}


def parse_search_arg(search):
    groups = search.split()
    entries = {k: vs for k, vs in (g.split("=") for g in groups)}
    entry_names = list(entries.keys())
    sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()]
    matrix = [list(x) for x in itertools.product(*sets)]
    return matrix, entry_names


def run_search():
    """
     Run parametric search over the desired hparam space with help of ``run_eval.py``.

     All the arguments except ``--search`` are passed to ``run_eval.py`` as is. The values inside of "--search" are parsed, reformatted and fed to ``run_eval.py`` as additional args.

    The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.:
    ```
     --search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
    ```
    which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ``run_eval.py`` repeatedly with:

    ```
     --num_beams 5 --length_penalty 0.8 --early_stopping true
     --num_beams 5 --length_penalty 0.8 --early_stopping false
     [...]
     --num_beams 10 --length_penalty 1.2 --early_stopping false
    ```

    On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.


    """
    prog = sys.argv[0]

    parser = argparse.ArgumentParser(
        usage="\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore refer to `run_eval.py -h` for the complete list."
    )
    parser.add_argument(
        "--search",
        type=str,
        required=False,
        help='param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"',
    )
    parser.add_argument(
        "--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
    )
66
    parser.add_argument("--task", type=str, help="used for task_specific_params + metrics")
67
68
69
70
71
72
73
74
75
76
77
78
    parser.add_argument(
        "--info",
        nargs="?",
        type=str,
        const=datetime_now(),
        help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.",
    )
    args, args_main = parser.parse_known_args()
    # we share some of the args
    args_main.extend(["--task", args.task])
    args_normal = [prog] + args_main

79
80
81
    # to support variations like translation_en_to_de"
    task = "translation" if "translation" in args.task else "summarization"

82
    matrix, col_names = parse_search_arg(args.search)
83
    col_names[0:0] = task_score_names[task]  # score cols first
84
85
86
87
88
89
90
91
92
93
94
95
96
    col_widths = {col: len(str(col)) for col in col_names}
    results = []
    for r in matrix:
        hparams = {k: v for k, v in (x.replace("--", "").split() for x in r)}
        args_exp = " ".join(r).split()
        args_exp.extend(["--bs", str(args.bs)])  # in case we need to reduce its size due to CUDA OOM
        sys.argv = args_normal + args_exp

        # XXX: need to trap CUDA OOM and lower args.bs if that happens and retry

        scores = run_generate(verbose=False)
        # make sure scores are first in the table
        result = OrderedDict()
97
        for score in task_score_names[task]:
98
99
100
101
102
103
104
105
106
107
            result[score] = scores[score]
        result.update(hparams)
        results.append(result)

        # find widest entries
        for k, v in result.items():
            l = len(str(v))
            if l > col_widths[k]:
                col_widths[k] = l

108
    results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True)
109
110
111
112
113
114
    print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
    print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
    for row in results_sorted:
        print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))

    best = results_sorted[0]
115
    for score in task_score_names[task]:
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        del best[score]
    best_args = [f"--{k} {v}" for k, v in best.items()]
    dyn_args = ["--bs", str(args.bs)]
    if args.info:
        print(f"\nInfo: {args.info}")
    print("\nBest score args:")
    print(" ".join(args_main + best_args + dyn_args))

    return results_sorted


if __name__ == "__main__":
    # Usage:
    # [normal-run_eval_search.py cmd plus] \
    # --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
    #
    # Example:
    # PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval_search.py $MODEL_NAME \
    # $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target \
    # --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation \
    # --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
    run_search()