"git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "4b9858ec346a2020ac4c0c2c1c7abd684d2481a4"
Unverified Commit 92f8ce2e authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

Fix deebert tests (#6102)

parent c49cd927
...@@ -21,11 +21,13 @@ def get_setup_file(): ...@@ -21,11 +21,13 @@ def get_setup_file():
class DeeBertTests(unittest.TestCase): class DeeBertTests(unittest.TestCase):
@slow def setup(self) -> None:
def test_glue_deebert(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
@slow
def test_glue_deebert_train(self):
train_args = """ train_args = """
run_glue_deebert.py run_glue_deebert.py
--model_type roberta --model_type roberta
...@@ -48,6 +50,10 @@ class DeeBertTests(unittest.TestCase): ...@@ -48,6 +50,10 @@ class DeeBertTests(unittest.TestCase):
--overwrite_cache --overwrite_cache
--eval_after_first_stage --eval_after_first_stage
""".split() """.split()
with patch.object(sys, "argv", train_args):
result = run_glue_deebert.main()
for value in result.values():
self.assertGreaterEqual(value, 0.666)
eval_args = """ eval_args = """
run_glue_deebert.py run_glue_deebert.py
...@@ -65,6 +71,10 @@ class DeeBertTests(unittest.TestCase): ...@@ -65,6 +71,10 @@ class DeeBertTests(unittest.TestCase):
--overwrite_cache --overwrite_cache
--per_gpu_eval_batch_size=1 --per_gpu_eval_batch_size=1
""".split() """.split()
with patch.object(sys, "argv", eval_args):
result = run_glue_deebert.main()
for value in result.values():
self.assertGreaterEqual(value, 0.666)
entropy_eval_args = """ entropy_eval_args = """
run_glue_deebert.py run_glue_deebert.py
...@@ -82,18 +92,7 @@ class DeeBertTests(unittest.TestCase): ...@@ -82,18 +92,7 @@ class DeeBertTests(unittest.TestCase):
--overwrite_cache --overwrite_cache
--per_gpu_eval_batch_size=1 --per_gpu_eval_batch_size=1
""".split() """.split()
with patch.object(sys, "argv", train_args):
result = run_glue_deebert.main()
for value in result.values():
self.assertGreaterEqual(value, 0.75)
with patch.object(sys, "argv", eval_args):
result = run_glue_deebert.main()
for value in result.values():
self.assertGreaterEqual(value, 0.75)
with patch.object(sys, "argv", entropy_eval_args): with patch.object(sys, "argv", entropy_eval_args):
result = run_glue_deebert.main() result = run_glue_deebert.main()
for value in result.values(): for value in result.values():
self.assertGreaterEqual(value, 0.75) self.assertGreaterEqual(value, 0.666)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment