Unverified Commit fdaf8ab3 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s run_eval] new features (#7109)


Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent df165065
......@@ -46,7 +46,7 @@ export DATA_DIR=${PWD}/wmt_en_de
#### Private Data
If you are using your own data, it must be formatted as one directory with 6 files:
If you are using your own data, it must be formatted as one directory with 6 files:
```
train.source
train.target
......@@ -228,6 +228,67 @@ python run_eval.py sshleifer/distilbart-cnn-12-6 $DATA_DIR/val.source dbart_val_
--bs 32
```
#### run_eval tips and tricks
When using `run_eval.py`, the following features can be useful:
* if you running the script multiple times and want to make it easier to track what arguments produced that output, use `--dump-args`. Along with the results it will also dump any custom params that were passed to the script. For example if you used: `--num_beams 8 --early_stopping true`, the output will be:
```
{'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True}
```
`--info` is an additional argument available for the same purpose of tracking the conditions of the experiment. It's useful to pass things that weren't in the argument list, e.g. a language pair `--info "lang:en-ru"`. But also if you pass `--info` without a value it will fallback to the current date/time string, e.g. `2020-09-13 18:44:43`.
If using `--dump-args --info`, the output will be:
```
{'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': '2020-09-13 18:44:43'}
```
If using `--dump-args --info "pair:en-ru chkpt=best`, the output will be:
```
{'bleu': 26.887, 'n_obs': 10, 'runtime': 1, 'seconds_per_sample': 0.1, 'num_beams': 8, 'early_stopping': True, 'info': 'pair=en-ru chkpt=best'}
```
* if you need to perform a parametric search in order to find the best ones that lead to the highest BLEU score, let `run_eval_search.py` to do the searching for you.
The script accepts the exact same arguments as `run_eval.py`, plus an additional argument `--search`. The value of `--search` is parsed, reformatted and fed to ``run_eval.py`` as additional args.
The format for the `--search` value is a simple string with hparams and colon separated values to try, e.g.:
```
--search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
```
which will generate `12` `(2*3*2)` searches for a product of each hparam. For example the example that was just used will invoke `run_eval.py` repeatedly with:
```
--num_beams 5 --length_penalty 0.8 --early_stopping true
--num_beams 5 --length_penalty 0.8 --early_stopping false
[...]
--num_beams 10 --length_penalty 1.2 --early_stopping false
```
On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.
```
bleu | num_beams | length_penalty | early_stopping
----- | --------- | -------------- | --------------
26.71 | 5 | 1.1 | 1
26.66 | 5 | 0.9 | 1
26.66 | 5 | 0.9 | 0
26.41 | 5 | 1.1 | 0
21.94 | 1 | 0.9 | 1
21.94 | 1 | 0.9 | 0
21.94 | 1 | 1.1 | 1
21.94 | 1 | 1.1 | 0
Best score args:
stas/wmt19-en-ru data/en-ru/val.source data/en-ru/test_translations.txt --reference_path data/en-ru/val.target --score_path data/en-ru/test_bleu.json --bs 8 --task translation --num_beams 5 --length_penalty 1.1 --early_stopping True
```
If you pass `--info "some experiment-specific info"` it will get printed before the results table - this is useful for scripting and multiple runs, so one can tell the different sets of results from each other.
### DistilBART
![DBART](https://huggingface.co/front/thumbnails/distilbart_large.png)
......
import argparse
import datetime
import json
import time
import warnings
......@@ -15,9 +16,9 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
logger = getLogger(__name__)
try:
from .utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
from .utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
except ImportError:
from utils import calculate_bleu, calculate_rouge, parse_numeric_cl_kwargs, use_task_specific_params
from utils import calculate_bleu, calculate_rouge, parse_numeric_n_bool_cl_kwargs, use_task_specific_params
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
......@@ -72,7 +73,26 @@ def generate_summaries_or_translations(
return dict(n_obs=n_obs, runtime=runtime, seconds_per_sample=round(runtime / n_obs, 4))
def run_generate():
def datetime_now():
return datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def run_generate(verbose=True):
"""
Takes input text, generates output, and then using reference calculates the BLEU scores.
The results are saved to a file and returned to the caller, and printed out unless ``verbose=False`` is passed.
Args:
verbose (:obj:`bool`, `optional`, defaults to :obj:`True`): print results to stdout
Returns:
a tuple: ``(scores, params}``
- ``scores``: a dict of scores data ``{'bleu': 39.6501, 'n_obs': 2000, 'runtime': 186, 'seconds_per_sample': 0.093}``
- ``params``: a dict of custom params, e.g. ``{'num_beams': 5, 'length_penalty': 0.8}``
"""
parser = argparse.ArgumentParser()
parser.add_argument("model_name", type=str, help="like facebook/bart-large-cnn,t5-base, etc.")
parser.add_argument("input_path", type=str, help="like cnn_dm/test.source")
......@@ -89,11 +109,19 @@ def run_generate():
"--n_obs", type=int, default=-1, required=False, help="How many observations. Defaults to all."
)
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--dump-args", action="store_true", help="print the custom hparams with the results")
parser.add_argument(
"--info",
nargs="?",
type=str,
const=datetime_now(),
help="use in conjunction w/ --dump-args to print with the results whatever other info you'd like, e.g. lang=en-ru. If no value is passed, the current datetime string will be used.",
)
# Unspecified args like --num_beams=2 --decoder_start_token_id=4 are passed to model.generate
args, rest = parser.parse_known_args()
parsed = parse_numeric_cl_kwargs(rest)
if parsed:
print(f"parsed the following generate kwargs: {parsed}")
parsed_args = parse_numeric_n_bool_cl_kwargs(rest)
if parsed_args and verbose:
print(f"parsed the following generate kwargs: {parsed_args}")
examples = [" " + x.rstrip() if "t5" in args.model_name else x.rstrip() for x in open(args.input_path).readlines()]
if args.n_obs > 0:
examples = examples[: args.n_obs]
......@@ -109,23 +137,35 @@ def run_generate():
fp16=args.fp16,
task=args.task,
prefix=args.prefix,
**parsed,
**parsed_args,
)
if args.reference_path is None:
return
return {}
# Compute scores
score_fn = calculate_bleu if "translation" in args.task else calculate_rouge
output_lns = [x.rstrip() for x in open(args.save_path).readlines()]
reference_lns = [x.rstrip() for x in open(args.reference_path).readlines()][: len(output_lns)]
scores: dict = score_fn(output_lns, reference_lns)
scores.update(runtime_metrics)
print(scores)
if args.dump_args:
scores.update(parsed_args)
if args.info:
scores["info"] = args.info
if verbose:
print(*scores)
if args.score_path is not None:
json.dump(scores, open(args.score_path, "w"))
path = args.score_path
json.dump(scores, open(path, "w"))
return scores
if __name__ == "__main__":
# Usage for MT:
# python run_eval.py MODEL_NAME $DATA_DIR/test.source $save_dir/test_translations.txt --reference_path $DATA_DIR/test.target --score_path $save_dir/test_bleu.json --task translation $@
run_generate()
run_generate(verbose=True)
import argparse
import itertools
import operator
import sys
from collections import OrderedDict
try:
from .run_eval import datetime_now, run_generate
except ImportError:
from run_eval import datetime_now, run_generate
# A table of supported tasks and the list of scores in the order of importance to be sorted by.
# To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names = {
"translation": ["bleu"],
"translation_en_to_de": ["bleu"],
"summarization": ["rouge1", "rouge2", "rougeL"],
}
def parse_search_arg(search):
groups = search.split()
entries = {k: vs for k, vs in (g.split("=") for g in groups)}
entry_names = list(entries.keys())
sets = [list((f"--{k} {v}") for v in vs.split(":")) for k, vs in entries.items()]
matrix = [list(x) for x in itertools.product(*sets)]
return matrix, entry_names
def run_search():
"""
Run parametric search over the desired hparam space with help of ``run_eval.py``.
All the arguments except ``--search`` are passed to ``run_eval.py`` as is. The values inside of "--search" are parsed, reformatted and fed to ``run_eval.py`` as additional args.
The format for the ``--search`` value is a simple string with hparams and colon separated values to try, e.g.:
```
--search "num_beams=5:10 length_penalty=0.8:1.0:1.2 early_stopping=true:false"
```
which will generate ``12`` ``(2*3*2)`` searches for a product of each hparam. For example the example that was just used will invoke ``run_eval.py`` repeatedly with:
```
--num_beams 5 --length_penalty 0.8 --early_stopping true
--num_beams 5 --length_penalty 0.8 --early_stopping false
[...]
--num_beams 10 --length_penalty 1.2 --early_stopping false
```
On completion, this function prints a markdown table of the results sorted by the best BLEU score and the winning arguments.
"""
prog = sys.argv[0]
parser = argparse.ArgumentParser(
usage="\n\nImportant: this script accepts all arguments `run_eval.py` accepts and then a few extra, therefore refer to `run_eval.py -h` for the complete list."
)
parser.add_argument(
"--search",
type=str,
required=False,
help='param space to search, e.g. "num_beams=5:10 length_penalty=0.8:1.0:1.2"',
)
parser.add_argument(
"--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
)
parser.add_argument(
"--task", type=str, help="used for task_specific_params + metrics", choices=task_score_names.keys()
)
parser.add_argument(
"--info",
nargs="?",
type=str,
const=datetime_now(),
help="add custom notes to be printed before the results table. If no value is passed, the current datetime string will be used.",
)
args, args_main = parser.parse_known_args()
# we share some of the args
args_main.extend(["--task", args.task])
args_normal = [prog] + args_main
matrix, col_names = parse_search_arg(args.search)
col_names[0:0] = task_score_names[args.task] # score cols first
col_widths = {col: len(str(col)) for col in col_names}
results = []
for r in matrix:
hparams = {k: v for k, v in (x.replace("--", "").split() for x in r)}
args_exp = " ".join(r).split()
args_exp.extend(["--bs", str(args.bs)]) # in case we need to reduce its size due to CUDA OOM
sys.argv = args_normal + args_exp
# XXX: need to trap CUDA OOM and lower args.bs if that happens and retry
scores = run_generate(verbose=False)
# make sure scores are first in the table
result = OrderedDict()
for score in task_score_names[args.task]:
result[score] = scores[score]
result.update(hparams)
results.append(result)
# find widest entries
for k, v in result.items():
l = len(str(v))
if l > col_widths[k]:
col_widths[k] = l
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[args.task]), reverse=True)
print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
for row in results_sorted:
print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))
best = results_sorted[0]
for score in task_score_names[args.task]:
del best[score]
best_args = [f"--{k} {v}" for k, v in best.items()]
dyn_args = ["--bs", str(args.bs)]
if args.info:
print(f"\nInfo: {args.info}")
print("\nBest score args:")
print(" ".join(args_main + best_args + dyn_args))
return results_sorted
if __name__ == "__main__":
# Usage:
# [normal-run_eval_search.py cmd plus] \
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
#
# Example:
# PYTHONPATH="src:examples/seq2seq" python examples/seq2seq/run_eval_search.py $MODEL_NAME \
# $DATA_DIR/val.source $SAVE_DIR/test_translations.txt --reference_path $DATA_DIR/val.target \
# --score_path $SAVE_DIR/test_bleu.json --bs $BS --task translation \
# --search="num_beams=1:5:10 length_penalty=0.8:1:1.2 early_stopping=true:false"
run_search()
......@@ -23,6 +23,7 @@ from .distillation import distill_main, evaluate_checkpoint
from .finetune import SummarizationModule, main
from .pack_dataset import pack_data_dir
from .run_eval import generate_summaries_or_translations, run_generate
from .run_eval_search import run_search
from .utils import LegacySeq2SeqDataset, Seq2SeqDataset, label_smoothed_nll_loss, lmap, load_json
......@@ -283,7 +284,7 @@ class TestSummarizationDistiller(unittest.TestCase):
return model
@pytest.mark.parametrize(["model"], [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)])
def test_run_eval(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
......@@ -311,6 +312,58 @@ def test_run_eval(model):
assert Path(output_file_name).exists()
os.remove(Path(output_file_name))
@slow
@pytest.mark.parametrize("model", [pytest.param(T5_TINY)])
def test_run_eval_search(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists()
text = {
"en": ["Machine learning is great, isn't it?", "I like to eat bananas", "Tomorrow is another great day!"],
"de": [
"Maschinelles Lernen ist großartig, oder?",
"Ich esse gerne Bananen",
"Morgen ist wieder ein toller Tag!",
],
}
tmp_dir = Path(tempfile.mkdtemp())
score_path = str(tmp_dir / "scores.json")
reference_path = str(tmp_dir / "val.target")
_dump_articles(input_file_name, text["en"])
_dump_articles(reference_path, text["de"])
task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = [
"run_eval_search.py",
model,
str(input_file_name),
str(output_file_name),
"--score_path",
score_path,
"--reference_path",
reference_path,
"--task",
task,
"--search",
"num_beams=1:2 length_penalty=0.9:1.0",
]
with patch.object(sys, "argv", testargs):
with CaptureStdout() as cs:
run_search()
expected_strings = [" num_beams | length_penalty", model, "Best score args"]
un_expected_strings = ["Info"]
if "translation" in task:
expected_strings.append("bleu")
else:
expected_strings.extend(["rouge1", "rouge2", "rougeL"])
for w in expected_strings:
assert w in cs.out
for w in un_expected_strings:
assert w not in cs.out
assert Path(output_file_name).exists()
os.remove(Path(output_file_name))
@pytest.mark.parametrize(
["model"],
......
......@@ -377,18 +377,26 @@ def assert_not_all_frozen(model):
# CLI Parsing utils
def parse_numeric_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float]]:
"""Parse an argv list of unspecified command line args to a dict. Assumes all values are numeric."""
def parse_numeric_n_bool_cl_kwargs(unparsed_args: List[str]) -> Dict[str, Union[int, float, bool]]:
"""
Parse an argv list of unspecified command line args to a dict.
Assumes all values are either numeric or boolean in the form of true/false.
"""
result = {}
assert len(unparsed_args) % 2 == 0, f"got odd number of unparsed args: {unparsed_args}"
num_pairs = len(unparsed_args) // 2
for pair_num in range(num_pairs):
i = 2 * pair_num
assert unparsed_args[i].startswith("--")
try:
value = int(unparsed_args[i + 1])
except ValueError:
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
if unparsed_args[i + 1].lower() == "true":
value = True
elif unparsed_args[i + 1].lower() == "false":
value = False
else:
try:
value = int(unparsed_args[i + 1])
except ValueError:
value = float(unparsed_args[i + 1]) # this can raise another informative ValueError
result[unparsed_args[i][2:]] = value
return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment