Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
761f0297
Unverified
Commit
761f0297
authored
Sep 16, 2022
by
Anton Lozhkov
Committed by
GitHub
Sep 16, 2022
Browse files
[Tests] Fix spatial transformer tests on GPU (#531)
parent
c1796efd
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
4 deletions
+12
-4
tests/test_layers_utils.py
tests/test_layers_utils.py
+12
-4
No files found.
tests/test_layers_utils.py
View file @
761f0297
...
@@ -240,7 +240,9 @@ class AttentionBlockTests(unittest.TestCase):
...
@@ -240,7 +240,9 @@ class AttentionBlockTests(unittest.TestCase):
assert
attention_scores
.
shape
==
(
1
,
32
,
64
,
64
)
assert
attention_scores
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
attention_scores
[
0
,
-
1
,
-
3
:,
-
3
:]
output_slice
=
attention_scores
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
1.4975
,
-
0.0038
,
-
0.7847
,
-
1.4567
,
1.1220
,
-
0.8962
,
-
1.7394
,
1.1319
,
-
0.5427
])
expected_slice
=
torch
.
tensor
(
[
-
1.4975
,
-
0.0038
,
-
0.7847
,
-
1.4567
,
1.1220
,
-
0.8962
,
-
1.7394
,
1.1319
,
-
0.5427
],
device
=
torch_device
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
...
@@ -264,7 +266,9 @@ class SpatialTransformerTests(unittest.TestCase):
...
@@ -264,7 +266,9 @@ class SpatialTransformerTests(unittest.TestCase):
assert
attention_scores
.
shape
==
(
1
,
32
,
64
,
64
)
assert
attention_scores
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
attention_scores
[
0
,
-
1
,
-
3
:,
-
3
:]
output_slice
=
attention_scores
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
1.2447
,
-
0.0137
,
-
0.9559
,
-
1.5223
,
0.6991
,
-
1.0126
,
-
2.0974
,
0.8921
,
-
1.0201
])
expected_slice
=
torch
.
tensor
(
[
-
1.2447
,
-
0.0137
,
-
0.9559
,
-
1.5223
,
0.6991
,
-
1.0126
,
-
2.0974
,
0.8921
,
-
1.0201
],
device
=
torch_device
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_spatial_transformer_context_dim
(
self
):
def
test_spatial_transformer_context_dim
(
self
):
...
@@ -287,7 +291,9 @@ class SpatialTransformerTests(unittest.TestCase):
...
@@ -287,7 +291,9 @@ class SpatialTransformerTests(unittest.TestCase):
assert
attention_scores
.
shape
==
(
1
,
64
,
64
,
64
)
assert
attention_scores
.
shape
==
(
1
,
64
,
64
,
64
)
output_slice
=
attention_scores
[
0
,
-
1
,
-
3
:,
-
3
:]
output_slice
=
attention_scores
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.2555
,
-
0.8877
,
-
2.4739
,
-
2.2251
,
1.2714
,
0.0807
,
-
0.4161
,
-
1.6408
,
-
0.0471
])
expected_slice
=
torch
.
tensor
(
[
-
0.2555
,
-
0.8877
,
-
2.4739
,
-
2.2251
,
1.2714
,
0.0807
,
-
0.4161
,
-
1.6408
,
-
0.0471
],
device
=
torch_device
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_spatial_transformer_dropout
(
self
):
def
test_spatial_transformer_dropout
(
self
):
...
@@ -313,5 +319,7 @@ class SpatialTransformerTests(unittest.TestCase):
...
@@ -313,5 +319,7 @@ class SpatialTransformerTests(unittest.TestCase):
assert
attention_scores
.
shape
==
(
1
,
32
,
64
,
64
)
assert
attention_scores
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
attention_scores
[
0
,
-
1
,
-
3
:,
-
3
:]
output_slice
=
attention_scores
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
1.2448
,
-
0.0190
,
-
0.9471
,
-
1.5140
,
0.7069
,
-
1.0144
,
-
2.1077
,
0.9099
,
-
1.0091
])
expected_slice
=
torch
.
tensor
(
[
-
1.2448
,
-
0.0190
,
-
0.9471
,
-
1.5140
,
0.7069
,
-
1.0144
,
-
2.1077
,
0.9099
,
-
1.0091
],
device
=
torch_device
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment