Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
7dc71897
Commit
7dc71897
authored
Jun 17, 2022
by
patil-suraj
Browse files
add UnetModelTests
parent
800b2770
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
43 additions
and
27 deletions
+43
-27
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+43
-27
No files found.
tests/test_modeling_utils.py
View file @
7dc71897
...
@@ -82,44 +82,25 @@ class ConfigTester(unittest.TestCase):
...
@@ -82,44 +82,25 @@ class ConfigTester(unittest.TestCase):
assert
config
==
new_config
assert
config
==
new_config
class
ModelTesterMixin
(
unittest
.
TestCase
):
class
ModelTesterMixin
:
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
(
noise
,
time_step
)
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
model
=
UNetModel
(
ch
=
32
,
ch_mult
=
(
1
,
2
),
num_res_blocks
=
2
,
attn_resolutions
=
(
16
,),
resolution
=
32
)
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
eval
()
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
model
.
save_pretrained
(
tmpdirname
)
new_model
=
UNetModel
.
from_pretrained
(
tmpdirname
)
new_model
=
UNetModel
.
from_pretrained
(
tmpdirname
)
new_model
.
to
(
torch_device
)
new_model
.
to
(
torch_device
)
dummy_input
=
self
.
dummy_input
with
torch
.
no_grad
():
image
=
model
(
**
inputs_dict
)
image
=
model
(
*
dummy_input
)
new_image
=
new_model
(
**
inputs_dict
)
new_image
=
new_model
(
*
dummy_input
)
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
assert
(
image
-
new_image
).
abs
().
sum
()
<
1e-5
,
"Models don't give the same forward pass"
def
test_from_pretrained_hub
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
to
(
torch_device
)
image
=
model
(
*
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_save_load
(
self
):
pass
def
test_determinism
(
self
):
def
test_determinism
(
self
):
pass
pass
...
@@ -137,6 +118,41 @@ class ModelTesterMixin(unittest.TestCase):
...
@@ -137,6 +118,41 @@ class ModelTesterMixin(unittest.TestCase):
pass
pass
class
UnetModelTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
UNetModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
noise
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]).
to
(
torch_device
)
return
{
"x"
:
noise
,
"t"
:
time_step
}
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"ch"
:
32
,
"ch_mult"
:
(
1
,
2
),
"num_res_blocks"
:
2
,
"attn_resolutions"
:
(
16
,),
"resolution"
:
32
,
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_from_pretrained_hub
(
self
):
model
=
UNetModel
.
from_pretrained
(
"fusing/ddpm_dummy"
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
class
PipelineTesterMixin
(
unittest
.
TestCase
):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
# 1. Load models
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment