Unverified Commit c3e60749 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[cleanup] examples test_run_squad uses tiny model (#5059)

parent 439aa1d6
...@@ -55,7 +55,7 @@ class ExamplesTests(unittest.TestCase): ...@@ -55,7 +55,7 @@ class ExamplesTests(unittest.TestCase):
testargs = """ testargs = """
run_glue.py run_glue.py
--model_name_or_path bert-base-uncased --model_name_or_path distilbert-base-uncased
--data_dir ./tests/fixtures/tests_samples/MRPC/ --data_dir ./tests/fixtures/tests_samples/MRPC/
--task_name mrpc --task_name mrpc
--do_train --do_train
...@@ -79,6 +79,7 @@ class ExamplesTests(unittest.TestCase): ...@@ -79,6 +79,7 @@ class ExamplesTests(unittest.TestCase):
def test_run_language_modeling(self): def test_run_language_modeling(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
# TODO: switch to smaller model like sshleifer/tiny-distilroberta-base
testargs = """ testargs = """
run_language_modeling.py run_language_modeling.py
...@@ -105,10 +106,9 @@ class ExamplesTests(unittest.TestCase): ...@@ -105,10 +106,9 @@ class ExamplesTests(unittest.TestCase):
testargs = """ testargs = """
run_squad.py run_squad.py
--model_type=bert --model_type=distilbert
--model_name_or_path=bert-base-uncased --model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
--data_dir=./tests/fixtures/tests_samples/SQUAD --data_dir=./tests/fixtures/tests_samples/SQUAD
--model_name=bert-base-uncased
--output_dir=./tests/fixtures/tests_samples/temp_dir --output_dir=./tests/fixtures/tests_samples/temp_dir
--max_steps=10 --max_steps=10
--warmup_steps=2 --warmup_steps=2
...@@ -123,15 +123,15 @@ class ExamplesTests(unittest.TestCase): ...@@ -123,15 +123,15 @@ class ExamplesTests(unittest.TestCase):
""".split() """.split()
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
result = run_squad.main() result = run_squad.main()
self.assertGreaterEqual(result["f1"], 30) self.assertGreaterEqual(result["f1"], 25)
self.assertGreaterEqual(result["exact"], 30) self.assertGreaterEqual(result["exact"], 21)
def test_generation(self): def test_generation(self):
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"] testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt") model_type, model_name = ("--model_type=gpt2", "--model_name_or_path=sshleifer/tiny-gpt2")
with patch.object(sys, "argv", testargs + [model_type, model_name]): with patch.object(sys, "argv", testargs + [model_type, model_name]):
result = run_generation.main() result = run_generation.main()
self.assertGreaterEqual(len(result[0]), 10) self.assertGreaterEqual(len(result[0]), 10)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment