Unverified Commit efeab6a3 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s] run_eval/run_eval_search tweaks (#7192)


Co-authored-by: default avatarSam Shleifer <sshleifer@gmail.com>
parent 9c5bcab5
...@@ -15,7 +15,6 @@ except ImportError: ...@@ -15,7 +15,6 @@ except ImportError:
# To add a new task, simply list the score names that `run_eval.run_generate()` returns # To add a new task, simply list the score names that `run_eval.run_generate()` returns
task_score_names = { task_score_names = {
"translation": ["bleu"], "translation": ["bleu"],
"translation_en_to_de": ["bleu"],
"summarization": ["rouge1", "rouge2", "rougeL"], "summarization": ["rouge1", "rouge2", "rougeL"],
} }
...@@ -66,9 +65,7 @@ def run_search(): ...@@ -66,9 +65,7 @@ def run_search():
parser.add_argument( parser.add_argument(
"--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)" "--bs", type=int, default=8, required=False, help="initial batch size (may get reduced if it's too big)"
) )
parser.add_argument( parser.add_argument("--task", type=str, help="used for task_specific_params + metrics")
"--task", type=str, help="used for task_specific_params + metrics", choices=task_score_names.keys()
)
parser.add_argument( parser.add_argument(
"--info", "--info",
nargs="?", nargs="?",
...@@ -81,8 +78,11 @@ def run_search(): ...@@ -81,8 +78,11 @@ def run_search():
args_main.extend(["--task", args.task]) args_main.extend(["--task", args.task])
args_normal = [prog] + args_main args_normal = [prog] + args_main
# to support variations like translation_en_to_de"
task = "translation" if "translation" in args.task else "summarization"
matrix, col_names = parse_search_arg(args.search) matrix, col_names = parse_search_arg(args.search)
col_names[0:0] = task_score_names[args.task] # score cols first col_names[0:0] = task_score_names[task] # score cols first
col_widths = {col: len(str(col)) for col in col_names} col_widths = {col: len(str(col)) for col in col_names}
results = [] results = []
for r in matrix: for r in matrix:
...@@ -96,7 +96,7 @@ def run_search(): ...@@ -96,7 +96,7 @@ def run_search():
scores = run_generate(verbose=False) scores = run_generate(verbose=False)
# make sure scores are first in the table # make sure scores are first in the table
result = OrderedDict() result = OrderedDict()
for score in task_score_names[args.task]: for score in task_score_names[task]:
result[score] = scores[score] result[score] = scores[score]
result.update(hparams) result.update(hparams)
results.append(result) results.append(result)
...@@ -107,14 +107,14 @@ def run_search(): ...@@ -107,14 +107,14 @@ def run_search():
if l > col_widths[k]: if l > col_widths[k]:
col_widths[k] = l col_widths[k] = l
results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[args.task]), reverse=True) results_sorted = sorted(results, key=operator.itemgetter(*task_score_names[task]), reverse=True)
print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names])) print(" | ".join([f"{col:{col_widths[col]}}" for col in col_names]))
print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names])) print(" | ".join([f"{'-'*col_widths[col]}" for col in col_names]))
for row in results_sorted: for row in results_sorted:
print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names])) print(" | ".join([f"{row[col]:{col_widths[col]}}" for col in col_names]))
best = results_sorted[0] best = results_sorted[0]
for score in task_score_names[args.task]: for score in task_score_names[task]:
del best[score] del best[score]
best_args = [f"--{k} {v}" for k, v in best.items()] best_args = [f"--{k} {v}" for k, v in best.items()]
dyn_args = ["--bs", str(args.bs)] dyn_args = ["--bs", str(args.bs)]
......
...@@ -106,6 +106,9 @@ T5_TINY = "patrickvonplaten/t5-tiny-random" ...@@ -106,6 +106,9 @@ T5_TINY = "patrickvonplaten/t5-tiny-random"
BART_TINY = "sshleifer/bart-tiny-random" BART_TINY = "sshleifer/bart-tiny-random"
MBART_TINY = "sshleifer/tiny-mbart" MBART_TINY = "sshleifer/tiny-mbart"
MARIAN_TINY = "sshleifer/tiny-marian-en-de" MARIAN_TINY = "sshleifer/tiny-marian-en-de"
BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM = "google/pegasus-xsum"
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
...@@ -284,8 +287,7 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -284,8 +287,7 @@ class TestSummarizationDistiller(unittest.TestCase):
return model return model
@pytest.mark.parametrize("model", [pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY)]) def run_eval_tester(model):
def test_run_eval(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt" output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists() assert not output_file_name.exists()
...@@ -293,28 +295,39 @@ def test_run_eval(model): ...@@ -293,28 +295,39 @@ def test_run_eval(model):
_dump_articles(input_file_name, articles) _dump_articles(input_file_name, articles)
score_path = str(Path(tempfile.mkdtemp()) / "scores.json") score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
task = "translation_en_to_de" if model == T5_TINY else "summarization" task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = [ testargs = f"""
"run_eval.py", run_eval_search.py
model, {model}
str(input_file_name), {input_file_name}
str(output_file_name), {output_file_name}
"--score_path", --score_path {score_path}
score_path, --task {task}
"--task", --num_beams 2
task, --length_penalty 2.0
"--num_beams", """.split()
"2",
"--length_penalty",
"2.0",
]
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
assert Path(output_file_name).exists() assert Path(output_file_name).exists()
os.remove(Path(output_file_name)) os.remove(Path(output_file_name))
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
def test_run_eval():
run_eval_tester(T5_TINY)
# any extra models should go into the list here - can be slow
@slow @slow
@pytest.mark.parametrize("model", [pytest.param(T5_TINY)]) @pytest.mark.parametrize("model", [BART_TINY, MBART_TINY])
def test_run_eval_slow(model):
run_eval_tester(model)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@slow
@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY])
def test_run_eval_search(model): def test_run_eval_search(model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt" output_file_name = input_file_name.parent / "utest_output.txt"
...@@ -335,20 +348,17 @@ def test_run_eval_search(model): ...@@ -335,20 +348,17 @@ def test_run_eval_search(model):
_dump_articles(input_file_name, text["en"]) _dump_articles(input_file_name, text["en"])
_dump_articles(reference_path, text["de"]) _dump_articles(reference_path, text["de"])
task = "translation_en_to_de" if model == T5_TINY else "summarization" task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = [ testargs = f"""
"run_eval_search.py", run_eval_search.py
model, --model_name {model}
str(input_file_name), --data_dir {str(input_file_name)}
str(output_file_name), --save_dir {str(output_file_name)}
"--score_path", --score_path {score_path}
score_path, --reference_path {reference_path},
"--reference_path", --task {task}
reference_path, --search num_beams=1:2 length_penalty=0.9:1.0
"--task", """.split()
task,
"--search",
"num_beams=1:2 length_penalty=0.9:1.0",
]
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
with CaptureStdout() as cs: with CaptureStdout() as cs:
run_search() run_search()
...@@ -367,8 +377,8 @@ def test_run_eval_search(model): ...@@ -367,8 +377,8 @@ def test_run_eval_search(model):
@pytest.mark.parametrize( @pytest.mark.parametrize(
["model"], "model",
[pytest.param(T5_TINY), pytest.param(BART_TINY), pytest.param(MBART_TINY), pytest.param(MARIAN_TINY)], [T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY],
) )
def test_finetune(model): def test_finetune(model):
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
...@@ -541,13 +551,13 @@ def test_pack_dataset(): ...@@ -541,13 +551,13 @@ def test_pack_dataset():
@pytest.mark.parametrize( @pytest.mark.parametrize(
["tok_name"], "tok_name",
[ [
pytest.param(MBART_TINY), MBART_TINY,
pytest.param(MARIAN_TINY), MARIAN_TINY,
pytest.param(T5_TINY), T5_TINY,
pytest.param(BART_TINY), BART_TINY,
pytest.param("google/pegasus-xsum"), PEGASUS_XSUM,
], ],
) )
def test_seq2seq_dataset_truncation(tok_name): def test_seq2seq_dataset_truncation(tok_name):
...@@ -589,7 +599,7 @@ def test_seq2seq_dataset_truncation(tok_name): ...@@ -589,7 +599,7 @@ def test_seq2seq_dataset_truncation(tok_name):
break # No need to test every batch break # No need to test every batch
@pytest.mark.parametrize(["tok"], [pytest.param(BART_TINY), pytest.param("bert-base-cased")]) @pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED])
def test_legacy_dataset_truncation(tok): def test_legacy_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok) tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment