Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c3e60749
Unverified
Commit
c3e60749
authored
Jun 16, 2020
by
Sam Shleifer
Committed by
GitHub
Jun 16, 2020
Browse files
[cleanup] examples test_run_squad uses tiny model (#5059)
parent
439aa1d6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
7 deletions
+7
-7
examples/test_examples.py
examples/test_examples.py
+7
-7
No files found.
examples/test_examples.py
View file @
c3e60749
...
@@ -55,7 +55,7 @@ class ExamplesTests(unittest.TestCase):
...
@@ -55,7 +55,7 @@ class ExamplesTests(unittest.TestCase):
testargs
=
"""
testargs
=
"""
run_glue.py
run_glue.py
--model_name_or_path bert-base-uncased
--model_name_or_path
distil
bert-base-uncased
--data_dir ./tests/fixtures/tests_samples/MRPC/
--data_dir ./tests/fixtures/tests_samples/MRPC/
--task_name mrpc
--task_name mrpc
--do_train
--do_train
...
@@ -79,6 +79,7 @@ class ExamplesTests(unittest.TestCase):
...
@@ -79,6 +79,7 @@ class ExamplesTests(unittest.TestCase):
def
test_run_language_modeling
(
self
):
def
test_run_language_modeling
(
self
):
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
logger
.
addHandler
(
stream_handler
)
# TODO: switch to smaller model like sshleifer/tiny-distilroberta-base
testargs
=
"""
testargs
=
"""
run_language_modeling.py
run_language_modeling.py
...
@@ -105,10 +106,9 @@ class ExamplesTests(unittest.TestCase):
...
@@ -105,10 +106,9 @@ class ExamplesTests(unittest.TestCase):
testargs
=
"""
testargs
=
"""
run_squad.py
run_squad.py
--model_type=bert
--model_type=
distil
bert
--model_name_or_path=bert-base-
un
cased
--model_name_or_path=
sshleifer/tiny-distil
bert-base-cased
-distilled-squad
--data_dir=./tests/fixtures/tests_samples/SQUAD
--data_dir=./tests/fixtures/tests_samples/SQUAD
--model_name=bert-base-uncased
--output_dir=./tests/fixtures/tests_samples/temp_dir
--output_dir=./tests/fixtures/tests_samples/temp_dir
--max_steps=10
--max_steps=10
--warmup_steps=2
--warmup_steps=2
...
@@ -123,15 +123,15 @@ class ExamplesTests(unittest.TestCase):
...
@@ -123,15 +123,15 @@ class ExamplesTests(unittest.TestCase):
"""
.
split
()
"""
.
split
()
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
result
=
run_squad
.
main
()
result
=
run_squad
.
main
()
self
.
assertGreaterEqual
(
result
[
"f1"
],
30
)
self
.
assertGreaterEqual
(
result
[
"f1"
],
25
)
self
.
assertGreaterEqual
(
result
[
"exact"
],
30
)
self
.
assertGreaterEqual
(
result
[
"exact"
],
21
)
def
test_generation
(
self
):
def
test_generation
(
self
):
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
logger
.
addHandler
(
stream_handler
)
testargs
=
[
"run_generation.py"
,
"--prompt=Hello"
,
"--length=10"
,
"--seed=42"
]
testargs
=
[
"run_generation.py"
,
"--prompt=Hello"
,
"--length=10"
,
"--seed=42"
]
model_type
,
model_name
=
(
"--model_type=
openai-
gpt"
,
"--model_name_or_path=
openai
-gpt"
)
model_type
,
model_name
=
(
"--model_type=gpt
2
"
,
"--model_name_or_path=
sshleifer/tiny
-gpt
2
"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
+
[
model_type
,
model_name
]):
with
patch
.
object
(
sys
,
"argv"
,
testargs
+
[
model_type
,
model_name
]):
result
=
run_generation
.
main
()
result
=
run_generation
.
main
()
self
.
assertGreaterEqual
(
len
(
result
[
0
]),
10
)
self
.
assertGreaterEqual
(
len
(
result
[
0
]),
10
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment