Unverified Commit 9f7b2b24 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[s2s testing] turn all to unittests, use auto-delete temp dirs (#7859)

parent dc552b9b
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
import argparse import argparse
import os import os
import sys import sys
import tempfile
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
...@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main ...@@ -16,7 +15,7 @@ from distillation import BartSummarizationDistiller, distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY from test_seq2seq_examples import CUDA_AVAILABLE, MBART_TINY
from transformers import BartForConditionalGeneration, MarianMTModel from transformers import BartForConditionalGeneration, MarianMTModel
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from utils import load_json from utils import load_json
...@@ -24,18 +23,18 @@ MODEL_NAME = MBART_TINY ...@@ -24,18 +23,18 @@ MODEL_NAME = MBART_TINY
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
@slow class TestAll(TestCasePlus):
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") @slow
def test_model_download(): @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
def test_model_download(self):
"""This warms up the cache so that we can time the next test without including download time, which varies between machines.""" """This warms up the cache so that we can time the next test without including download time, which varies between machines."""
BartForConditionalGeneration.from_pretrained(MODEL_NAME) BartForConditionalGeneration.from_pretrained(MODEL_NAME)
MarianMTModel.from_pretrained(MARIAN_MODEL) MarianMTModel.from_pretrained(MARIAN_MODEL)
@timeout_decorator.timeout(120)
@timeout_decorator.timeout(120) @slow
@slow @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") def test_train_mbart_cc25_enro_script(self):
def test_train_mbart_cc25_enro_script():
data_dir = "examples/seq2seq/test_data/wmt_en_ro" data_dir = "examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace = { env_vars_to_replace = {
"--fp16_opt_level=O1": "", "--fp16_opt_level=O1": "",
...@@ -53,7 +52,7 @@ def test_train_mbart_cc25_enro_script(): ...@@ -53,7 +52,7 @@ def test_train_mbart_cc25_enro_script():
bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "") bash_script = bash_script.replace("\\\n", "").strip().replace('"$@"', "")
for k, v in env_vars_to_replace.items(): for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v)) bash_script = bash_script.replace(k, str(v))
output_dir = tempfile.mkdtemp(prefix="output_mbart") output_dir = self.get_auto_remove_tmp_dir()
bash_script = bash_script.replace("--fp16 ", "") bash_script = bash_script.replace("--fp16 ", "")
testargs = ( testargs = (
...@@ -81,7 +80,9 @@ def test_train_mbart_cc25_enro_script(): ...@@ -81,7 +80,9 @@ def test_train_mbart_cc25_enro_script():
metrics = load_json(model.metrics_save_path) metrics = load_json(model.metrics_save_path)
first_step_stats = metrics["val"][0] first_step_stats = metrics["val"][0]
last_step_stats = metrics["val"][-1] last_step_stats = metrics["val"][-1]
assert len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1 # +1 accounts for val_sanity_check assert (
len(metrics["val"]) == (args.max_epochs / args.val_check_interval) + 1
) # +1 accounts for val_sanity_check
assert last_step_stats["val_avg_gen_time"] >= 0.01 assert last_step_stats["val_avg_gen_time"] >= 0.01
...@@ -106,11 +107,10 @@ def test_train_mbart_cc25_enro_script(): ...@@ -106,11 +107,10 @@ def test_train_mbart_cc25_enro_script():
# assert len(metrics["val"]) == desired_n_evals # assert len(metrics["val"]) == desired_n_evals
assert len(metrics["test"]) == 1 assert len(metrics["test"]) == 1
@timeout_decorator.timeout(600)
@timeout_decorator.timeout(600) @slow
@slow @pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU")
@pytest.mark.skipif(not CUDA_AVAILABLE, reason="too slow to run on CPU") def test_opus_mt_distill_script(self):
def test_opus_mt_distill_script():
data_dir = "examples/seq2seq/test_data/wmt_en_ro" data_dir = "examples/seq2seq/test_data/wmt_en_ro"
env_vars_to_replace = { env_vars_to_replace = {
"--fp16_opt_level=O1": "", "--fp16_opt_level=O1": "",
...@@ -131,7 +131,7 @@ def test_opus_mt_distill_script(): ...@@ -131,7 +131,7 @@ def test_opus_mt_distill_script():
for k, v in env_vars_to_replace.items(): for k, v in env_vars_to_replace.items():
bash_script = bash_script.replace(k, str(v)) bash_script = bash_script.replace(k, str(v))
output_dir = tempfile.mkdtemp(prefix="marian_output") output_dir = self.get_auto_remove_tmp_dir()
bash_script = bash_script.replace("--fp16", "") bash_script = bash_script.replace("--fp16", "")
epochs = 6 epochs = 6
testargs = ( testargs = (
......
import os import os
import tempfile
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
...@@ -7,11 +6,12 @@ import pytest ...@@ -7,11 +6,12 @@ import pytest
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from pack_dataset import pack_data_dir from pack_dataset import pack_data_dir
from parameterized import parameterized
from save_len_file import save_len_file from save_len_file import save_len_file
from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir from test_seq2seq_examples import ARTICLES, BART_TINY, MARIAN_TINY, MBART_TINY, SUMMARIES, T5_TINY, make_test_data_dir
from transformers import AutoTokenizer from transformers import AutoTokenizer
from transformers.modeling_bart import shift_tokens_right from transformers.modeling_bart import shift_tokens_right
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset from utils import FAIRSEQ_AVAILABLE, DistributedSortishSampler, LegacySeq2SeqDataset, Seq2SeqDataset
...@@ -19,9 +19,8 @@ BERT_BASE_CASED = "bert-base-cased" ...@@ -19,9 +19,8 @@ BERT_BASE_CASED = "bert-base-cased"
PEGASUS_XSUM = "google/pegasus-xsum" PEGASUS_XSUM = "google/pegasus-xsum"
@slow class TestAll(TestCasePlus):
@pytest.mark.parametrize( @parameterized.expand(
"tok_name",
[ [
MBART_TINY, MBART_TINY,
MARIAN_TINY, MARIAN_TINY,
...@@ -29,10 +28,11 @@ PEGASUS_XSUM = "google/pegasus-xsum" ...@@ -29,10 +28,11 @@ PEGASUS_XSUM = "google/pegasus-xsum"
BART_TINY, BART_TINY,
PEGASUS_XSUM, PEGASUS_XSUM,
], ],
) )
def test_seq2seq_dataset_truncation(tok_name): @slow
def test_seq2seq_dataset_truncation(self, tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name) tokenizer = AutoTokenizer.from_pretrained(tok_name)
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
max_src_len = 4 max_src_len = 4
...@@ -68,11 +68,10 @@ def test_seq2seq_dataset_truncation(tok_name): ...@@ -68,11 +68,10 @@ def test_seq2seq_dataset_truncation(tok_name):
break # No need to test every batch break # No need to test every batch
@parameterized.expand([BART_TINY, BERT_BASE_CASED])
@pytest.mark.parametrize("tok", [BART_TINY, BERT_BASE_CASED]) def test_legacy_dataset_truncation(self, tok):
def test_legacy_dataset_truncation(tok):
tokenizer = AutoTokenizer.from_pretrained(tok) tokenizer = AutoTokenizer.from_pretrained(tok)
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES) max_len_source = max(len(tokenizer.encode(a)) for a in ARTICLES)
max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES) max_len_target = max(len(tokenizer.encode(a)) for a in SUMMARIES)
trunc_target = 4 trunc_target = 4
...@@ -94,13 +93,12 @@ def test_legacy_dataset_truncation(tok): ...@@ -94,13 +93,12 @@ def test_legacy_dataset_truncation(tok):
assert max_len_target > trunc_target # Truncated assert max_len_target > trunc_target # Truncated
break # No need to test every batch break # No need to test every batch
def test_pack_dataset(self):
def test_pack_dataset():
tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25") tokenizer = AutoTokenizer.from_pretrained("facebook/mbart-large-cc25")
tmp_dir = Path(make_test_data_dir()) tmp_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
orig_examples = tmp_dir.joinpath("train.source").open().readlines() orig_examples = tmp_dir.joinpath("train.source").open().readlines()
save_dir = Path(tempfile.mkdtemp(prefix="packed_")) save_dir = Path(make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()))
pack_data_dir(tokenizer, tmp_dir, 128, save_dir) pack_data_dir(tokenizer, tmp_dir, 128, save_dir)
orig_paths = {x.name for x in tmp_dir.iterdir()} orig_paths = {x.name for x in tmp_dir.iterdir()}
new_paths = {x.name for x in save_dir.iterdir()} new_paths = {x.name for x in save_dir.iterdir()}
...@@ -112,12 +110,11 @@ def test_pack_dataset(): ...@@ -112,12 +110,11 @@ def test_pack_dataset():
assert len(packed_examples[0]) == sum(len(x) for x in orig_examples) assert len(packed_examples[0]) == sum(len(x) for x in orig_examples)
assert orig_paths == new_paths assert orig_paths == new_paths
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq")
@pytest.mark.skipif(not FAIRSEQ_AVAILABLE, reason="This test requires fairseq") def test_dynamic_batch_size(self):
def test_dynamic_batch_size():
if not FAIRSEQ_AVAILABLE: if not FAIRSEQ_AVAILABLE:
return return
ds, max_tokens, tokenizer = _get_dataset(max_len=64) ds, max_tokens, tokenizer = self._get_dataset(max_len=64)
required_batch_size_multiple = 64 required_batch_size_multiple = 64
batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple) batch_sampler = ds.make_dynamic_sampler(max_tokens, required_batch_size_multiple=required_batch_size_multiple)
batch_sizes = [len(x) for x in batch_sampler] batch_sizes = [len(x) for x in batch_sampler]
...@@ -138,9 +135,8 @@ def test_dynamic_batch_size(): ...@@ -138,9 +135,8 @@ def test_dynamic_batch_size():
if failures: if failures:
raise AssertionError(f"too many tokens in {len(failures)} batches") raise AssertionError(f"too many tokens in {len(failures)} batches")
def test_sortish_sampler_reduces_padding(self):
def test_sortish_sampler_reduces_padding(): ds, _, tokenizer = self._get_dataset(max_len=512)
ds, _, tokenizer = _get_dataset(max_len=512)
bs = 2 bs = 2
sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False) sortish_sampler = ds.make_sortish_sampler(bs, shuffle=False)
...@@ -156,8 +152,7 @@ def test_sortish_sampler_reduces_padding(): ...@@ -156,8 +152,7 @@ def test_sortish_sampler_reduces_padding():
assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl)) assert sum(count_pad_tokens(sortish_dl)) < sum(count_pad_tokens(naive_dl))
assert len(sortish_dl) == len(naive_dl) assert len(sortish_dl) == len(naive_dl)
def _get_dataset(self, n_obs=1000, max_len=128):
def _get_dataset(n_obs=1000, max_len=128):
if os.getenv("USE_REAL_DATA", False): if os.getenv("USE_REAL_DATA", False):
data_dir = "examples/seq2seq/wmt_en_ro" data_dir = "examples/seq2seq/wmt_en_ro"
max_tokens = max_len * 2 * 64 max_tokens = max_len * 2 * 64
...@@ -179,16 +174,13 @@ def _get_dataset(n_obs=1000, max_len=128): ...@@ -179,16 +174,13 @@ def _get_dataset(n_obs=1000, max_len=128):
) )
return ds, max_tokens, tokenizer return ds, max_tokens, tokenizer
def test_distributed_sortish_sampler_splits_indices_between_procs(self):
def test_distributed_sortish_sampler_splits_indices_between_procs(): ds, max_tokens, tokenizer = self._get_dataset()
ds, max_tokens, tokenizer = _get_dataset()
ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False)) ids1 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=0, add_extra_examples=False))
ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False)) ids2 = set(DistributedSortishSampler(ds, 256, num_replicas=2, rank=1, add_extra_examples=False))
assert ids1.intersection(ids2) == set() assert ids1.intersection(ids2) == set()
@parameterized.expand(
@pytest.mark.parametrize(
"tok_name",
[ [
MBART_TINY, MBART_TINY,
MARIAN_TINY, MARIAN_TINY,
...@@ -196,13 +188,13 @@ def test_distributed_sortish_sampler_splits_indices_between_procs(): ...@@ -196,13 +188,13 @@ def test_distributed_sortish_sampler_splits_indices_between_procs():
BART_TINY, BART_TINY,
PEGASUS_XSUM, PEGASUS_XSUM,
], ],
) )
def test_dataset_kwargs(tok_name): def test_dataset_kwargs(self, tok_name):
tokenizer = AutoTokenizer.from_pretrained(tok_name) tokenizer = AutoTokenizer.from_pretrained(tok_name)
if tok_name == MBART_TINY: if tok_name == MBART_TINY:
train_dataset = Seq2SeqDataset( train_dataset = Seq2SeqDataset(
tokenizer, tokenizer,
data_dir=make_test_data_dir(), data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
type_path="train", type_path="train",
max_source_length=4, max_source_length=4,
max_target_length=8, max_target_length=8,
...@@ -213,7 +205,11 @@ def test_dataset_kwargs(tok_name): ...@@ -213,7 +205,11 @@ def test_dataset_kwargs(tok_name):
assert "src_lang" in kwargs and "tgt_lang" in kwargs assert "src_lang" in kwargs and "tgt_lang" in kwargs
else: else:
train_dataset = Seq2SeqDataset( train_dataset = Seq2SeqDataset(
tokenizer, data_dir=make_test_data_dir(), type_path="train", max_source_length=4, max_target_length=8 tokenizer,
data_dir=make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir()),
type_path="train",
max_source_length=4,
max_target_length=8,
) )
kwargs = train_dataset.dataset_kwargs kwargs = train_dataset.dataset_kwargs
assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs assert "add_prefix_space" not in kwargs if tok_name != BART_TINY else "add_prefix_space" in kwargs
......
import os import os
import sys import sys
import tempfile
from unittest.mock import patch from unittest.mock import patch
from transformers.testing_utils import slow from transformers.testing_utils import TestCasePlus, slow
from transformers.trainer_callback import TrainerState from transformers.trainer_callback import TrainerState
from transformers.trainer_utils import set_seed from transformers.trainer_utils import set_seed
...@@ -15,18 +14,18 @@ set_seed(42) ...@@ -15,18 +14,18 @@ set_seed(42)
MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1" MARIAN_MODEL = "sshleifer/student_marian_en_ro_6_1"
def test_finetune_trainer(): class TestFinetuneTrainer(TestCasePlus):
output_dir = run_trainer(1, "12", MBART_TINY, 1) def test_finetune_trainer(self):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()] eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0] first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats assert "eval_bleu" in first_step_stats
@slow
@slow def test_finetune_trainer_slow(self):
def test_finetune_trainer_slow():
# There is a missing call to __init__process_group somewhere # There is a missing call to __init__process_group somewhere
output_dir = run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3) output_dir = self.run_trainer(eval_steps=2, max_len="128", model_name=MARIAN_MODEL, num_train_epochs=3)
# Check metrics # Check metrics
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
...@@ -43,10 +42,9 @@ def test_finetune_trainer_slow(): ...@@ -43,10 +42,9 @@ def test_finetune_trainer_slow():
assert "test_generations.txt" in contents assert "test_generations.txt" in contents
assert "test_results.json" in contents assert "test_results.json" in contents
def run_trainer(self, eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
def run_trainer(eval_steps: int, max_len: str, model_name: str, num_train_epochs: int):
data_dir = "examples/seq2seq/test_data/wmt_en_ro" data_dir = "examples/seq2seq/test_data/wmt_en_ro"
output_dir = tempfile.mkdtemp(prefix="test_output") output_dir = self.get_auto_remove_tmp_dir()
argv = f""" argv = f"""
--model_name_or_path {model_name} --model_name_or_path {model_name}
--data_dir {data_dir} --data_dir {data_dir}
......
...@@ -3,7 +3,6 @@ import logging ...@@ -3,7 +3,6 @@ import logging
import os import os
import sys import sys
import tempfile import tempfile
import unittest
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch
...@@ -15,11 +14,12 @@ import lightning_base ...@@ -15,11 +14,12 @@ import lightning_base
from convert_pl_checkpoint_to_hf import convert_pl_to_hf from convert_pl_checkpoint_to_hf import convert_pl_to_hf
from distillation import distill_main from distillation import distill_main
from finetune import SummarizationModule, main from finetune import SummarizationModule, main
from parameterized import parameterized
from run_eval import generate_summaries_or_translations, run_generate from run_eval import generate_summaries_or_translations, run_generate
from run_eval_search import run_search from run_eval_search import run_search
from transformers import AutoConfig, AutoModelForSeq2SeqLM from transformers import AutoConfig, AutoModelForSeq2SeqLM
from transformers.hf_api import HfApi from transformers.hf_api import HfApi
from transformers.testing_utils import CaptureStderr, CaptureStdout, require_multigpu, require_torch_and_cuda, slow from transformers.testing_utils import CaptureStderr, CaptureStdout, TestCasePlus, require_torch_and_cuda, slow
from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json from utils import ROUGE_KEYS, label_smoothed_nll_loss, lmap, load_json
...@@ -52,7 +52,7 @@ CHEAP_ARGS = { ...@@ -52,7 +52,7 @@ CHEAP_ARGS = {
"student_decoder_layers": 1, "student_decoder_layers": 1,
"val_check_interval": 1.0, "val_check_interval": 1.0,
"output_dir": "", "output_dir": "",
"fp16": False, # TODO: set this to CUDA_AVAILABLE if ci installs apex or start using native amp "fp16": False, # TODO(SS): set this to CUDA_AVAILABLE if ci installs apex or start using native amp
"no_teacher": False, "no_teacher": False,
"fp16_opt_level": "O1", "fp16_opt_level": "O1",
"gpus": 1 if CUDA_AVAILABLE else 0, "gpus": 1 if CUDA_AVAILABLE else 0,
...@@ -88,6 +88,7 @@ CHEAP_ARGS = { ...@@ -88,6 +88,7 @@ CHEAP_ARGS = {
"student_encoder_layers": 1, "student_encoder_layers": 1,
"freeze_encoder": False, "freeze_encoder": False,
"auto_scale_batch_size": False, "auto_scale_batch_size": False,
"overwrite_output_dir": False,
} }
...@@ -110,15 +111,14 @@ logger.addHandler(stream_handler) ...@@ -110,15 +111,14 @@ logger.addHandler(stream_handler)
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
def make_test_data_dir(**kwargs): def make_test_data_dir(tmp_dir):
tmp_dir = Path(tempfile.mkdtemp(**kwargs))
for split in ["train", "val", "test"]: for split in ["train", "val", "test"]:
_dump_articles((tmp_dir / f"{split}.source"), ARTICLES) _dump_articles(os.path.join(tmp_dir, f"{split}.source"), ARTICLES)
_dump_articles((tmp_dir / f"{split}.target"), SUMMARIES) _dump_articles(os.path.join(tmp_dir, f"{split}.target"), SUMMARIES)
return tmp_dir return tmp_dir
class TestSummarizationDistiller(unittest.TestCase): class TestSummarizationDistiller(TestCasePlus):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks logging.disable(logging.CRITICAL) # remove noisy download output from tracebacks
...@@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -143,17 +143,6 @@ class TestSummarizationDistiller(unittest.TestCase):
failures.append(m) failures.append(m)
assert not failures, f"The following models could not be loaded through AutoConfig: {failures}" assert not failures, f"The following models could not be loaded through AutoConfig: {failures}"
@require_multigpu
@unittest.skip("Broken at the moment")
def test_multigpu(self):
updates = dict(
no_teacher=True,
freeze_encoder=True,
gpus=2,
sortish_sampler=True,
)
self._test_distiller_cli(updates, check_contents=False)
def test_distill_no_teacher(self): def test_distill_no_teacher(self):
updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True) updates = dict(student_encoder_layers=2, student_decoder_layers=1, no_teacher=True)
self._test_distiller_cli(updates) self._test_distiller_cli(updates)
...@@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -173,12 +162,12 @@ class TestSummarizationDistiller(unittest.TestCase):
self.assertEqual(1, len(ckpts)) self.assertEqual(1, len(ckpts))
transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin")) transformer_ckpts = list(Path(model.output_dir).glob("**/*.bin"))
self.assertEqual(len(transformer_ckpts), 2) self.assertEqual(len(transformer_ckpts), 2)
examples = lmap(str.strip, model.hparams.data_dir.joinpath("test.source").open().readlines()) examples = lmap(str.strip, Path(model.hparams.data_dir).joinpath("test.source").open().readlines())
out_path = tempfile.mktemp() out_path = tempfile.mktemp() # XXX: not being cleaned up
generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr")) generate_summaries_or_translations(examples, out_path, str(model.output_dir / "best_tfmr"))
self.assertTrue(Path(out_path).exists()) self.assertTrue(Path(out_path).exists())
out_path_new = tempfile.mkdtemp() out_path_new = self.get_auto_remove_tmp_dir()
convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new) convert_pl_to_hf(ckpts[0], transformer_ckpts[0].parent, out_path_new)
assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin")) assert os.path.exists(os.path.join(out_path_new, "pytorch_model.bin"))
...@@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -253,8 +242,8 @@ class TestSummarizationDistiller(unittest.TestCase):
) )
default_updates.update(updates) default_updates.update(updates)
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = tempfile.mkdtemp(prefix="output_") output_dir = self.get_auto_remove_tmp_dir()
args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates) args_d.update(data_dir=tmp_dir, output_dir=output_dir, **default_updates)
model = distill_main(argparse.Namespace(**args_d)) model = distill_main(argparse.Namespace(**args_d))
...@@ -279,13 +268,15 @@ class TestSummarizationDistiller(unittest.TestCase): ...@@ -279,13 +268,15 @@ class TestSummarizationDistiller(unittest.TestCase):
return model return model
def run_eval_tester(model): class TestTheRest(TestCasePlus):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" def run_eval_tester(self, model):
input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt" output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists() assert not output_file_name.exists()
articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."] articles = [" New York (CNN)When Liana Barrientos was 23 years old, she got married in Westchester County."]
_dump_articles(input_file_name, articles) _dump_articles(input_file_name, articles)
score_path = str(Path(tempfile.mkdtemp()) / "scores.json")
score_path = str(Path(self.get_auto_remove_tmp_dir()) / "scores.json")
task = "translation_en_to_de" if model == T5_TINY else "summarization" task = "translation_en_to_de" if model == T5_TINY else "summarization"
testargs = f""" testargs = f"""
run_eval_search.py run_eval_search.py
...@@ -301,27 +292,24 @@ def run_eval_tester(model): ...@@ -301,27 +292,24 @@ def run_eval_tester(model):
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
run_generate() run_generate()
assert Path(output_file_name).exists() assert Path(output_file_name).exists()
os.remove(Path(output_file_name)) # os.remove(Path(output_file_name))
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
def test_run_eval():
run_eval_tester(T5_TINY)
# test one model to quickly (no-@slow) catch simple problems and do an
# extensive testing of functionality with multiple models as @slow separately
def test_run_eval(self):
self.run_eval_tester(T5_TINY)
# any extra models should go into the list here - can be slow # any extra models should go into the list here - can be slow
@slow @parameterized.expand([BART_TINY, MBART_TINY])
@pytest.mark.parametrize("model", [BART_TINY, MBART_TINY]) @slow
def test_run_eval_slow(model): def test_run_eval_slow(self, model):
run_eval_tester(model) self.run_eval_tester(model)
# testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart) # testing with 2 models to validate: 1. translation (t5) 2. summarization (mbart)
@slow @parameterized.expand([T5_TINY, MBART_TINY])
@pytest.mark.parametrize("model", [T5_TINY, MBART_TINY]) @slow
def test_run_eval_search(model): def test_run_eval_search(self, model):
input_file_name = Path(tempfile.mkdtemp()) / "utest_input.source" input_file_name = Path(self.get_auto_remove_tmp_dir()) / "utest_input.source"
output_file_name = input_file_name.parent / "utest_output.txt" output_file_name = input_file_name.parent / "utest_output.txt"
assert not output_file_name.exists() assert not output_file_name.exists()
...@@ -334,7 +322,7 @@ def test_run_eval_search(model): ...@@ -334,7 +322,7 @@ def test_run_eval_search(model):
], ],
} }
tmp_dir = Path(tempfile.mkdtemp()) tmp_dir = Path(self.get_auto_remove_tmp_dir())
score_path = str(tmp_dir / "scores.json") score_path = str(tmp_dir / "scores.json")
reference_path = str(tmp_dir / "val.target") reference_path = str(tmp_dir / "val.target")
_dump_articles(input_file_name, text["en"]) _dump_articles(input_file_name, text["en"])
...@@ -367,18 +355,16 @@ def test_run_eval_search(model): ...@@ -367,18 +355,16 @@ def test_run_eval_search(model):
assert Path(output_file_name).exists() assert Path(output_file_name).exists()
os.remove(Path(output_file_name)) os.remove(Path(output_file_name))
@parameterized.expand(
@pytest.mark.parametrize(
"model",
[T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY], [T5_TINY, BART_TINY, MBART_TINY, MARIAN_TINY, FSMT_TINY],
) )
def test_finetune(model): def test_finetune(self, model):
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization" task = "translation" if model in [MBART_TINY, MARIAN_TINY, FSMT_TINY] else "summarization"
args_d["label_smoothing"] = 0.1 if task == "translation" else 0 args_d["label_smoothing"] = 0.1 if task == "translation" else 0
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
output_dir = tempfile.mkdtemp(prefix="output_") output_dir = self.get_auto_remove_tmp_dir()
args_d.update( args_d.update(
data_dir=tmp_dir, data_dir=tmp_dir,
model_name_or_path=model, model_name_or_path=model,
...@@ -423,12 +409,11 @@ def test_finetune(model): ...@@ -423,12 +409,11 @@ def test_finetune(model):
assert isinstance(example_batch, dict) assert isinstance(example_batch, dict)
assert len(example_batch) >= 4 assert len(example_batch) >= 4
def test_finetune_extra_model_args(self):
def test_finetune_extra_model_args():
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
task = "summarization" task = "summarization"
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
args_d.update( args_d.update(
data_dir=tmp_dir, data_dir=tmp_dir,
...@@ -445,7 +430,7 @@ def test_finetune_extra_model_args(): ...@@ -445,7 +430,7 @@ def test_finetune_extra_model_args():
# test models whose config includes the extra_model_args # test models whose config includes the extra_model_args
model = BART_TINY model = BART_TINY
output_dir = tempfile.mkdtemp(prefix="output_1_") output_dir = self.get_auto_remove_tmp_dir()
args_d1 = args_d.copy() args_d1 = args_d.copy()
args_d1.update( args_d1.update(
model_name_or_path=model, model_name_or_path=model,
...@@ -461,7 +446,7 @@ def test_finetune_extra_model_args(): ...@@ -461,7 +446,7 @@ def test_finetune_extra_model_args():
# test models whose config doesn't include the extra_model_args # test models whose config doesn't include the extra_model_args
model = T5_TINY model = T5_TINY
output_dir = tempfile.mkdtemp(prefix="output_2_") output_dir = self.get_auto_remove_tmp_dir()
args_d2 = args_d.copy() args_d2 = args_d.copy()
args_d2.update( args_d2.update(
model_name_or_path=model, model_name_or_path=model,
...@@ -474,15 +459,14 @@ def test_finetune_extra_model_args(): ...@@ -474,15 +459,14 @@ def test_finetune_extra_model_args():
model = main(args) model = main(args)
assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute" assert str(excinfo.value) == f"model config doesn't have a `{unsupported_param}` attribute"
def test_finetune_lr_schedulers(self):
def test_finetune_lr_schedulers():
args_d: dict = CHEAP_ARGS.copy() args_d: dict = CHEAP_ARGS.copy()
task = "summarization" task = "summarization"
tmp_dir = make_test_data_dir() tmp_dir = make_test_data_dir(tmp_dir=self.get_auto_remove_tmp_dir())
model = BART_TINY model = BART_TINY
output_dir = tempfile.mkdtemp(prefix="output_1_") output_dir = self.get_auto_remove_tmp_dir()
args_d.update( args_d.update(
data_dir=tmp_dir, data_dir=tmp_dir,
...@@ -531,4 +515,6 @@ def test_finetune_lr_schedulers(): ...@@ -531,4 +515,6 @@ def test_finetune_lr_schedulers():
args_d1["lr_scheduler"] = supported_param args_d1["lr_scheduler"] = supported_param
args = argparse.Namespace(**args_d1) args = argparse.Namespace(**args_d1)
model = main(args) model = main(args)
assert getattr(model.hparams, "lr_scheduler") == supported_param, f"lr_scheduler={supported_param} shouldn't fail" assert (
getattr(model.hparams, "lr_scheduler") == supported_param
), f"lr_scheduler={supported_param} shouldn't fail"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment