Unverified Commit 9edafaeb authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s] test_bash_script.py - actually learn something (#8318)

* use decorator

* remove hardcoded paths

* make the test use more data and do real quality tests

* shave off 10 secs

* add --eval_beams 2, reformat

* reduce train size, use smaller custom dataset
parent 17450397
...@@ -3,92 +3,107 @@ ...@@ -3,92 +3,107 @@
import argparse import argparse
import os import os
import sys import sys
from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
import pytest
import pytorch_lightning as pl import pytorch_lightning as pl
import timeout_decorator import timeout_decorator
import torch import torch
from distillation import BartSummarizationDistiller, distill_main from distillation import BartSummarizationDistiller, distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY from transformers import MarianMTModel
from transformers import BartForConditionalGeneration, MarianMTModel from transformers.file_utils import cached_path
from transformers.testing_utils import TestCasePlus, slow from transformers.testing_utils import TestCasePlus, require_torch_gpu, slow
from utils import load_json from utils import load_json
MODEL_NAME = MBART_TINY MARIAN_MODEL = "sshleifer/mar_enro_6_3_student"
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
class TestAll(TestCasePlus): class TestMbartCc25Enro(TestCasePlus):
def setUp(self):
super().setUp()
data_cached = cached_path(
"https://cdn-datasets.huggingface.co/translation/wmt_en_ro-tr40k-va0.5k-te0.5k.tar.gz",
extract_compressed_file=True,
)
self.data_dir = f"{data_cached}/wmt_en_ro-tr40k-va0.5k-te0.5k"
@slow @slow
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") @require_torch_gpu
def test_model_download(self): def test_model_download(self):
"""This warms up the cache so that we can time the next test without including download time, which varies between machines.""" """This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration.from_pretrained(MODEL_NAME)
MarianMTModel.from_pretrained(MARIAN_MODEL) MarianMTModel.from_pretrained(MARIAN_MODEL)
@timeout_decorator.timeout(120) # @timeout_decorator.timeout(1200)
@slow @slow
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") @require_torch_gpu
def test_train_mbart_cc25_enro_script(self): def test_train_mbart_cc25_enro_script(self):
data_dir = "examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace = { env_vars_to_replace = {
"--fp16_opt_level=O1": "", "$MAX_LEN": 64,
"$MAX_LEN": 128, "$BS": 64,
"$BS": 4,
"$GAS": 1, "$GAS": 1,
"$ENRO_DIR": data_dir, "$ENRO_DIR": self.data_dir,
"facebook/mbart-large-cc25": MODEL_NAME, "facebook/mbart-large-cc25": MARIAN_MODEL,
# Download is 120MB in previous test. # "val_check_interval=0.25": "val_check_interval=1.0",
"val_check_interval=0.25": "val_check_interval=1.0", "--learning_rate=3e-5": "--learning_rate 3e-4",
"--num_train_epochs 6": "--num_train_epochs 1",
} }
# Clean up bash script # Clean up bash script
bash_script = Path("examples/seq2seq/train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip() bash_script = (self.test_file_dir / "train_mbart_cc25_enro.sh").open().read().split("finetune.py")[1].strip()
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
for k, v in env_vars_to_replace.items(): for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v)) bash_script = bash_script.replace(k, str(v))
output_dir = self.get_auto_remove_tmp_dir() output_dir = self.get_auto_remove_tmp_dir()
bash_script = bash_script.replace("--fp16 ", "") # bash_script = bash_script.replace("--fp16 ", "")
testargs = ( args = f"""
["finetune.py"] --output_dir {output_dir}
+ bash_script.split() --tokenizer_name Helsinki-NLP/opus-mt-en-ro
+ [ --sortish_sampler
f"--output_dir={output_dir}", --do_predict
"--gpus=1", --gpus 1
"--learning_rate=3e-1", --freeze_encoder
"--warmup_steps=0", --n_train 40000
"--val_check_interval=1.0", --n_val 500
"--tokenizer_name=facebook/mbart-large-en-ro", --n_test 500
] --fp16_opt_level O1
) --num_sanity_val_steps 0
--eval_beams 2
""".split()
# XXX: args.gpus > 1 : handle multigpu in the future
testargs = ["finetune.py"] + bash_script.split() + args
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
parser = SummarizationModule.add_model_specific_args(parser, os.getcwd()) parser = SummarizationModule.add_model_specific_args(parser, os.getcwd())
args = parser.parse_args() args = parser.parse_args()
args.do_predict = False
# assert args.gpus == gpus THIS BREAKS for multigpu
model = main(args) model = main(args)
# Check metrics # Check metrics
metrics = load_json(model.metrics_save_path) metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0] first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1] last_step_stats = metrics["val"][-1]
assert ( self.assertEqual(len(metrics["val"]), (args.max_epochs / args.val_check_interval))
len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float)
) # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01 self.assertGreater(last_step_stats["val_avg_gen_time"], 0.01)
# model hanging on generate. Maybe bad config was saved. (XXX: old comment/assert?)
self.assertLessEqual(last_step_stats["val_avg_gen_time"], 1.0)
assert first_step_stats["val_avg_bleu"] < last_step_stats["val_avg_bleu"] # model learned nothing # test learning requirements:
assert 1.0 >= last_step_stats["val_avg_gen_time"] # model hanging on generate. Maybe bad config was saved.
assert isinstance(last_step_stats[f"val_avg_{model.val_metric}"], float) # 1. BLEU improves over the course of training by more than 2 pts
self.assertGreater(last_step_stats["val_avg_bleu"] - first_step_stats["val_avg_bleu"], 2)
# 2. BLEU finishes above 17
self.assertGreater(last_step_stats["val_avg_bleu"], 17)
# 3. test BLEU and val BLEU within ~1.1 pt.
self.assertLess(abs(metrics["val"][-1]["val_avg_bleu"] - metrics["test"][-1]["test_avg_bleu"]), 1.1)
# check lightning ckpt can be loaded and has a reasonable statedict # check lightning ckpt can be loaded and has a reasonable statedict
contents = os.listdir(output_dir) contents = os.listdir(output_dir)
...@@ -107,11 +122,13 @@ class TestAll(TestCasePlus): ...@@ -107,11 +122,13 @@ class TestAll(TestCasePlus):
# assert len(metrics["val"]) == desired_n_evals # assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1 assert len(metrics["test"]) == 1
class TestDistilMarianNoTeacher(TestCasePlus):
@timeout_decorator.timeout(600) @timeout_decorator.timeout(600)
@slow @slow
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") @require_torch_gpu
def test_opus_mt_distill_script(self): def test_opus_mt_distill_script(self):
data_dir = "examples/seq2seq/test_data/wmt_en_ro" data_dir = f"{self.test_file_dir_str}/test_data/wmt_en_ro"
env_vars_to_replace = { env_vars_to_replace = {
"--fp16_opt_level=O1": "", "--fp16_opt_level=O1": "",
"$MAX_LEN": 128, "$MAX_LEN": 128,
...@@ -124,7 +141,7 @@ class TestAll(TestCasePlus): ...@@ -124,7 +141,7 @@ class TestAll(TestCasePlus):
# Clean up bash script # Clean up bash script
bash_script = ( bash_script = (
Path("examples/seq2seq/distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip() (self.test_file_dir / "distil_marian_no_teacher.sh").open().read().split("distillation.py")[1].strip()
) )
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
bash_script = bash_script.replace("--fp16 ", " ") bash_script = bash_script.replace("--fp16 ", " ")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment