Unverified Commit eb613b56 authored by Kevin Canwen Xu's avatar Kevin Canwen Xu Committed by GitHub
Browse files

Use hash to clean the test dirs (#6475)

* Use hash to clean the test dirs

* Use hash to clean the test dirs

* Use hash to clean the test dirs

* fix
parent 680f1337
...@@ -20,7 +20,7 @@ def get_setup_file(): ...@@ -20,7 +20,7 @@ def get_setup_file():
return args.f return args.f
def clean_test_dir(path="./tests/fixtures/tests_samples/temp_dir"): def clean_test_dir(path):
shutil.rmtree(path, ignore_errors=True) shutil.rmtree(path, ignore_errors=True)
...@@ -37,7 +37,6 @@ class PabeeTests(unittest.TestCase): ...@@ -37,7 +37,6 @@ class PabeeTests(unittest.TestCase):
--task_name mrpc --task_name mrpc
--do_train --do_train
--do_eval --do_eval
--output_dir ./tests/fixtures/tests_samples/temp_dir
--per_gpu_train_batch_size=2 --per_gpu_train_batch_size=2
--per_gpu_eval_batch_size=1 --per_gpu_eval_batch_size=1
--learning_rate=2e-5 --learning_rate=2e-5
...@@ -46,10 +45,13 @@ class PabeeTests(unittest.TestCase): ...@@ -46,10 +45,13 @@ class PabeeTests(unittest.TestCase):
--overwrite_output_dir --overwrite_output_dir
--seed=42 --seed=42
--max_seq_length=128 --max_seq_length=128
""".split() """
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_glue_with_pabee.main() result = run_glue_with_pabee.main()
for value in result.values(): for value in result.values():
self.assertGreaterEqual(value, 0.75) self.assertGreaterEqual(value, 0.75)
clean_test_dir() clean_test_dir(output_dir)
...@@ -52,7 +52,7 @@ def get_setup_file(): ...@@ -52,7 +52,7 @@ def get_setup_file():
return args.f return args.f
def clean_test_dir(path="./tests/fixtures/tests_samples/temp_dir"): def clean_test_dir(path):
shutil.rmtree(path, ignore_errors=True) shutil.rmtree(path, ignore_errors=True)
...@@ -68,7 +68,6 @@ class ExamplesTests(unittest.TestCase): ...@@ -68,7 +68,6 @@ class ExamplesTests(unittest.TestCase):
--task_name mrpc --task_name mrpc
--do_train --do_train
--do_eval --do_eval
--output_dir ./tests/fixtures/tests_samples/temp_dir
--per_device_train_batch_size=2 --per_device_train_batch_size=2
--per_device_eval_batch_size=1 --per_device_eval_batch_size=1
--learning_rate=1e-4 --learning_rate=1e-4
...@@ -77,13 +76,16 @@ class ExamplesTests(unittest.TestCase): ...@@ -77,13 +76,16 @@ class ExamplesTests(unittest.TestCase):
--overwrite_output_dir --overwrite_output_dir
--seed=42 --seed=42
--max_seq_length=128 --max_seq_length=128
""".split() """
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_glue.main() result = run_glue.main()
del result["eval_loss"] del result["eval_loss"]
for value in result.values(): for value in result.values():
self.assertGreaterEqual(value, 0.75) self.assertGreaterEqual(value, 0.75)
clean_test_dir() clean_test_dir(output_dir)
def test_run_pl_glue(self): def test_run_pl_glue(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
...@@ -96,13 +98,15 @@ class ExamplesTests(unittest.TestCase): ...@@ -96,13 +98,15 @@ class ExamplesTests(unittest.TestCase):
--task mrpc --task mrpc
--do_train --do_train
--do_predict --do_predict
--output_dir ./tests/fixtures/tests_samples/temp_dir
--train_batch_size=32 --train_batch_size=32
--learning_rate=1e-4 --learning_rate=1e-4
--num_train_epochs=1 --num_train_epochs=1
--seed=42 --seed=42
--max_seq_length=128 --max_seq_length=128
""".split() """
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
if torch.cuda.is_available(): if torch.cuda.is_available():
testargs += ["--fp16", "--gpus=1"] testargs += ["--fp16", "--gpus=1"]
...@@ -119,7 +123,7 @@ class ExamplesTests(unittest.TestCase): ...@@ -119,7 +123,7 @@ class ExamplesTests(unittest.TestCase):
# for k, v in result.items(): # for k, v in result.items():
# self.assertGreaterEqual(v, 0.75, f"({k})") # self.assertGreaterEqual(v, 0.75, f"({k})")
# #
clean_test_dir() clean_test_dir(output_dir)
def test_run_language_modeling(self): def test_run_language_modeling(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
...@@ -133,17 +137,19 @@ class ExamplesTests(unittest.TestCase): ...@@ -133,17 +137,19 @@ class ExamplesTests(unittest.TestCase):
--line_by_line --line_by_line
--train_data_file ./tests/fixtures/sample_text.txt --train_data_file ./tests/fixtures/sample_text.txt
--eval_data_file ./tests/fixtures/sample_text.txt --eval_data_file ./tests/fixtures/sample_text.txt
--output_dir ./tests/fixtures/tests_samples/temp_dir
--overwrite_output_dir --overwrite_output_dir
--do_train --do_train
--do_eval --do_eval
--num_train_epochs=1 --num_train_epochs=1
--no_cuda --no_cuda
""".split() """
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_language_modeling.main() result = run_language_modeling.main()
self.assertLess(result["perplexity"], 35) self.assertLess(result["perplexity"], 35)
clean_test_dir() clean_test_dir(output_dir)
def test_run_squad(self): def test_run_squad(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
...@@ -154,7 +160,6 @@ class ExamplesTests(unittest.TestCase): ...@@ -154,7 +160,6 @@ class ExamplesTests(unittest.TestCase):
--model_type=distilbert --model_type=distilbert
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad --model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
--data_dir=./tests/fixtures/tests_samples/SQUAD --data_dir=./tests/fixtures/tests_samples/SQUAD
--output_dir=./tests/fixtures/tests_samples/temp_dir
--max_steps=10 --max_steps=10
--warmup_steps=2 --warmup_steps=2
--do_train --do_train
...@@ -165,12 +170,15 @@ class ExamplesTests(unittest.TestCase): ...@@ -165,12 +170,15 @@ class ExamplesTests(unittest.TestCase):
--per_gpu_eval_batch_size=1 --per_gpu_eval_batch_size=1
--overwrite_output_dir --overwrite_output_dir
--seed=42 --seed=42
""".split() """
output_dir = "./tests/fixtures/tests_samples/temp_dir_{}".format(hash(testargs))
testargs += "--output_dir " + output_dir
testargs = testargs.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_squad.main() result = run_squad.main()
self.assertGreaterEqual(result["f1"], 25) self.assertGreaterEqual(result["f1"], 25)
self.assertGreaterEqual(result["exact"], 21) self.assertGreaterEqual(result["exact"], 21)
clean_test_dir() clean_test_dir(output_dir)
def test_generation(self): def test_generation(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment