Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cf8a70bf
Commit
cf8a70bf
authored
Jan 11, 2020
by
Julien Chaumond
Browse files
More AutoConfig tests
parent
6bb3edc3
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
35 additions
and
4 deletions
+35
-4
tests/test_configuration_auto.py
tests/test_configuration_auto.py
+7
-1
tests/test_modeling_auto.py
tests/test_modeling_auto.py
+9
-1
tests/test_modeling_tf_auto.py
tests/test_modeling_tf_auto.py
+9
-1
tests/test_tokenization_auto.py
tests/test_tokenization_auto.py
+8
-1
tests/utils.py
tests/utils.py
+2
-0
No files found.
tests/test_configuration_auto.py
View file @
cf8a70bf
...
...
@@ -20,6 +20,8 @@ from transformers.configuration_auto import AutoConfig
from
transformers.configuration_bert
import
BertConfig
from
transformers.configuration_roberta
import
RobertaConfig
from
.utils
import
DUMMY_UNKWOWN_IDENTIFIER
SAMPLE_ROBERTA_CONFIG
=
os
.
path
.
join
(
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
)),
"fixtures/dummy-config.json"
)
...
...
@@ -29,10 +31,14 @@ class AutoConfigTest(unittest.TestCase):
config
=
AutoConfig
.
from_pretrained
(
"bert-base-uncased"
)
self
.
assertIsInstance
(
config
,
BertConfig
)
def
test_config_
from_
model_type
(
self
):
def
test_config_model_type
_from_local_file
(
self
):
config
=
AutoConfig
.
from_pretrained
(
SAMPLE_ROBERTA_CONFIG
)
self
.
assertIsInstance
(
config
,
RobertaConfig
)
def
test_config_model_type_from_model_identifier
(
self
):
config
=
AutoConfig
.
from_pretrained
(
DUMMY_UNKWOWN_IDENTIFIER
)
self
.
assertIsInstance
(
config
,
RobertaConfig
)
def
test_config_for_model_str
(
self
):
config
=
AutoConfig
.
for_model
(
"roberta"
)
self
.
assertIsInstance
(
config
,
RobertaConfig
)
tests/test_modeling_auto.py
View file @
cf8a70bf
...
...
@@ -19,7 +19,7 @@ import unittest
from
transformers
import
is_torch_available
from
.utils
import
SMALL_MODEL_IDENTIFIER
,
require_torch
,
slow
from
.utils
import
DUMMY_UNKWOWN_IDENTIFIER
,
SMALL_MODEL_IDENTIFIER
,
require_torch
,
slow
if
is_torch_available
():
...
...
@@ -30,6 +30,7 @@ if is_torch_available():
BertModel
,
AutoModelWithLMHead
,
BertForMaskedLM
,
RobertaForMaskedLM
,
AutoModelForSequenceClassification
,
BertForSequenceClassification
,
AutoModelForQuestionAnswering
,
...
...
@@ -102,3 +103,10 @@ class AutoModelTest(unittest.TestCase):
self
.
assertIsInstance
(
model
,
BertForMaskedLM
)
self
.
assertEqual
(
model
.
num_parameters
(),
14830
)
self
.
assertEqual
(
model
.
num_parameters
(
only_trainable
=
True
),
14830
)
def
test_from_identifier_from_model_type
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
model
=
AutoModelWithLMHead
.
from_pretrained
(
DUMMY_UNKWOWN_IDENTIFIER
)
self
.
assertIsInstance
(
model
,
RobertaForMaskedLM
)
self
.
assertEqual
(
model
.
num_parameters
(),
14830
)
self
.
assertEqual
(
model
.
num_parameters
(
only_trainable
=
True
),
14830
)
tests/test_modeling_tf_auto.py
View file @
cf8a70bf
...
...
@@ -19,7 +19,7 @@ import unittest
from
transformers
import
is_tf_available
from
.utils
import
SMALL_MODEL_IDENTIFIER
,
require_tf
,
slow
from
.utils
import
DUMMY_UNKWOWN_IDENTIFIER
,
SMALL_MODEL_IDENTIFIER
,
require_tf
,
slow
if
is_tf_available
():
...
...
@@ -30,6 +30,7 @@ if is_tf_available():
TFBertModel
,
TFAutoModelWithLMHead
,
TFBertForMaskedLM
,
TFRobertaForMaskedLM
,
TFAutoModelForSequenceClassification
,
TFBertForSequenceClassification
,
TFAutoModelForQuestionAnswering
,
...
...
@@ -101,3 +102,10 @@ class TFAutoModelTest(unittest.TestCase):
self
.
assertIsInstance
(
model
,
TFBertForMaskedLM
)
self
.
assertEqual
(
model
.
num_parameters
(),
14830
)
self
.
assertEqual
(
model
.
num_parameters
(
only_trainable
=
True
),
14830
)
def
test_from_identifier_from_model_type
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
model
=
TFAutoModelWithLMHead
.
from_pretrained
(
DUMMY_UNKWOWN_IDENTIFIER
)
self
.
assertIsInstance
(
model
,
TFRobertaForMaskedLM
)
self
.
assertEqual
(
model
.
num_parameters
(),
14830
)
self
.
assertEqual
(
model
.
num_parameters
(
only_trainable
=
True
),
14830
)
tests/test_tokenization_auto.py
View file @
cf8a70bf
...
...
@@ -23,9 +23,10 @@ from transformers import (
AutoTokenizer
,
BertTokenizer
,
GPT2Tokenizer
,
RobertaTokenizer
,
)
from
.utils
import
SMALL_MODEL_IDENTIFIER
,
slow
# noqa: F401
from
.utils
import
DUMMY_UNKWOWN_IDENTIFIER
,
SMALL_MODEL_IDENTIFIER
,
slow
# noqa: F401
class
AutoTokenizerTest
(
unittest
.
TestCase
):
...
...
@@ -49,3 +50,9 @@ class AutoTokenizerTest(unittest.TestCase):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
SMALL_MODEL_IDENTIFIER
)
self
.
assertIsInstance
(
tokenizer
,
BertTokenizer
)
self
.
assertEqual
(
len
(
tokenizer
),
12
)
def
test_tokenizer_from_model_type
(
self
):
logging
.
basicConfig
(
level
=
logging
.
INFO
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
DUMMY_UNKWOWN_IDENTIFIER
)
self
.
assertIsInstance
(
tokenizer
,
RobertaTokenizer
)
self
.
assertEqual
(
len
(
tokenizer
),
20
)
tests/utils.py
View file @
cf8a70bf
...
...
@@ -9,6 +9,8 @@ from transformers.file_utils import _tf_available, _torch_available
CACHE_DIR
=
os
.
path
.
join
(
tempfile
.
gettempdir
(),
"transformers_test"
)
SMALL_MODEL_IDENTIFIER
=
"julien-c/bert-xsmall-dummy"
DUMMY_UNKWOWN_IDENTIFIER
=
"julien-c/dummy-unknown"
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
def
parse_flag_from_env
(
key
,
default
=
False
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment