Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cbf036f7
Unverified
Commit
cbf036f7
authored
Dec 17, 2021
by
NielsRogge
Committed by
GitHub
Dec 17, 2021
Browse files
Add test (#14810)
parent
c4a0fb51
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
1 deletion
+26
-1
tests/test_modeling_perceiver.py
tests/test_modeling_perceiver.py
+26
-1
No files found.
tests/test_modeling_perceiver.py
View file @
cbf036f7
...
@@ -28,7 +28,7 @@ from datasets import load_dataset
...
@@ -28,7 +28,7 @@ from datasets import load_dataset
from
transformers
import
PerceiverConfig
from
transformers
import
PerceiverConfig
from
transformers.file_utils
import
is_torch_available
,
is_vision_available
from
transformers.file_utils
import
is_torch_available
,
is_vision_available
from
transformers.models.auto
import
get_values
from
transformers.models.auto
import
get_values
from
transformers.testing_utils
import
require_torch
,
require_vision
,
slow
,
torch_device
from
transformers.testing_utils
import
require_torch
,
require_torch_multi_gpu
,
require_vision
,
slow
,
torch_device
from
.test_configuration_common
import
ConfigTester
from
.test_configuration_common
import
ConfigTester
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
from
.test_modeling_common
import
ModelTesterMixin
,
floats_tensor
,
ids_tensor
,
random_attention_mask
...
@@ -757,6 +757,31 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -757,6 +757,31 @@ class PerceiverModelTest(ModelTesterMixin, unittest.TestCase):
loss
.
backward
()
loss
.
backward
()
@
require_torch_multi_gpu
def
test_multi_gpu_data_parallel_forward
(
self
):
for
model_class
in
self
.
all_model_classes
:
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_model_class
(
model_class
)
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params
=
[
"head_mask"
,
"decoder_head_mask"
,
"cross_attn_head_mask"
]
for
k
in
blacklist_non_batched_params
:
inputs_dict
.
pop
(
k
,
None
)
# move input tensors to cuda:O
for
k
,
v
in
inputs_dict
.
items
():
if
torch
.
is_tensor
(
v
):
inputs_dict
[
k
]
=
v
.
to
(
0
)
model
=
model_class
(
config
=
config
)
model
.
to
(
0
)
model
.
eval
()
# Wrap model in nn.DataParallel
model
=
nn
.
DataParallel
(
model
)
with
torch
.
no_grad
():
_
=
model
(
**
self
.
_prepare_for_class
(
inputs_dict
,
model_class
))
@
unittest
.
skip
(
reason
=
"Perceiver models don't have a typical head like is the case with BERT"
)
@
unittest
.
skip
(
reason
=
"Perceiver models don't have a typical head like is the case with BERT"
)
def
test_save_load_fast_init_from_base
(
self
):
def
test_save_load_fast_init_from_base
(
self
):
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment