Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b19f3e69
Unverified
Commit
b19f3e69
authored
Mar 08, 2022
by
NielsRogge
Committed by
GitHub
Mar 08, 2022
Browse files
[Tests] Fix ViTMAE integration test (#15949)
* Fix test across both cpu and gpu * Fix typo
parent
9879a1d5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
3 deletions
+12
-3
tests/vit_mae/test_modeling_vit_mae.py
tests/vit_mae/test_modeling_vit_mae.py
+12
-3
No files found.
tests/vit_mae/test_modeling_vit_mae.py
View file @
b19f3e69
...
...
@@ -401,6 +401,9 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
@
slow
def
test_inference_for_pretraining
(
self
):
# make random mask reproducible
# note that the same seed on CPU and on GPU doesn’t mean they spew the same random number sequences,
# as they both have fairly different PRNGs (for efficiency reasons).
# source: https://discuss.pytorch.org/t/random-seed-that-spans-across-devices/19735
torch
.
manual_seed
(
2
)
model
=
ViTMAEForPreTraining
.
from_pretrained
(
"facebook/vit-mae-base"
).
to
(
torch_device
)
...
...
@@ -417,8 +420,14 @@ class ViTMAEModelIntegrationTest(unittest.TestCase):
expected_shape
=
torch
.
Size
((
1
,
196
,
768
))
self
.
assertEqual
(
outputs
.
logits
.
shape
,
expected_shape
)
expected_slice
=
torch
.
tensor
(
expected_slice
_cpu
=
torch
.
tensor
(
[[
0.7366
,
-
1.3663
,
-
0.2844
],
[
0.7919
,
-
1.3839
,
-
0.3241
],
[
0.4313
,
-
0.7168
,
-
0.2878
]]
).
to
(
torch_device
)
)
expected_slice_gpu
=
torch
.
tensor
(
[[
0.8948
,
-
1.0680
,
0.0030
],
[
0.9758
,
-
1.1181
,
-
0.0290
],
[
1.0602
,
-
1.1522
,
-
0.0528
]]
)
# set expected slice depending on device
expected_slice
=
expected_slice_cpu
if
torch_device
==
"cpu"
else
expected_slice_gpu
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[
0
,
:
3
,
:
3
],
expected_slice
,
atol
=
1e-4
))
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[
0
,
:
3
,
:
3
],
expected_slice
.
to
(
torch_device
)
,
atol
=
1e-4
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment