Unverified Commit 1652ddad authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[seq2seq testing] improve readability (#7845)

parent 466115b2
...@@ -47,58 +47,38 @@ def test_finetune_trainer_slow(): ...@@ -47,58 +47,38 @@ def test_finetune_trainer_slow():
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int): def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
data_dir = "examples/seq2seq/test_data/wmt_en_ro" data_dir = "examples/seq2seq/test_data/wmt_en_ro"
output_dir = tempfile.mkdtemp(prefix="test_output") output_dir = tempfile.mkdtemp(prefix="test_output")
argv = [ argv = f"""
"--model_name_or_path", --model_name_or_path {model_name}
model_name, --data_dir {data_dir}
"--data_dir", --output_dir {output_dir}
data_dir, --overwrite_output_dir
"--output_dir", --n_train 8
output_dir, --n_val 8
"--overwrite_output_dir", --max_source_length {max_len}
"--n_train", --max_target_length {max_len}
"8", --val_max_target_length {max_len}
"--n_val", --do_train
"8", --do_eval
"--max_source_length", --do_predict
max_len, --num_train_epochs {str(num_train_epochs)}
"--max_target_length", --per_device_train_batch_size 4
max_len, --per_device_eval_batch_size 4
"--val_max_target_length", --learning_rate 3e-4
max_len, --warmup_steps 8
"--do_train", --evaluate_during_training
"--do_eval", --predict_with_generate
"--do_predict", --logging_steps 0
"--num_train_epochs", --save_steps {str(eval_steps)}
str(num_train_epochs), --eval_steps {str(eval_steps)}
"--per_device_train_batch_size", --sortish_sampler
"4", --label_smoothing 0.1
"--per_device_eval_batch_size", --adafactor
"4", --task translation
"--learning_rate", --tgt_lang ro_RO
"3e-4", --src_lang en_XX
"--warmup_steps", """.split()
"8", # --eval_beams 2
"--evaluate_during_training",
"--predict_with_generate",
"--logging_steps",
0,
"--save_steps",
str(eval_steps),
"--eval_steps",
str(eval_steps),
"--sortish_sampler",
"--label_smoothing",
"0.1",
# "--eval_beams",
# "2",
"--adafactor",
"--task",
"translation",
"--tgt_lang",
"ro_RO",
"--src_lang",
"en_XX",
]
testargs = ["finetune_trainer.py"] + argv testargs = ["finetune_trainer.py"] + argv
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
main() main()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment