Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
diffusers
Commits
4e267493
Unverified
Commit
4e267493
authored
Jul 04, 2022
by
Suraj Patil
Committed by
GitHub
Jul 04, 2022
Browse files
add tests for 1D Up/Downsample blocks (#72)
parent
53a42d0a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
106 additions
and
1 deletion
+106
-1
tests/test_layers_utils.py
tests/test_layers_utils.py
+106
-1
No files found.
tests/test_layers_utils.py
View file @
4e267493
...
...
@@ -22,7 +22,7 @@ import numpy as np
import
torch
from
diffusers.models.embeddings
import
get_timestep_embedding
from
diffusers.models.resnet
import
Downsample
2
D
,
Upsample2D
from
diffusers.models.resnet
import
Downsample
1D
,
Downsample2D
,
Upsample1
D
,
Upsample2D
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
...
...
@@ -219,3 +219,108 @@ class Downsample2DBlockTests(unittest.TestCase):
output_slice
=
downsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.6586
,
0.5985
,
0.0721
,
0.1256
,
-
0.1492
,
0.4436
,
-
0.2544
,
0.5021
,
1.1522
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
class
Upsample1DBlockTests
(
unittest
.
TestCase
):
def
test_upsample_default
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
)
upsample
=
Upsample1D
(
channels
=
32
,
use_conv
=
False
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
8
:]
expected_slice
=
torch
.
tensor
([
-
1.6340
,
-
1.6340
,
0.5374
,
0.5374
,
1.0826
,
1.0826
,
-
1.7105
,
-
1.7105
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_conv
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
)
upsample
=
Upsample1D
(
channels
=
32
,
use_conv
=
True
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
8
:]
expected_slice
=
torch
.
tensor
([
-
0.4546
,
-
0.5010
,
-
0.2996
,
0.2844
,
0.4040
,
-
0.7772
,
-
0.6862
,
0.3612
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_conv_out_dim
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
)
upsample
=
Upsample1D
(
channels
=
32
,
use_conv
=
True
,
out_channels
=
64
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
8
:]
expected_slice
=
torch
.
tensor
([
-
0.0516
,
-
0.0972
,
0.9740
,
1.1883
,
0.4539
,
-
0.5285
,
-
0.5851
,
0.1152
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_transpose
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
)
upsample
=
Upsample1D
(
channels
=
32
,
use_conv
=
False
,
use_conv_transpose
=
True
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
8
:]
expected_slice
=
torch
.
tensor
([
-
0.2238
,
-
0.5842
,
-
0.7165
,
0.6699
,
0.1033
,
-
0.4269
,
-
0.8974
,
-
0.3716
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
class
Downsample1DBlockTests
(
unittest
.
TestCase
):
def
test_downsample_default
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
)
downsample
=
Downsample1D
(
channels
=
32
,
use_conv
=
False
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
32
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
8
:]
expected_slice
=
torch
.
tensor
([
-
0.8796
,
1.0945
,
-
0.3434
,
0.2910
,
0.3391
,
-
0.4488
,
-
0.9568
,
-
0.2909
])
max_diff
=
(
output_slice
.
flatten
()
-
expected_slice
).
abs
().
sum
().
item
()
assert
max_diff
<=
1e-3
# assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-1)
def
test_downsample_with_conv
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
)
downsample
=
Downsample1D
(
channels
=
32
,
use_conv
=
True
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
32
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
8
:]
expected_slice
=
torch
.
tensor
(
[
0.1723
,
0.0811
,
-
0.6205
,
-
0.3045
,
0.0666
,
-
0.2381
,
-
0.0238
,
0.2834
],
)
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_downsample_with_conv_pad1
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
)
downsample
=
Downsample1D
(
channels
=
32
,
use_conv
=
True
,
padding
=
1
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
32
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
8
:]
expected_slice
=
torch
.
tensor
([
0.1723
,
0.0811
,
-
0.6205
,
-
0.3045
,
0.0666
,
-
0.2381
,
-
0.0238
,
0.2834
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_downsample_with_conv_out_dim
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
64
)
downsample
=
Downsample1D
(
channels
=
32
,
use_conv
=
True
,
out_channels
=
16
)
with
torch
.
no_grad
():
downsampled
=
downsample
(
sample
)
assert
downsampled
.
shape
==
(
1
,
16
,
32
)
output_slice
=
downsampled
[
0
,
-
1
,
-
8
:]
expected_slice
=
torch
.
tensor
([
1.1067
,
-
0.5255
,
-
0.4451
,
0.0487
,
-
0.3664
,
-
0.7945
,
-
0.4495
,
-
0.3129
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment