"web/extensions/vscode:/vscode.git/clone" did not exist on "37b70d798701295819a50bc64d373b0e6235cf14"
Unverified Commit c3e60749 authored by Sam Shleifer's avatar Sam Shleifer Committed by GitHub
Browse files

[cleanup] examples test_run_squad uses tiny model (#5059)

parent 439aa1d6
......@@ -55,7 +55,7 @@ class ExamplesTests(unittest.TestCase):
testargs = """
run_glue.py
--model_name_or_path bert-base-uncased
--model_name_or_path distilbert-base-uncased
--data_dir ./tests/fixtures/tests_samples/MRPC/
--task_name mrpc
--do_train
......@@ -79,6 +79,7 @@ class ExamplesTests(unittest.TestCase):
def test_run_language_modeling(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
# TODO: switch to smaller model like sshleifer/tiny-distilroberta-base
testargs = """
run_language_modeling.py
......@@ -105,10 +106,9 @@ class ExamplesTests(unittest.TestCase):
testargs = """
run_squad.py
--model_type=bert
--model_name_or_path=bert-base-uncased
--model_type=distilbert
--model_name_or_path=sshleifer/tiny-distilbert-base-cased-distilled-squad
--data_dir=./tests/fixtures/tests_samples/SQUAD
--model_name=bert-base-uncased
--output_dir=./tests/fixtures/tests_samples/temp_dir
--max_steps=10
--warmup_steps=2
......@@ -123,15 +123,15 @@ class ExamplesTests(unittest.TestCase):
""".split()
with patch.object(sys, "argv", testargs):
result = run_squad.main()
self.assertGreaterEqual(result["f1"], 30)
self.assertGreaterEqual(result["exact"], 30)
self.assertGreaterEqual(result["f1"], 25)
self.assertGreaterEqual(result["exact"], 21)
def test_generation(self):
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt")
model_type, model_name = ("--model_type=gpt2", "--model_name_or_path=sshleifer/tiny-gpt2")
with patch.object(sys, "argv", testargs + [model_type, model_name]):
result = run_generation.main()
self.assertGreaterEqual(len(result[0]), 10)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment