run_eval_search.py 5.87 KB
Newer Older
1
#!/usr/bin/env python
Sylvain Gugger's avatar
Sylvain Gugger committed
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15

16
17
18
19
20
21
import argparse
import itertools
import operator
import sys
from collections import OrderedDict

22
from run_eval import datetime_now, run_generate
23

24
from utils import ROUGE_KEYS
25
26
27
28
29
30


# A table of supported tasks and the list of scores in the order of importance to be sorted by.
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names = {
    "translation": ["bleu"],
31
    "summarization": ROUGE_KEYS,
32
33
34
35
36
37
38
}


def parse_search_arg(search):
    groups = search.split()
    entries = {k: vs for k, vs in (g.split("=") for g in groups)}
    entry_names = list(entries.keys())
39
    sets = [[f"--{k} {v}" for v in vs.split(":")] for k, vs in entries.items()]
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
    matrix = [list(x) for x in itertools.product(*sets)]
    return matrix, entry_names


def run_search():
    """
     Run parametric search over the desired hparam space with help of ``run_eval.py``.

     All the arguments except ``--search`` are passed to ``run_eval.py`` as is. The values inside of "--search" are parsed, reformatted and fed to ``run_eval.py`` as additional args.

    The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.:
    ```
     --search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
    ```
    which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ``run_eval.py`` repeatedly with:

    ```
     --num_beams 5 --length_penalty 0.8 --early_stopping true
     --num_beams 5 --length_penalty 0.8 --early_stopping false
     [...]
     --num_beams 10 --length_penalty 1.2 --early_stopping false
    ```

    On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.


    """
    prog = sys.argv[0]

    parser = argparse.ArgumentParser(
Sylvain Gugger's avatar
Sylvain Gugger committed
70
71
72
73
        usage=(
            "\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore"
            " refer to `run_eval.py -h` for the complete list."
        )
74
75
76
77
78
79
80
81
82
83
    )
    parser.add_argument(
        "--search",
        type=str,
        required=False,
        help='param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"',
    )
    parser.add_argument(
        "--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
    )
84
    parser.add_argument("--task", type=str, help="used for task_specific_params + metrics")
85
86
87
88
89
    parser.add_argument(
        "--info",
        nargs="?",
        type=str,
        const=datetime_now(),
Sylvain Gugger's avatar
Sylvain Gugger committed
90
91
92
93
        help=(
            "add custom notes to be printed before the results table. If no value is passed, the current datetime"
            " string will be used."
        ),
94
95
96
97
98
99
    )
    args, args_main = parser.parse_known_args()
    # we share some of the args
    args_main.extend(["--task", args.task])
    args_normal = [prog] + args_main

100
101
102
    # to support variations like translation_en_to_de"
    task = "translation" if "translation" in args.task else "summarization"

103
    matrix, col_names = parse_search_arg(args.search)
104
    col_names[0:0] = task_score_names[task]  # score cols first
105
106
107
108
109
110
111
112
113
114
115
116
117
    col_widths = {col: len(str(col)) for col in col_names}
    results = []
    for r in matrix:
        hparams = {k: v for k, v in (x.replace("--", "").split() for x in r)}
        args_exp = " ".join(r).split()
        args_exp.extend(["--bs", str(args.bs)])  # in case we need to reduce its size due to CUDA OOM
        sys.argv = args_normal + args_exp

        # XXX: need to trap CUDA OOM and lower args.bs if that happens and retry

        scores = run_generate(verbose=False)
        # make sure scores are first in the table
        result = OrderedDict()
118
        for score in task_score_names[task]:
119
120
121
122
123
124
125
126
127
128
            result[score] = scores[score]
        result.update(hparams)
        results.append(result)

        # find widest entries
        for k, v in result.items():
            l = len(str(v))
            if l > col_widths[k]:
                col_widths[k] = l

129
    results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True)
130
131
132
133
134
135
    print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
    print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
    for row in results_sorted:
        print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))

    best = results_sorted[0]
136
    for score in task_score_names[task]:
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        del best[score]
    best_args = [f"--{k} {v}" for k, v in best.items()]
    dyn_args = ["--bs", str(args.bs)]
    if args.info:
        print(f"\nInfo: {args.info}")
    print("\nBest score args:")
    print(" ".join(args_main + best_args + dyn_args))

    return results_sorted


if __name__ == "__main__":
    # Usage:
    # [normal-run_eval_search.py cmd plus] \
    # --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
    #
    # Example:
    # PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval_search.py $MODEL_NAME \
    # $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target \
    # --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation \
    # --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
    run_search()