Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
dc7c49e4
Commit
dc7c49e4
authored
Jun 27, 2022
by
patil-suraj
Browse files
add tests for upsample blocks
parent
e13ee8b5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
6 deletions
+59
-6
src/diffusers/models/resnet.py
src/diffusers/models/resnet.py
+8
-6
tests/test_layers_utils.py
tests/test_layers_utils.py
+51
-0
No files found.
src/diffusers/models/resnet.py
View file @
dc7c49e4
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
...
...
@@ -29,6 +28,7 @@ def conv_nd(dims, *args, **kwargs):
return
nn
.
Conv3d
(
*
args
,
**
kwargs
)
raise
ValueError
(
f
"unsupported dimensions:
{
dims
}
"
)
def
conv_transpose_nd
(
dims
,
*
args
,
**
kwargs
):
"""
Create a 1D, 2D, or 3D convolution module.
...
...
@@ -73,7 +73,7 @@ class Upsample(nn.Module):
self
.
use_conv_transpose
=
use_conv_transpose
if
use_conv_transpose
:
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
out_channels
,
4
,
2
,
1
)
self
.
conv
=
conv_transpose_nd
(
dims
,
channels
,
self
.
out_channels
,
4
,
2
,
1
)
elif
use_conv
:
self
.
conv
=
conv_nd
(
dims
,
self
.
channels
,
self
.
out_channels
,
3
,
padding
=
1
)
...
...
@@ -81,15 +81,15 @@ class Upsample(nn.Module):
assert
x
.
shape
[
1
]
==
self
.
channels
if
self
.
use_conv_transpose
:
return
self
.
conv
(
x
)
if
self
.
dims
==
3
:
x
=
F
.
interpolate
(
x
,
(
x
.
shape
[
2
],
x
.
shape
[
3
]
*
2
,
x
.
shape
[
4
]
*
2
),
mode
=
"nearest"
)
else
:
x
=
F
.
interpolate
(
x
,
scale_factor
=
2.0
,
mode
=
"nearest"
)
if
self
.
use_conv
:
x
=
self
.
conv
(
x
)
return
x
...
...
@@ -138,6 +138,7 @@ class UNetUpsample(nn.Module):
x
=
self
.
conv
(
x
)
return
x
class
GlideUpsample
(
nn
.
Module
):
"""
An upsampling layer with an optional convolution.
...
...
@@ -199,13 +200,14 @@ class LDMUpsample(nn.Module):
class
GradTTSUpsample
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
(
Upsample
,
self
).
__init__
()
super
(
GradTTS
Upsample
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
ConvTranspose2d
(
dim
,
dim
,
4
,
2
,
1
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
# TODO (patil-suraj): needs test
class
Upsample1d
(
nn
.
Module
):
def
__init__
(
self
,
dim
):
super
().
__init__
()
...
...
tests/test_layers_utils.py
View file @
dc7c49e4
...
...
@@ -22,6 +22,7 @@ import numpy as np
import
torch
from
diffusers.models.embeddings
import
get_timestep_embedding
from
diffusers.models.resnet
import
Upsample
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
...
...
@@ -113,3 +114,53 @@ class EmbeddingsTests(unittest.TestCase):
torch
.
tensor
([
-
0.9801
,
-
0.9464
,
-
0.9349
,
-
0.3952
,
0.8887
,
-
0.9709
,
0.5299
,
-
0.2853
,
-
0.9927
]),
1e-3
,
)
class
UpsampleBlockTests
(
unittest
.
TestCase
):
def
test_upsample_default
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
False
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.2173
,
-
1.2079
,
-
1.2079
,
0.2952
,
1.1254
,
1.1254
,
0.2952
,
1.1254
,
1.1254
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_conv
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
True
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
0.7145
,
1.3773
,
0.3492
,
0.8448
,
1.0839
,
-
0.3341
,
0.5956
,
0.1250
,
-
0.4841
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_conv_out_dim
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
True
,
out_channels
=
64
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
64
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
0.2703
,
0.1656
,
-
0.2538
,
-
0.0553
,
-
0.2984
,
0.1044
,
0.1155
,
0.2579
,
0.7755
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
def
test_upsample_with_transpose
(
self
):
torch
.
manual_seed
(
0
)
sample
=
torch
.
randn
(
1
,
32
,
32
,
32
)
upsample
=
Upsample
(
channels
=
32
,
use_conv
=
False
,
use_conv_transpose
=
True
)
with
torch
.
no_grad
():
upsampled
=
upsample
(
sample
)
assert
upsampled
.
shape
==
(
1
,
32
,
64
,
64
)
output_slice
=
upsampled
[
0
,
-
1
,
-
3
:,
-
3
:]
expected_slice
=
torch
.
tensor
([
-
0.3028
,
-
0.1582
,
0.0071
,
0.0350
,
-
0.4799
,
-
0.1139
,
0.1056
,
-
0.1153
,
-
0.1046
])
assert
torch
.
allclose
(
output_slice
.
flatten
(),
expected_slice
,
atol
=
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment