Unverified Commit 9edafaeb authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s] test_bash_script.py - actually learn something (#8318)

* use decorator

* remove hardcoded paths

* make the test use more data and do real quality tests

* shave off 10 secs

* add --eval_beams 2, reformat

* reduce train size, use smaller custom dataset
parent 17450397
......@@ -3,92 +3,107 @@
import argparse
import os
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
import pytorch_lightning as pl
import timeout_decorator
import torch
from distillation import BartSummarizationDistiller, distill_main
from finetune import SummarizationModule, main
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import TestCasePlus, slow
from transformers import MarianMTModel
from transformers.file_utils import cached_path
from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow
from utils import load_json
MODEL_NAME = MBART_TINY
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
MARIAN_MODEL = "sshleifer/mar_enro_6_3_student"
class TestAll(TestCasePlus):
class TestMbartCc25Enro(TestCasePlus):
def setUp(self):
super().setUp()
data_cached = cached_path(
"https://cdn-datasets.huggingface.co/translation/wmt_en_ro-tr40k-va0.5k-te0.5k.tar.gz",
extract_compressed_file=True,
)
self.data_dir = f"{data_cached}/wmt_en_ro-tr40k-va0.5k-te0.5k"
@slow
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
@require_torch_gpu
def test_model_download(self):
"""This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
MarianMTModel.from_pretrained(MARIAN_MODEL)
@timeout_decorator.timeout(120)
# @timeout_decorator.timeout(1200)
@slow
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
@require_torch_gpu
def test_train_mbart_cc25_enro_script(self):
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace = {
"--fp16_opt_level=O1": "",
"$MAX_LEN": 128,
"$BS": 4,
"$MAX_LEN": 64,
"$BS": 64,
"$GAS": 1,
"$ENRO_DIR": data_dir,
"facebook/mbart-large-cc25": MODEL_NAME,
# Download is 120MB in previous test.
"val_check_interval=0.25": "val_check_interval=1.0",
"$ENRO_DIR": self.data_dir,
"facebook/mbart-large-cc25": MARIAN_MODEL,
# "val_check_interval=0.25": "val_check_interval=1.0",
"--learning_rate=3e-5": "--learning_rate 3e-4",
"--num_train_epochs 6": "--num_train_epochs 1",
}
# Clean up bash script
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v))
output_dir = self.get_auto_remove_tmp_dir()
bash_script = bash_script.replace("--fp16 ", "")
testargs = (
["finetune.py"]
+ bash_script.split()
+ [
f"--output_dir={output_dir}",
"--gpus=1",
"--learning_rate=3e-1",
"--warmup_steps=0",
"--val_check_interval=1.0",
"--tokenizer_name=facebook/mbart-large-en-ro",
]
)
# bash_script = bash_script.replace("--fp16 ", "")
args = f"""
--output_dir {output_dir}
--tokenizer_name Helsinki-NLP/opus-mt-en-ro
--sortish_sampler
--do_predict
--gpus 1
--freeze_encoder
--n_train 40000
--n_val 500
--n_test 500
--fp16_opt_level O1
--num_sanity_val_steps 0
--eval_beams 2
""".split()
# XXX: args.gpus > 1 : handle multigpu in the future
testargs = ["finetune.py"] + bash_script.split() + args
with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args()
args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu
model = main(args)
# Check metrics
metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1]
assert (
len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1
) # +1 accounts for val_sanity_check
self.assertEqual(len(metrics["val"]), (args.max_epochs / args.val_check_interval))
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
assert last_step_stats["val_avg_gen_time"] >= 0.01
self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01)
# model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?)
self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0)
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
# test learning requirements:
# 1. BLEU improves over the course of training by more than 2 pts
self.assertGreater(last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"], 2)
# 2. BLEU finishes above 17
self.assertGreater(last_step_stats["val_avg_bleu"], 17)
# 3. test BLEU and val BLEU within ~1.1 pt.
self.assertLess(abs(metrics["val"][-1]["val_avg_bleu"] - metrics["test"][-1]["test_avg_bleu"]), 1.1)
# check lightning ckpt can be loaded and has a reasonable statedict
contents = os.listdir(output_dir)
......@@ -107,11 +122,13 @@ class TestAll(TestCasePlus):
# assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1
class TestDistilMarianNoTeacher(TestCasePlus):
@timeout_decorator.timeout(600)
@slow
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
@require_torch_gpu
def test_opus_mt_distill_script(self):
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro"
env_vars_to_replace = {
"--fp16_opt_level=O1": "",
"$MAX_LEN": 128,
......@@ -124,7 +141,7 @@ class TestAll(TestCasePlus):
# Clean up bash script
bash_script = (
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
(self.test_file_dir / "distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
)
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
bash_script = bash_script.replace("--fp16 ", " ")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment