Unverified Commit 9dbe4094 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[testing] a new TestCasePlus subclass + get_auto_remove_tmp_dir() (#6494)

* [testing] switch to a new TestCasePlus + get_auto_remove_tmp_dir() for auto-removal of tmp dirs

* respect after=True for tempfile, simplify code

* comments

* comment fix

* put `before` last in args, so can make debug even faster
parent 36010cb1
import argparse import argparse
import logging import logging
import shutil
import sys import sys
import unittest
from unittest.mock import patch from unittest.mock import patch
import run_glue_with_pabee import run_glue_with_pabee
from transformers.testing_utils import TestCasePlus
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -20,20 +19,19 @@ def get_setup_file(): ...@@ -20,20 +19,19 @@ def get_setup_file():
return args.f return args.f
def clean_test_dir(path): class PabeeTests(TestCasePlus):
shutil.rmtree(path, ignore_errors=True)
class PabeeTests(unittest.TestCase):
def test_run_glue(self): def test_run_glue(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
testargs = """ tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_glue_with_pabee.py run_glue_with_pabee.py
--model_type albert --model_type albert
--model_name_or_path albert-base-v2 --model_name_or_path albert-base-v2
--data_dir ./tests/fixtures/tests_samples/MRPC/ --data_dir ./tests/fixtures/tests_samples/MRPC/
--output_dir {tmp_dir}
--overwrite_output_dir
--task_name mrpc --task_name mrpc
--do_train --do_train
--do_eval --do_eval
...@@ -42,16 +40,11 @@ class PabeeTests(unittest.TestCase): ...@@ -42,16 +40,11 @@ class PabeeTests(unittest.TestCase):
--learning_rate=2e-5 --learning_rate=2e-5
--max_steps=50 --max_steps=50
--warmup_steps=2 --warmup_steps=2
--overwrite_output_dir
--seed=42 --seed=42
--max_seq_length=128 --max_seq_length=128
""" """.split()
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_glue_with_pabee.main() result = run_glue_with_pabee.main()
for value in result.values(): for value in result.values():
self.assertGreaterEqual(value, 0.75) self.assertGreaterEqual(value, 0.75)
clean_test_dir(output_dir)
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
import argparse import argparse
import logging import logging
import os import os
import shutil
import sys import sys
import unittest
from unittest.mock import patch from unittest.mock import patch
import torch import torch
from transformers.testing_utils import TestCasePlus
SRC_DIRS = [ SRC_DIRS = [
os.path.join(os.path.dirname(__file__), dirname) os.path.join(os.path.dirname(__file__), dirname)
...@@ -52,19 +52,18 @@ def get_setup_file(): ...@@ -52,19 +52,18 @@ def get_setup_file():
return args.f return args.f
def clean_test_dir(path): class ExamplesTests(TestCasePlus):
shutil.rmtree(path, ignore_errors=True)
class ExamplesTests(unittest.TestCase):
def test_run_glue(self): def test_run_glue(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
testargs = """ tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_glue.py run_glue.py
--model_name_or_path distilbert-base-uncased --model_name_or_path distilbert-base-uncased
--data_dir ./tests/fixtures/tests_samples/MRPC/ --data_dir ./tests/fixtures/tests_samples/MRPC/
--output_dir {tmp_dir}
--overwrite_output_dir
--task_name mrpc --task_name mrpc
--do_train --do_train
--do_eval --do_eval
...@@ -73,28 +72,26 @@ class ExamplesTests(unittest.TestCase): ...@@ -73,28 +72,26 @@ class ExamplesTests(unittest.TestCase):
--learning_rate=1e-4 --learning_rate=1e-4
--max_steps=10 --max_steps=10
--warmup_steps=2 --warmup_steps=2
--overwrite_output_dir
--seed=42 --seed=42
--max_seq_length=128 --max_seq_length=128
""" """.split()
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_glue.main() result = run_glue.main()
del result["eval_loss"] del result["eval_loss"]
for value in result.values(): for value in result.values():
self.assertGreaterEqual(value, 0.75) self.assertGreaterEqual(value, 0.75)
clean_test_dir(output_dir)
def test_run_pl_glue(self): def test_run_pl_glue(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
testargs = """ tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_pl_glue.py run_pl_glue.py
--model_name_or_path bert-base-cased --model_name_or_path bert-base-cased
--data_dir ./tests/fixtures/tests_samples/MRPC/ --data_dir ./tests/fixtures/tests_samples/MRPC/
--output_dir {tmp_dir}
--task mrpc --task mrpc
--do_train --do_train
--do_predict --do_predict
...@@ -103,11 +100,7 @@ class ExamplesTests(unittest.TestCase): ...@@ -103,11 +100,7 @@ class ExamplesTests(unittest.TestCase):
--num_train_epochs=1 --num_train_epochs=1
--seed=42 --seed=42
--max_seq_length=128 --max_seq_length=128
""" """.split()
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
if torch.cuda.is_available(): if torch.cuda.is_available():
testargs += ["--fp16", "--gpus=1"] testargs += ["--fp16", "--gpus=1"]
...@@ -123,13 +116,13 @@ class ExamplesTests(unittest.TestCase): ...@@ -123,13 +116,13 @@ class ExamplesTests(unittest.TestCase):
# for k, v in result.items(): # for k, v in result.items():
# self.assertGreaterEqual(v, 0.75, f"({k})") # self.assertGreaterEqual(v, 0.75, f"({k})")
# #
clean_test_dir(output_dir)
def test_run_language_modeling(self): def test_run_language_modeling(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
testargs = """ tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_language_modeling.py run_language_modeling.py
--model_name_or_path distilroberta-base --model_name_or_path distilroberta-base
--model_type roberta --model_type roberta
...@@ -137,29 +130,30 @@ class ExamplesTests(unittest.TestCase): ...@@ -137,29 +130,30 @@ class ExamplesTests(unittest.TestCase):
--line_by_line --line_by_line
--train_data_file ./tests/fixtures/sample_text.txt --train_data_file ./tests/fixtures/sample_text.txt
--eval_data_file ./tests/fixtures/sample_text.txt --eval_data_file ./tests/fixtures/sample_text.txt
--output_dir {tmp_dir}
--overwrite_output_dir --overwrite_output_dir
--do_train --do_train
--do_eval --do_eval
--num_train_epochs=1 --num_train_epochs=1
--no_cuda --no_cuda
""" """.split()
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_language_modeling.main() result = run_language_modeling.main()
self.assertLess(result["perplexity"], 35) self.assertLess(result["perplexity"], 35)
clean_test_dir(output_dir)
def test_run_squad(self): def test_run_squad(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
testargs = """ tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f"""
run_squad.py run_squad.py
--model_type=distilbert --model_type=distilbert
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad --model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
--data_dir=./tests/fixtures/tests_samples/SQUAD --data_dir=./tests/fixtures/tests_samples/SQUAD
--output_dir {tmp_dir}
--overwrite_output_dir
--max_steps=10 --max_steps=10
--warmup_steps=2 --warmup_steps=2
--do_train --do_train
...@@ -168,17 +162,13 @@ class ExamplesTests(unittest.TestCase): ...@@ -168,17 +162,13 @@ class ExamplesTests(unittest.TestCase):
--learning_rate=2e-4 --learning_rate=2e-4
--per_gpu_train_batch_size=2 --per_gpu_train_batch_size=2
--per_gpu_eval_batch_size=1 --per_gpu_eval_batch_size=1
--overwrite_output_dir
--seed=42 --seed=42
""" """.split()
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_squad.main() result = run_squad.main()
self.assertGreaterEqual(result["f1"], 25) self.assertGreaterEqual(result["f1"], 25)
self.assertGreaterEqual(result["exact"], 21) self.assertGreaterEqual(result["exact"], 21)
clean_test_dir(output_dir)
def test_generation(self): def test_generation(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
......
import os import os
import re import re
import shutil
import sys import sys
import tempfile
import unittest import unittest
from distutils.util import strtobool from distutils.util import strtobool
from io import StringIO from io import StringIO
from pathlib import Path
from .file_utils import _tf_available, _torch_available, _torch_tpu_available from .file_utils import _tf_available, _torch_available, _torch_tpu_available
...@@ -255,3 +258,92 @@ class CaptureStderr(CaptureStd): ...@@ -255,3 +258,92 @@ class CaptureStderr(CaptureStd):
def __init__(self): def __init__(self):
super().__init__(out=False) super().__init__(out=False)
class TestCasePlus(unittest.TestCase):
"""This class extends `unittest.TestCase` with additional features.
Feature 1: Flexible auto-removable temp dirs which are guaranteed to get
removed at the end of test.
In all the following scenarios the temp dir will be auto-removed at the end
of test, unless `after=False`.
# 1. create a unique temp dir, `tmp_dir` will contain the path to the created temp dir
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir()
# 2. create a temp dir of my choice and delete it at the end - useful for debug when you want to
# monitor a specific directory
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test")
# 3. create a temp dir of my choice and do not delete it at the end - useful for when you want
# to look at the temp results
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", after=False)
# 4. create a temp dir of my choice and ensure to delete it right away - useful for when you
# disabled deletion in the previous test run and want to make sure the that tmp dir is empty
# before the new test is run
def test_whatever(self):
tmp_dir = self.get_auto_remove_tmp_dir(tmp_dir="./tmp/run/test", before=True)
Note 1: In order to run the equivalent of `rm -r` safely, only subdirs of the
project repository checkout are allowed if an explicit `tmp_dir` is used, so
that by mistake no `/tmp` or similar important part of the filesystem will
get nuked. i.e. please always pass paths that start with `./`
Note 2: Each test can register multiple temp dirs and they all will get
auto-removed, unless requested otherwise.
"""
def setUp(self):
self.teardown_tmp_dirs = []
def get_auto_remove_tmp_dir(self, tmp_dir=None, after=True, before=False):
"""
Args:
tmp_dir (:obj:`string`, `optional`, defaults to :obj:`None`):
use this path, if None a unique path will be assigned
before (:obj:`bool`, `optional`, defaults to :obj:`False`):
if `True` and tmp dir already exists make sure to empty it right away
after (:obj:`bool`, `optional`, defaults to :obj:`True`):
delete the tmp dir at the end of the test
Returns:
tmp_dir(:obj:`string`):
either the same value as passed via `tmp_dir` or the path to the auto-created tmp dir
"""
if tmp_dir is not None:
# using provided path
path = Path(tmp_dir).resolve()
# to avoid nuking parts of the filesystem, only relative paths are allowed
if not tmp_dir.startswith("./"):
raise ValueError(
f"`tmp_dir` can only be a relative path, i.e. `./some/path`, but received `{tmp_dir}`"
)
# ensure the dir is empty to start with
if before is True and path.exists():
shutil.rmtree(tmp_dir, ignore_errors=True)
path.mkdir(parents=True, exist_ok=True)
else:
# using unique tmp dir (always empty, regardless of `before`)
tmp_dir = tempfile.mkdtemp()
if after is True:
# register for deletion
self.teardown_tmp_dirs.append(tmp_dir)
return tmp_dir
def tearDown(self):
# remove registered temp dirs
for path in self.teardown_tmp_dirs:
shutil.rmtree(path, ignore_errors=True)
self.teardown_tmp_dirs = []
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment