Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
18ef809c
"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "87e45f776bac962a738db341e7996a247222b9d3"
Commit
18ef809c
authored
May 31, 2022
by
Patrick von Platen
Browse files
add another test
parent
e779b250
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
7 deletions
+22
-7
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+22
-7
No files found.
tests/test_modeling_utils.py
View file @
18ef809c
...
@@ -42,6 +42,18 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
...
@@ -42,6 +42,18 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
class
ModelTesterMixin
(
unittest
.
TestCase
):
class
ModelTesterMixin
(
unittest
.
TestCase
):
@
property
def
dummy_input
(
self
):
batch_size
=
1
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
)
time_step
=
torch
.
tensor
([
10
])
return
(
noise
,
time_step
)
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
config
=
UNetConfig
(
dim
=
8
,
dim_mults
=
(
1
,
2
),
resnet_block_groups
=
2
)
config
=
UNetConfig
(
dim
=
8
,
dim_mults
=
(
1
,
2
),
resnet_block_groups
=
2
)
model
=
UNetModel
(
config
)
model
=
UNetModel
(
config
)
...
@@ -50,13 +62,16 @@ class ModelTesterMixin(unittest.TestCase):
...
@@ -50,13 +62,16 @@ class ModelTesterMixin(unittest.TestCase):
model
.
save_pretrained
(
tmpdirname
)
model
.
save_pretrained
(
tmpdirname
)
new_model
=
UNetModel
.
from_pretrained
(
tmpdirname
)
new_model
=
UNetModel
.
from_pretrained
(
tmpdirname
)
batch_size
=
1
dummy_input
=
self
.
dummy_input
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
)
time_step
=
torch
.
tensor
([
10
])
image
=
model
(
noise
,
time_step
)
image
=
model
(
*
dummy_input
)
new_image
=
new_model
(
noise
,
time_step
)
new_image
=
new_model
(
*
dummy_input
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
def
test_from_pretrained_hub
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
image
=
model
(
*
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment