Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
dfc76b25
Unverified
Commit
dfc76b25
authored
Jun 09, 2022
by
amyeroberts
Committed by
GitHub
Jun 09, 2022
Browse files
has_attentions - consistent test skipping logic and tf tests (#17495)
parent
66e86567
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
146 additions
and
120 deletions
+146
-120
tests/models/convnext/test_modeling_convnext.py
tests/models/convnext/test_modeling_convnext.py
+4
-0
tests/models/cvt/test_modeling_cvt.py
tests/models/cvt/test_modeling_cvt.py
+4
-0
tests/models/flava/test_modeling_flava.py
tests/models/flava/test_modeling_flava.py
+4
-0
tests/models/poolformer/test_modeling_poolformer.py
tests/models/poolformer/test_modeling_poolformer.py
+4
-0
tests/models/regnet/test_modeling_regnet.py
tests/models/regnet/test_modeling_regnet.py
+4
-0
tests/models/resnet/test_modeling_resnet.py
tests/models/resnet/test_modeling_resnet.py
+4
-0
tests/models/van/test_modeling_van.py
tests/models/van/test_modeling_van.py
+4
-0
tests/test_modeling_common.py
tests/test_modeling_common.py
+105
-109
tests/test_modeling_tf_common.py
tests/test_modeling_tf_common.py
+13
-11
No files found.
tests/models/convnext/test_modeling_convnext.py
View file @
dfc76b25
...
@@ -158,6 +158,10 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -158,6 +158,10 @@ class ConvNextModelTest(ModelTesterMixin, unittest.TestCase):
def
create_and_test_config_common_properties
(
self
):
def
create_and_test_config_common_properties
(
self
):
return
return
@
unittest
.
skip
(
reason
=
"ConvNext does not output attentions"
)
def
test_attention_outputs
(
self
):
pass
@
unittest
.
skip
(
reason
=
"ConvNext does not use inputs_embeds"
)
@
unittest
.
skip
(
reason
=
"ConvNext does not use inputs_embeds"
)
def
test_inputs_embeds
(
self
):
def
test_inputs_embeds
(
self
):
pass
pass
...
...
tests/models/cvt/test_modeling_cvt.py
View file @
dfc76b25
...
@@ -173,6 +173,10 @@ class CvtModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -173,6 +173,10 @@ class CvtModelTest(ModelTesterMixin, unittest.TestCase):
def
create_and_test_config_common_properties
(
self
):
def
create_and_test_config_common_properties
(
self
):
return
return
@
unittest
.
skip
(
reason
=
"Cvt does not output attentions"
)
def
test_attention_outputs
(
self
):
pass
@
unittest
.
skip
(
reason
=
"Cvt does not use inputs_embeds"
)
@
unittest
.
skip
(
reason
=
"Cvt does not use inputs_embeds"
)
def
test_inputs_embeds
(
self
):
def
test_inputs_embeds
(
self
):
pass
pass
...
...
tests/models/flava/test_modeling_flava.py
View file @
dfc76b25
...
@@ -695,6 +695,10 @@ class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase):
...
@@ -695,6 +695,10 @@ class FlavaImageCodebookTest(ModelTesterMixin, unittest.TestCase):
expected_arg_names
=
[
"pixel_values"
]
expected_arg_names
=
[
"pixel_values"
]
self
.
assertListEqual
(
arg_names
[:
1
],
expected_arg_names
)
self
.
assertListEqual
(
arg_names
[:
1
],
expected_arg_names
)
@
unittest
.
skip
(
reason
=
"Flava does not output attentions"
)
def
test_attention_outputs
(
self
):
pass
def
test_model_common_attributes
(
self
):
def
test_model_common_attributes
(
self
):
# No embedding in multimodal model
# No embedding in multimodal model
pass
pass
...
...
tests/models/poolformer/test_modeling_poolformer.py
View file @
dfc76b25
...
@@ -142,6 +142,10 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -142,6 +142,10 @@ class PoolFormerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
self
.
model_tester
.
create_and_check_model
(
*
config_and_inputs
)
@
unittest
.
skip
(
reason
=
"PoolFormer does not output attentions"
)
def
test_attention_outputs
(
self
):
pass
@
unittest
.
skip
(
"PoolFormer does not use inputs_embeds"
)
@
unittest
.
skip
(
"PoolFormer does not use inputs_embeds"
)
def
test_inputs_embeds
(
self
):
def
test_inputs_embeds
(
self
):
pass
pass
...
...
tests/models/regnet/test_modeling_regnet.py
View file @
dfc76b25
...
@@ -147,6 +147,10 @@ class RegNetModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -147,6 +147,10 @@ class RegNetModelTest(ModelTesterMixin, unittest.TestCase):
def
create_and_test_config_common_properties
(
self
):
def
create_and_test_config_common_properties
(
self
):
return
return
@
unittest
.
skip
(
reason
=
"RegNet does not output attentions"
)
def
test_attention_outputs
(
self
):
pass
@
unittest
.
skip
(
reason
=
"RegNet does not use inputs_embeds"
)
@
unittest
.
skip
(
reason
=
"RegNet does not use inputs_embeds"
)
def
test_inputs_embeds
(
self
):
def
test_inputs_embeds
(
self
):
pass
pass
...
...
tests/models/resnet/test_modeling_resnet.py
View file @
dfc76b25
...
@@ -147,6 +147,10 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -147,6 +147,10 @@ class ResNetModelTest(ModelTesterMixin, unittest.TestCase):
def
create_and_test_config_common_properties
(
self
):
def
create_and_test_config_common_properties
(
self
):
return
return
@
unittest
.
skip
(
reason
=
"ResNet does not output attentions"
)
def
test_attention_outputs
(
self
):
pass
@
unittest
.
skip
(
reason
=
"ResNet does not use inputs_embeds"
)
@
unittest
.
skip
(
reason
=
"ResNet does not use inputs_embeds"
)
def
test_inputs_embeds
(
self
):
def
test_inputs_embeds
(
self
):
pass
pass
...
...
tests/models/van/test_modeling_van.py
View file @
dfc76b25
...
@@ -144,6 +144,10 @@ class VanModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -144,6 +144,10 @@ class VanModelTest(ModelTesterMixin, unittest.TestCase):
def
create_and_test_config_common_properties
(
self
):
def
create_and_test_config_common_properties
(
self
):
return
return
@
unittest
.
skip
(
reason
=
"Van does not output attentions"
)
def
test_attention_outputs
(
self
):
pass
@
unittest
.
skip
(
reason
=
"Van does not use inputs_embeds"
)
@
unittest
.
skip
(
reason
=
"Van does not use inputs_embeds"
)
def
test_inputs_embeds
(
self
):
def
test_inputs_embeds
(
self
):
pass
pass
...
...
tests/test_modeling_common.py
View file @
dfc76b25
...
@@ -485,10 +485,6 @@ class ModelTesterMixin:
...
@@ -485,10 +485,6 @@ class ModelTesterMixin:
loss
.
backward
()
loss
.
backward
()
def
test_attention_outputs
(
self
):
def
test_attention_outputs
(
self
):
if
not
self
.
has_attentions
:
pass
else
:
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
.
return_dict
=
True
config
.
return_dict
=
True
...
...
tests/test_modeling_tf_common.py
View file @
dfc76b25
...
@@ -978,6 +978,7 @@ class TFModelTesterMixin:
...
@@ -978,6 +978,7 @@ class TFModelTesterMixin:
dict_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
dict_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
check_equivalence
(
model
,
tuple_inputs
,
dict_inputs
,
{
"output_hidden_states"
:
True
})
check_equivalence
(
model
,
tuple_inputs
,
dict_inputs
,
{
"output_hidden_states"
:
True
})
if
self
.
has_attentions
:
tuple_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
tuple_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
dict_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
dict_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
)
check_equivalence
(
model
,
tuple_inputs
,
dict_inputs
,
{
"output_attentions"
:
True
})
check_equivalence
(
model
,
tuple_inputs
,
dict_inputs
,
{
"output_attentions"
:
True
})
...
@@ -992,6 +993,7 @@ class TFModelTesterMixin:
...
@@ -992,6 +993,7 @@ class TFModelTesterMixin:
dict_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
,
return_labels
=
True
)
dict_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
,
return_labels
=
True
)
check_equivalence
(
model
,
tuple_inputs
,
dict_inputs
,
{
"output_hidden_states"
:
True
})
check_equivalence
(
model
,
tuple_inputs
,
dict_inputs
,
{
"output_hidden_states"
:
True
})
if
self
.
has_attentions
:
tuple_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
,
return_labels
=
True
)
tuple_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
,
return_labels
=
True
)
dict_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
,
return_labels
=
True
)
dict_inputs
=
self
.
_prepare_for_class
(
inputs_dict
,
model_class
,
return_labels
=
True
)
check_equivalence
(
model
,
tuple_inputs
,
dict_inputs
,
{
"output_attentions"
:
True
})
check_equivalence
(
model
,
tuple_inputs
,
dict_inputs
,
{
"output_attentions"
:
True
})
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment