Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c3d9ac76
Unverified
Commit
c3d9ac76
authored
Jul 21, 2021
by
Lysandre Debut
Committed by
GitHub
Jul 21, 2021
Browse files
Expose get_config() on ModelTesters (#12812)
* Expose get_config() on ModelTesters * Typo
parent
cabcc751
Changes
53
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
667 additions
and
624 deletions
+667
-624
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_{{cookiecutter.lowercase_modelname}}.py
...e}}/test_modeling_{{cookiecutter.lowercase_modelname}}.py
+7
-4
tests/test_modeling_albert.py
tests/test_modeling_albert.py
+7
-5
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+7
-6
tests/test_modeling_bert.py
tests/test_modeling_bert.py
+10
-5
tests/test_modeling_bert_generation.py
tests/test_modeling_bert_generation.py
+8
-5
tests/test_modeling_big_bird.py
tests/test_modeling_big_bird.py
+7
-5
tests/test_modeling_bigbird_pegasus.py
tests/test_modeling_bigbird_pegasus.py
+7
-6
tests/test_modeling_blenderbot.py
tests/test_modeling_blenderbot.py
+8
-7
tests/test_modeling_blenderbot_small.py
tests/test_modeling_blenderbot_small.py
+8
-12
tests/test_modeling_canine.py
tests/test_modeling_canine.py
+7
-5
tests/test_modeling_clip.py
tests/test_modeling_clip.py
+20
-8
tests/test_modeling_convbert.py
tests/test_modeling_convbert.py
+7
-5
tests/test_modeling_ctrl.py
tests/test_modeling_ctrl.py
+19
-17
tests/test_modeling_deberta.py
tests/test_modeling_deberta.py
+175
-175
tests/test_modeling_deberta_v2.py
tests/test_modeling_deberta_v2.py
+175
-175
tests/test_modeling_deit.py
tests/test_modeling_deit.py
+7
-4
tests/test_modeling_detr.py
tests/test_modeling_detr.py
+7
-5
tests/test_modeling_distilbert.py
tests/test_modeling_distilbert.py
+157
-156
tests/test_modeling_dpr.py
tests/test_modeling_dpr.py
+8
-5
tests/test_modeling_electra.py
tests/test_modeling_electra.py
+16
-14
No files found.
templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/test_modeling_{{cookiecutter.lowercase_modelname}}.py
View file @
c3d9ac76
...
...
@@ -22,6 +22,7 @@ from tests.test_modeling_common import floats_tensor
from
transformers
import
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
transformers
import
{{
cookiecutter
.
camelcase_modelname
}}
Config
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
ids_tensor
,
random_attention_mask
...
...
@@ -30,7 +31,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
{{
cookiecutter
.
camelcase_modelname
}}
Config
,
{{
cookiecutter
.
camelcase_modelname
}}
ForCausalLM
,
{{
cookiecutter
.
camelcase_modelname
}}
ForMaskedLM
,
{{
cookiecutter
.
camelcase_modelname
}}
ForMultipleChoice
,
...
...
@@ -112,7 +112,12 @@ class {{cookiecutter.camelcase_modelname}}ModelTester:
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
{{
cookiecutter
.
camelcase_modelname
}}
Config
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
{{
cookiecutter
.
camelcase_modelname
}}
Config
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -127,8 +132,6 @@ class {{cookiecutter.camelcase_modelname}}ModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/test_modeling_albert.py
View file @
c3d9ac76
...
...
@@ -16,7 +16,7 @@
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
AlbertConfig
,
is_torch_available
from
transformers.models.auto
import
get_values
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
...
@@ -29,7 +29,6 @@ if is_torch_available():
from
transformers
import
(
MODEL_FOR_PRETRAINING_MAPPING
,
AlbertConfig
,
AlbertForMaskedLM
,
AlbertForMultipleChoice
,
AlbertForPreTraining
,
...
...
@@ -90,7 +89,12 @@ class AlbertModelTester:
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
AlbertConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
AlbertConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -105,8 +109,6 @@ class AlbertModelTester:
num_hidden_groups
=
self
.
num_hidden_groups
,
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
...
...
tests/test_modeling_bart.py
View file @
c3d9ac76
...
...
@@ -21,7 +21,7 @@ import unittest
import
timeout_decorator
# noqa
from
transformers
import
is_torch_available
from
transformers
import
BartConfig
,
is_torch_available
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
require_sentencepiece
,
require_tokenizers
,
require_torch
,
slow
,
torch_device
...
...
@@ -35,7 +35,6 @@ if is_torch_available():
from
transformers
import
(
AutoModelForSequenceClassification
,
BartConfig
,
BartForCausalLM
,
BartForConditionalGeneration
,
BartForQuestionAnswering
,
...
...
@@ -78,7 +77,6 @@ def prepare_bart_inputs_dict(
}
@
require_torch
class
BartModelTester
:
def
__init__
(
self
,
...
...
@@ -127,7 +125,12 @@ class BartModelTester:
decoder_input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
BartConfig
(
config
=
self
.
get_config
()
inputs_dict
=
prepare_bart_inputs_dict
(
config
,
input_ids
,
decoder_input_ids
)
return
config
,
inputs_dict
def
get_config
(
self
):
return
BartConfig
(
vocab_size
=
self
.
vocab_size
,
d_model
=
self
.
hidden_size
,
encoder_layers
=
self
.
num_hidden_layers
,
...
...
@@ -143,8 +146,6 @@ class BartModelTester:
bos_token_id
=
self
.
bos_token_id
,
pad_token_id
=
self
.
pad_token_id
,
)
inputs_dict
=
prepare_bart_inputs_dict
(
config
,
input_ids
,
decoder_input_ids
)
return
config
,
inputs_dict
def
prepare_config_and_inputs_for_common
(
self
):
config
,
inputs_dict
=
self
.
prepare_config_and_inputs
()
...
...
tests/test_modeling_bert.py
View file @
c3d9ac76
...
...
@@ -16,7 +16,7 @@
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
BertConfig
,
is_torch_available
from
transformers.models.auto
import
get_values
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
...
@@ -30,7 +30,6 @@ if is_torch_available():
from
transformers
import
(
MODEL_FOR_PRETRAINING_MAPPING
,
BertConfig
,
BertForMaskedLM
,
BertForMultipleChoice
,
BertForNextSentencePrediction
,
...
...
@@ -112,7 +111,15 @@ class BertModelTester:
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
BertConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
"""
Returns a tiny configuration by default.
"""
return
BertConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -127,8 +134,6 @@ class BertModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/test_modeling_bert_generation.py
View file @
c3d9ac76
...
...
@@ -16,7 +16,7 @@
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
BertGenerationConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -27,7 +27,7 @@ from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor, r
if
is_torch_available
():
import
torch
from
transformers
import
BertGenerationConfig
,
BertGenerationDecoder
,
BertGenerationEncoder
from
transformers
import
BertGenerationDecoder
,
BertGenerationEncoder
class
BertGenerationEncoderTester
:
...
...
@@ -79,7 +79,12 @@ class BertGenerationEncoderTester:
if
self
.
use_labels
:
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
BertGenerationConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
input_mask
,
token_labels
def
get_config
(
self
):
return
BertGenerationConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -93,8 +98,6 @@ class BertGenerationEncoderTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
input_mask
,
token_labels
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/test_modeling_big_bird.py
View file @
c3d9ac76
...
...
@@ -18,7 +18,7 @@
import
unittest
from
tests.test_modeling_common
import
floats_tensor
from
transformers
import
is_torch_available
from
transformers
import
BigBirdConfig
,
is_torch_available
from
transformers.models.auto
import
get_values
from
transformers.models.big_bird.tokenization_big_bird
import
BigBirdTokenizer
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
...
@@ -32,7 +32,6 @@ if is_torch_available():
from
transformers
import
(
MODEL_FOR_PRETRAINING_MAPPING
,
BigBirdConfig
,
BigBirdForCausalLM
,
BigBirdForMaskedLM
,
BigBirdForMultipleChoice
,
...
...
@@ -126,7 +125,12 @@ class BigBirdModelTester:
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
BigBirdConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
BigBirdConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -147,8 +151,6 @@ class BigBirdModelTester:
position_embedding_type
=
self
.
position_embedding_type
,
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/test_modeling_bigbird_pegasus.py
View file @
c3d9ac76
...
...
@@ -19,7 +19,7 @@ import copy
import
tempfile
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
BigBirdPegasusConfig
,
is_torch_available
from
transformers.testing_utils
import
require_sentencepiece
,
require_tokenizers
,
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -31,7 +31,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
BigBirdPegasusConfig
,
BigBirdPegasusForCausalLM
,
BigBirdPegasusForConditionalGeneration
,
BigBirdPegasusForQuestionAnswering
,
...
...
@@ -69,7 +68,6 @@ def prepare_bigbird_pegasus_inputs_dict(
return
input_dict
@
require_torch
class
BigBirdPegasusModelTester
:
def
__init__
(
self
,
...
...
@@ -129,7 +127,12 @@ class BigBirdPegasusModelTester:
decoder_input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
BigBirdPegasusConfig
(
config
=
self
.
get_config
()
inputs_dict
=
prepare_bigbird_pegasus_inputs_dict
(
config
,
input_ids
,
decoder_input_ids
)
return
config
,
inputs_dict
def
get_config
(
self
):
return
BigBirdPegasusConfig
(
vocab_size
=
self
.
vocab_size
,
d_model
=
self
.
hidden_size
,
encoder_layers
=
self
.
num_hidden_layers
,
...
...
@@ -150,8 +153,6 @@ class BigBirdPegasusModelTester:
num_random_blocks
=
self
.
num_random_blocks
,
scale_embedding
=
self
.
scale_embedding
,
)
inputs_dict
=
prepare_bigbird_pegasus_inputs_dict
(
config
,
input_ids
,
decoder_input_ids
)
return
config
,
inputs_dict
def
prepare_config_and_inputs_for_common
(
self
):
config
,
inputs_dict
=
self
.
prepare_config_and_inputs
()
...
...
tests/test_modeling_blenderbot.py
View file @
c3d9ac76
...
...
@@ -17,7 +17,7 @@
import
tempfile
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
BlenderbotConfig
,
is_torch_available
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
require_sentencepiece
,
require_tokenizers
,
require_torch
,
slow
,
torch_device
...
...
@@ -29,7 +29,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
if
is_torch_available
():
import
torch
from
transformers
import
BlenderbotConfig
,
BlenderbotForConditionalGeneration
,
BlenderbotModel
,
BlenderbotTokenizer
from
transformers
import
BlenderbotForConditionalGeneration
,
BlenderbotModel
,
BlenderbotTokenizer
from
transformers.models.blenderbot.modeling_blenderbot
import
(
BlenderbotDecoder
,
BlenderbotEncoder
,
...
...
@@ -68,7 +68,6 @@ def prepare_blenderbot_inputs_dict(
}
@
require_torch
class
BlenderbotModelTester
:
def
__init__
(
self
,
...
...
@@ -109,7 +108,6 @@ class BlenderbotModelTester:
self
.
bos_token_id
=
bos_token_id
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
).
clamp
(
3
,
)
...
...
@@ -117,7 +115,12 @@ class BlenderbotModelTester:
decoder_input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
BlenderbotConfig
(
config
=
self
.
get_config
()
inputs_dict
=
prepare_blenderbot_inputs_dict
(
config
,
input_ids
,
decoder_input_ids
)
return
config
,
inputs_dict
def
get_config
(
self
):
return
BlenderbotConfig
(
vocab_size
=
self
.
vocab_size
,
d_model
=
self
.
hidden_size
,
encoder_layers
=
self
.
num_hidden_layers
,
...
...
@@ -133,8 +136,6 @@ class BlenderbotModelTester:
bos_token_id
=
self
.
bos_token_id
,
pad_token_id
=
self
.
pad_token_id
,
)
inputs_dict
=
prepare_blenderbot_inputs_dict
(
config
,
input_ids
,
decoder_input_ids
)
return
config
,
inputs_dict
def
prepare_config_and_inputs_for_common
(
self
):
config
,
inputs_dict
=
self
.
prepare_config_and_inputs
()
...
...
tests/test_modeling_blenderbot_small.py
View file @
c3d9ac76
...
...
@@ -17,7 +17,7 @@
import
tempfile
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
BlenderbotSmallConfig
,
is_torch_available
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
...
@@ -29,12 +29,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor
if
is_torch_available
():
import
torch
from
transformers
import
(
BlenderbotSmallConfig
,
BlenderbotSmallForConditionalGeneration
,
BlenderbotSmallModel
,
BlenderbotSmallTokenizer
,
)
from
transformers
import
BlenderbotSmallForConditionalGeneration
,
BlenderbotSmallModel
,
BlenderbotSmallTokenizer
from
transformers.models.blenderbot_small.modeling_blenderbot_small
import
(
BlenderbotSmallDecoder
,
BlenderbotSmallEncoder
,
...
...
@@ -73,7 +68,6 @@ def prepare_blenderbot_small_inputs_dict(
}
@
require_torch
class
BlenderbotSmallModelTester
:
def
__init__
(
self
,
...
...
@@ -114,7 +108,6 @@ class BlenderbotSmallModelTester:
self
.
bos_token_id
=
bos_token_id
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
).
clamp
(
3
,
)
...
...
@@ -122,7 +115,12 @@ class BlenderbotSmallModelTester:
decoder_input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
BlenderbotSmallConfig
(
config
=
self
.
get_config
()
inputs_dict
=
prepare_blenderbot_small_inputs_dict
(
config
,
input_ids
,
decoder_input_ids
)
return
config
,
inputs_dict
def
get_config
(
self
):
return
BlenderbotSmallConfig
(
vocab_size
=
self
.
vocab_size
,
d_model
=
self
.
hidden_size
,
encoder_layers
=
self
.
num_hidden_layers
,
...
...
@@ -138,8 +136,6 @@ class BlenderbotSmallModelTester:
bos_token_id
=
self
.
bos_token_id
,
pad_token_id
=
self
.
pad_token_id
,
)
inputs_dict
=
prepare_blenderbot_small_inputs_dict
(
config
,
input_ids
,
decoder_input_ids
)
return
config
,
inputs_dict
def
prepare_config_and_inputs_for_common
(
self
):
config
,
inputs_dict
=
self
.
prepare_config_and_inputs
()
...
...
tests/test_modeling_canine.py
View file @
c3d9ac76
...
...
@@ -18,7 +18,7 @@
import
unittest
from
typing
import
List
,
Tuple
from
transformers
import
is_torch_available
from
transformers
import
CanineConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -29,7 +29,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
CanineConfig
,
CanineForMultipleChoice
,
CanineForQuestionAnswering
,
CanineForSequenceClassification
,
...
...
@@ -106,7 +105,12 @@ class CanineModelTester:
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
CanineConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
CanineConfig
(
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
num_attention_heads
=
self
.
num_attention_heads
,
...
...
@@ -120,8 +124,6 @@ class CanineModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
...
...
tests/test_modeling_clip.py
View file @
c3d9ac76
...
...
@@ -21,6 +21,7 @@ import tempfile
import
unittest
import
requests
from
transformers
import
CLIPConfig
,
CLIPTextConfig
,
CLIPVisionConfig
from
transformers.file_utils
import
is_torch_available
,
is_vision_available
from
transformers.testing_utils
import
require_torch
,
require_vision
,
slow
,
torch_device
...
...
@@ -32,7 +33,7 @@ if is_torch_available():
import
torch
from
torch
import
nn
from
transformers
import
CLIPConfig
,
CLIPModel
,
CLIPText
Config
,
CLIPTextModel
,
CLIPVisionConfig
,
CLIPVisionModel
from
transformers
import
CLIPModel
,
CLIPText
Model
,
CLIPVisionModel
from
transformers.models.clip.modeling_clip
import
CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
...
...
@@ -77,7 +78,12 @@ class CLIPVisionModelTester:
def
prepare_config_and_inputs
(
self
):
pixel_values
=
floats_tensor
([
self
.
batch_size
,
self
.
num_channels
,
self
.
image_size
,
self
.
image_size
])
config
=
CLIPVisionConfig
(
config
=
self
.
get_config
()
return
config
,
pixel_values
def
get_config
(
self
):
return
CLIPVisionConfig
(
image_size
=
self
.
image_size
,
patch_size
=
self
.
patch_size
,
num_channels
=
self
.
num_channels
,
...
...
@@ -90,8 +96,6 @@ class CLIPVisionModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
pixel_values
def
create_and_check_model
(
self
,
config
,
pixel_values
):
model
=
CLIPVisionModel
(
config
=
config
)
model
.
to
(
torch_device
)
...
...
@@ -323,7 +327,12 @@ class CLIPTextModelTester:
if
self
.
use_input_mask
:
input_mask
=
random_attention_mask
([
self
.
batch_size
,
self
.
seq_length
])
config
=
CLIPTextConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
input_mask
def
get_config
(
self
):
return
CLIPTextConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -335,8 +344,6 @@ class CLIPTextModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
input_mask
def
create_and_check_model
(
self
,
config
,
input_ids
,
input_mask
):
model
=
CLIPTextModel
(
config
=
config
)
model
.
to
(
torch_device
)
...
...
@@ -409,10 +416,15 @@ class CLIPModelTester:
text_config
,
input_ids
,
attention_mask
=
self
.
text_model_tester
.
prepare_config_and_inputs
()
vision_config
,
pixel_values
=
self
.
vision_model_tester
.
prepare_config_and_inputs
()
config
=
CLIPConfig
.
from_text_vision_configs
(
text_config
,
vision_config
,
projection_dim
=
64
)
config
=
self
.
get_config
(
)
return
config
,
input_ids
,
attention_mask
,
pixel_values
def
get_config
(
self
):
return
CLIPConfig
.
from_text_vision_configs
(
self
.
text_model_tester
.
get_config
(),
self
.
vision_model_tester
.
get_config
(),
projection_dim
=
64
)
def
create_and_check_model
(
self
,
config
,
input_ids
,
attention_mask
,
pixel_values
):
model
=
CLIPModel
(
config
).
to
(
torch_device
).
eval
()
result
=
model
(
input_ids
,
pixel_values
,
attention_mask
)
...
...
tests/test_modeling_convbert.py
View file @
c3d9ac76
...
...
@@ -18,7 +18,7 @@
import
unittest
from
tests.test_modeling_common
import
floats_tensor
from
transformers
import
is_torch_available
from
transformers
import
ConvBertConfig
,
is_torch_available
from
transformers.models.auto
import
get_values
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
...
@@ -31,7 +31,6 @@ if is_torch_available():
from
transformers
import
(
MODEL_FOR_QUESTION_ANSWERING_MAPPING
,
ConvBertConfig
,
ConvBertForMaskedLM
,
ConvBertForMultipleChoice
,
ConvBertForQuestionAnswering
,
...
...
@@ -110,7 +109,12 @@ class ConvBertModelTester:
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
ConvBertConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
ConvBertConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -125,8 +129,6 @@ class ConvBertModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
prepare_config_and_inputs_for_decoder
(
self
):
(
config
,
...
...
tests/test_modeling_ctrl.py
View file @
c3d9ac76
...
...
@@ -15,7 +15,7 @@
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
CTRLConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -28,7 +28,6 @@ if is_torch_available():
from
transformers
import
(
CTRL_PRETRAINED_MODEL_ARCHIVE_LIST
,
CTRLConfig
,
CTRLForSequenceClassification
,
CTRLLMHeadModel
,
CTRLModel
,
...
...
@@ -88,21 +87,7 @@ class CTRLModelTester:
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
CTRLConfig
(
vocab_size
=
self
.
vocab_size
,
n_embd
=
self
.
hidden_size
,
n_layer
=
self
.
num_hidden_layers
,
n_head
=
self
.
num_attention_heads
,
# intermediate_size=self.intermediate_size,
# hidden_act=self.hidden_act,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions
=
self
.
max_position_embeddings
,
n_ctx
=
self
.
max_position_embeddings
,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range,
pad_token_id
=
self
.
pad_token_id
,
)
config
=
self
.
get_config
()
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
...
...
@@ -118,6 +103,23 @@ class CTRLModelTester:
choice_labels
,
)
def
get_config
(
self
):
return
CTRLConfig
(
vocab_size
=
self
.
vocab_size
,
n_embd
=
self
.
hidden_size
,
n_layer
=
self
.
num_hidden_layers
,
n_head
=
self
.
num_attention_heads
,
# intermediate_size=self.intermediate_size,
# hidden_act=self.hidden_act,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions
=
self
.
max_position_embeddings
,
n_ctx
=
self
.
max_position_embeddings
,
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range,
pad_token_id
=
self
.
pad_token_id
,
)
def
create_and_check_ctrl_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
CTRLModel
(
config
=
config
)
model
.
to
(
torch_device
)
...
...
tests/test_modeling_deberta.py
View file @
c3d9ac76
...
...
@@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
DebertaConfig
,
is_torch_available
from
transformers.testing_utils
import
require_sentencepiece
,
require_tokenizers
,
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -26,7 +25,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
DebertaConfig
,
DebertaForMaskedLM
,
DebertaForQuestionAnswering
,
DebertaForSequenceClassification
,
...
...
@@ -36,27 +34,7 @@ if is_torch_available():
from
transformers.models.deberta.modeling_deberta
import
DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST
@
require_torch
class
DebertaModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
DebertaModel
,
DebertaForMaskedLM
,
DebertaForSequenceClassification
,
DebertaForTokenClassification
,
DebertaForQuestionAnswering
,
)
if
is_torch_available
()
else
()
)
test_torchscript
=
False
test_pruning
=
False
test_head_masking
=
False
is_encoder_decoder
=
False
class
DebertaModelTester
(
object
):
class
DebertaModelTester
(
object
):
def
__init__
(
self
,
parent
,
...
...
@@ -130,7 +108,12 @@ class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
DebertaConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
DebertaConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -147,8 +130,6 @@ class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
pos_att_type
=
self
.
pos_att_type
,
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
.
loss
.
size
()),
[])
...
...
@@ -162,9 +143,7 @@ class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
sequence_output
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)[
0
]
sequence_output
=
model
(
input_ids
)[
0
]
self
.
parent
.
assertListEqual
(
list
(
sequence_output
.
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
self
.
parent
.
assertListEqual
(
list
(
sequence_output
.
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
def
create_and_check_deberta_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
...
@@ -227,8 +206,29 @@ class DebertaModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict
=
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
"attention_mask"
:
input_mask
}
return
config
,
inputs_dict
@
require_torch
class
DebertaModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
DebertaModel
,
DebertaForMaskedLM
,
DebertaForSequenceClassification
,
DebertaForTokenClassification
,
DebertaForQuestionAnswering
,
)
if
is_torch_available
()
else
()
)
test_torchscript
=
False
test_pruning
=
False
test_head_masking
=
False
is_encoder_decoder
=
False
def
setUp
(
self
):
self
.
model_tester
=
DebertaModelTest
.
DebertaModelTester
(
self
)
self
.
model_tester
=
DebertaModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
DebertaConfig
,
hidden_size
=
37
)
def
test_config
(
self
):
...
...
tests/test_modeling_deberta_v2.py
View file @
c3d9ac76
...
...
@@ -12,10 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
DebertaV2Config
,
is_torch_available
from
transformers.testing_utils
import
require_sentencepiece
,
require_tokenizers
,
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -26,7 +25,6 @@ if is_torch_available():
import
torch
from
transformers
import
(
DebertaV2Config
,
DebertaV2ForMaskedLM
,
DebertaV2ForQuestionAnswering
,
DebertaV2ForSequenceClassification
,
...
...
@@ -36,27 +34,7 @@ if is_torch_available():
from
transformers.models.deberta_v2.modeling_deberta_v2
import
DEBERTA_V2_PRETRAINED_MODEL_ARCHIVE_LIST
@
require_torch
class
DebertaV2ModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
DebertaV2Model
,
DebertaV2ForMaskedLM
,
DebertaV2ForSequenceClassification
,
DebertaV2ForTokenClassification
,
DebertaV2ForQuestionAnswering
,
)
if
is_torch_available
()
else
()
)
test_torchscript
=
False
test_pruning
=
False
test_head_masking
=
False
is_encoder_decoder
=
False
class
DebertaV2ModelTester
(
object
):
class
DebertaV2ModelTester
(
object
):
def
__init__
(
self
,
parent
,
...
...
@@ -130,7 +108,12 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
DebertaV2Config
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
DebertaV2Config
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -147,8 +130,6 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
pos_att_type
=
self
.
pos_att_type
,
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
.
loss
.
size
()),
[])
...
...
@@ -162,9 +143,7 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
sequence_output
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)[
0
]
sequence_output
=
model
(
input_ids
)[
0
]
self
.
parent
.
assertListEqual
(
list
(
sequence_output
.
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
]
)
self
.
parent
.
assertListEqual
(
list
(
sequence_output
.
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
def
create_and_check_deberta_for_masked_lm
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
...
@@ -227,8 +206,29 @@ class DebertaV2ModelTest(ModelTesterMixin, unittest.TestCase):
inputs_dict
=
{
"input_ids"
:
input_ids
,
"token_type_ids"
:
token_type_ids
,
"attention_mask"
:
input_mask
}
return
config
,
inputs_dict
@
require_torch
class
DebertaV2ModelTest
(
ModelTesterMixin
,
unittest
.
TestCase
):
all_model_classes
=
(
(
DebertaV2Model
,
DebertaV2ForMaskedLM
,
DebertaV2ForSequenceClassification
,
DebertaV2ForTokenClassification
,
DebertaV2ForQuestionAnswering
,
)
if
is_torch_available
()
else
()
)
test_torchscript
=
False
test_pruning
=
False
test_head_masking
=
False
is_encoder_decoder
=
False
def
setUp
(
self
):
self
.
model_tester
=
DebertaV2ModelTest
.
DebertaV2ModelTester
(
self
)
self
.
model_tester
=
DebertaV2ModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
DebertaV2Config
,
hidden_size
=
37
)
def
test_config
(
self
):
...
...
tests/test_modeling_deit.py
View file @
c3d9ac76
...
...
@@ -18,6 +18,7 @@
import
inspect
import
unittest
from
transformers
import
DeiTConfig
from
transformers.file_utils
import
cached_property
,
is_torch_available
,
is_vision_available
from
transformers.testing_utils
import
require_torch
,
require_vision
,
slow
,
torch_device
...
...
@@ -31,7 +32,6 @@ if is_torch_available():
from
transformers
import
(
MODEL_MAPPING
,
DeiTConfig
,
DeiTForImageClassification
,
DeiTForImageClassificationWithTeacher
,
DeiTModel
,
...
...
@@ -92,7 +92,12 @@ class DeiTModelTester:
if
self
.
use_labels
:
labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
config
=
DeiTConfig
(
config
=
self
.
get_config
()
return
config
,
pixel_values
,
labels
def
get_config
(
self
):
return
DeiTConfig
(
image_size
=
self
.
image_size
,
patch_size
=
self
.
patch_size
,
num_channels
=
self
.
num_channels
,
...
...
@@ -107,8 +112,6 @@ class DeiTModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
pixel_values
,
labels
def
create_and_check_model
(
self
,
config
,
pixel_values
,
labels
):
model
=
DeiTModel
(
config
=
config
)
model
.
to
(
torch_device
)
...
...
tests/test_modeling_detr.py
View file @
c3d9ac76
...
...
@@ -19,7 +19,7 @@ import inspect
import
math
import
unittest
from
transformers
import
is_timm_available
,
is_vision_available
from
transformers
import
DetrConfig
,
is_timm_available
,
is_vision_available
from
transformers.file_utils
import
cached_property
from
transformers.testing_utils
import
require_timm
,
require_vision
,
slow
,
torch_device
...
...
@@ -31,7 +31,7 @@ from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_te
if
is_timm_available
():
import
torch
from
transformers
import
DetrConfig
,
DetrForObjectDetection
,
DetrForSegmentation
,
DetrModel
from
transformers
import
DetrForObjectDetection
,
DetrForSegmentation
,
DetrModel
if
is_vision_available
():
...
...
@@ -40,7 +40,6 @@ if is_vision_available():
from
transformers
import
DetrFeatureExtractor
@
require_timm
class
DetrModelTester
:
def
__init__
(
self
,
...
...
@@ -102,7 +101,11 @@ class DetrModelTester:
target
[
"masks"
]
=
torch
.
rand
(
self
.
n_targets
,
self
.
min_size
,
self
.
max_size
,
device
=
torch_device
)
labels
.
append
(
target
)
config
=
DetrConfig
(
config
=
self
.
get_config
()
return
config
,
pixel_values
,
pixel_mask
,
labels
def
get_config
(
self
):
return
DetrConfig
(
d_model
=
self
.
hidden_size
,
encoder_layers
=
self
.
num_hidden_layers
,
decoder_layers
=
self
.
num_hidden_layers
,
...
...
@@ -115,7 +118,6 @@ class DetrModelTester:
num_queries
=
self
.
num_queries
,
num_labels
=
self
.
num_labels
,
)
return
config
,
pixel_values
,
pixel_mask
,
labels
def
prepare_config_and_inputs_for_common
(
self
):
config
,
pixel_values
,
pixel_mask
,
labels
=
self
.
prepare_config_and_inputs
()
...
...
tests/test_modeling_distilbert.py
View file @
c3d9ac76
...
...
@@ -16,7 +16,7 @@
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
DistilBertConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -28,7 +28,6 @@ if is_torch_available():
from
transformers
import
(
DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST
,
DistilBertConfig
,
DistilBertForMaskedLM
,
DistilBertForMultipleChoice
,
DistilBertForQuestionAnswering
,
...
...
@@ -37,7 +36,8 @@ if is_torch_available():
DistilBertModel
,
)
class
DistilBertModelTester
(
object
):
class
DistilBertModelTester
(
object
):
def
__init__
(
self
,
parent
,
...
...
@@ -101,7 +101,12 @@ if is_torch_available():
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
DistilBertConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
DistilBertConfig
(
vocab_size
=
self
.
vocab_size
,
dim
=
self
.
hidden_size
,
n_layers
=
self
.
num_hidden_layers
,
...
...
@@ -114,8 +119,6 @@ if is_torch_available():
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_distilbert_model
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
...
...
@@ -124,9 +127,7 @@ if is_torch_available():
model
.
eval
()
result
=
model
(
input_ids
,
input_mask
)
result
=
model
(
input_ids
)
self
.
parent
.
assertEqual
(
result
.
last_hidden_state
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
)
)
self
.
parent
.
assertEqual
(
result
.
last_hidden_state
.
shape
,
(
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
))
def
create_and_check_distilbert_for_masked_lm
(
self
,
config
,
input_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
...
...
tests/test_modeling_dpr.py
View file @
c3d9ac76
...
...
@@ -16,7 +16,7 @@
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
DPRConfig
,
is_torch_available
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
...
...
@@ -26,7 +26,7 @@ from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention
if
is_torch_available
():
import
torch
from
transformers
import
DPRConfig
,
DPRContextEncoder
,
DPRQuestionEncoder
,
DPRReader
,
DPRReaderTokenizer
from
transformers
import
DPRContextEncoder
,
DPRQuestionEncoder
,
DPRReader
,
DPRReaderTokenizer
from
transformers.models.dpr.modeling_dpr
import
(
DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST
,
DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST
,
...
...
@@ -104,7 +104,12 @@ class DPRModelTester:
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
DPRConfig
(
config
=
self
.
get_config
()
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
get_config
(
self
):
return
DPRConfig
(
projection_dim
=
self
.
projection_dim
,
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
...
...
@@ -119,8 +124,6 @@ class DPRModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_context_encoder
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
...
...
tests/test_modeling_electra.py
View file @
c3d9ac76
...
...
@@ -16,7 +16,7 @@
import
unittest
from
transformers
import
is_torch_available
from
transformers
import
ElectraConfig
,
is_torch_available
from
transformers.models.auto
import
get_values
from
transformers.testing_utils
import
require_torch
,
slow
,
torch_device
...
...
@@ -29,7 +29,6 @@ if is_torch_available():
from
transformers
import
(
MODEL_FOR_PRETRAINING_MAPPING
,
ElectraConfig
,
ElectraForMaskedLM
,
ElectraForMultipleChoice
,
ElectraForPreTraining
,
...
...
@@ -89,7 +88,21 @@ class ElectraModelTester:
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
fake_token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
1
)
config
=
ElectraConfig
(
config
=
self
.
get_config
()
return
(
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
,
fake_token_labels
,
)
def
get_config
(
self
):
return
ElectraConfig
(
vocab_size
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
...
...
@@ -104,17 +117,6 @@ class ElectraModelTester:
initializer_range
=
self
.
initializer_range
,
)
return
(
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
,
fake_token_labels
,
)
def
create_and_check_electra_model
(
self
,
config
,
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment