Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
25e83894
"test/git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "cf0b6d6ddcfde3fbe49dab57245e48728d5161a5"
Commit
25e83894
authored
Aug 26, 2019
by
LysandreJik
Browse files
Tests for added AutoModels
parent
dc43215c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
1 deletion
+41
-1
pytorch_transformers/tests/modeling_auto_test.py
pytorch_transformers/tests/modeling_auto_test.py
+41
-1
No files found.
pytorch_transformers/tests/modeling_auto_test.py
View file @
25e83894
...
...
@@ -21,7 +21,11 @@ import shutil
import
pytest
import
logging
from
pytorch_transformers
import
AutoConfig
,
BertConfig
,
AutoModel
,
BertModel
from
pytorch_transformers
import
(
AutoConfig
,
BertConfig
,
AutoModel
,
BertModel
,
AutoModelWithLMHead
,
BertForMaskedLM
,
AutoModelForSequenceClassification
,
BertForSequenceClassification
,
AutoModelForQuestionAnswering
,
BertForQuestionAnswering
)
from
pytorch_transformers.modeling_bert
import
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from
.modeling_common_test
import
(
CommonTestCases
,
ConfigTester
,
ids_tensor
)
...
...
@@ -42,6 +46,42 @@ class AutoModelTest(unittest.TestCase):
for
value
in
loading_info
.
values
():
self
.
assertEqual
(
len
(
value
),
0
)
def
test_lmhead_model_from_pretrained
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
for
model_name
in
list
(
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
BertConfig
)
model
=
AutoModelWithLMHead
.
from_pretrained
(
model_name
)
model
,
loading_info
=
AutoModelWithLMHead
.
from_pretrained
(
model_name
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
BertForMaskedLM
)
def
test_sequence_classification_model_from_pretrained
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
for
model_name
in
list
(
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
BertConfig
)
model
=
AutoModelForSequenceClassification
.
from_pretrained
(
model_name
)
model
,
loading_info
=
AutoModelForSequenceClassification
.
from_pretrained
(
model_name
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
BertForSequenceClassification
)
def
test_question_answering_model_from_pretrained
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
for
model_name
in
list
(
BERT_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
config
=
AutoConfig
.
from_pretrained
(
model_name
)
self
.
assertIsNotNone
(
config
)
self
.
assertIsInstance
(
config
,
BertConfig
)
model
=
AutoModelForQuestionAnswering
.
from_pretrained
(
model_name
)
model
,
loading_info
=
AutoModelForQuestionAnswering
.
from_pretrained
(
model_name
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsInstance
(
model
,
BertForQuestionAnswering
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment