Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
c1c4dea9
Commit
c1c4dea9
authored
Jun 30, 2022
by
Patrick von Platen
Browse files
correct tests ncsnpp
parent
f4cd5a20
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
13 deletions
+13
-13
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+13
-13
No files found.
tests/test_modeling_utils.py
View file @
c1c4dea9
...
...
@@ -742,18 +742,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
0
]).
to
(
torch_device
)
noise
=
torch
.
ones
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
e-4
]).
to
(
torch_device
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
3
.1
909e-07
,
-
8.5393e-08
,
4.8460e-07
,
-
4.5550e-07
,
-
1.3205e-06
,
-
6.3475e-07
,
9.7837e-07
,
2.9974e-07
,
1.2345e-0
6
])
expected_output_slice
=
torch
.
tensor
([
0
.1
315
,
0.0741
,
0.0393
,
0.0455
,
0.0556
,
0.0180
,
-
0.0832
,
-
0.0644
,
-
0.085
6
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-
2
))
def
test_output_pretrained_ve_large
(
self
):
model
=
NCSNpp
.
from_pretrained
(
"fusing/ncsnpp-ffhq-ve-dummy"
)
...
...
@@ -768,21 +768,21 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
0
]).
to
(
torch_device
)
noise
=
torch
.
ones
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
1
e-4
]).
to
(
torch_device
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
8.3299e-07
,
-
9.0431e-07
,
4.0585e-08
,
9.7563e-07
,
1.0280e-06
,
1.0133e-06
,
1.4979e-06
,
-
2.9716e-07
,
-
6.1817e-07
])
expected_output_slice
=
torch
.
tensor
([
-
0.0325
,
-
0.0900
,
-
0.0869
,
-
0.0332
,
-
0.0725
,
-
0.0270
,
-
0.0101
,
0.0227
,
0.0256
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-
2
))
def
test_output_pretrained_vp
(
self
):
model
=
NCSNpp
.
from_pretrained
(
"fusing/
ddpm-
cifar10-
vp-dummy
"
)
model
=
NCSNpp
.
from_pretrained
(
"fusing/cifar10-
ddpmpp-vp
"
)
model
.
eval
()
model
.
to
(
torch_device
)
...
...
@@ -794,18 +794,18 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
10
]).
to
(
torch_device
)
noise
=
torch
.
randn
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
(
batch_size
*
[
9.
]).
to
(
torch_device
)
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
)
output_slice
=
output
[
0
,
-
3
:,
-
3
:,
-
1
].
flatten
().
cpu
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
3.9086e-07
,
-
1.1001e-0
5
,
1
.88
81e-06
,
1.1106e-05
,
1.
6629e-06
,
2.9820e-06
,
8
.4
978e-06
,
8.0253e-07
,
1.
5435e-0
6
])
expected_output_slice
=
torch
.
tensor
([
0.3303
,
-
0.227
5
,
-
2
.88
72
,
-
0.1309
,
-
1.
2861
,
3
.4
567
,
-
1.0083
,
2.5325
,
-
1.
386
6
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-
3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
rtol
=
1e-
2
))
class
VQModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment