Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
1c953bc3
Unverified
Commit
1c953bc3
authored
Jun 17, 2022
by
Suraj Patil
Committed by
GitHub
Jun 17, 2022
Browse files
Add tests for GLIDESuperResUNetModel # 22
Add tests for GLIDESuperResUNetModel
parents
d182a6ad
e007c797
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
107 additions
and
6 deletions
+107
-6
src/diffusers/modeling_utils.py
src/diffusers/modeling_utils.py
+3
-3
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+104
-3
No files found.
src/diffusers/modeling_utils.py
View file @
1c953bc3
...
@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
raise
RuntimeError
(
f
"Error(s) in loading state_dict for
{
model
.
__class__
.
__name__
}
:
\n\t
{
error_msg
}
"
)
raise
RuntimeError
(
f
"Error(s) in loading state_dict for
{
model
.
__class__
.
__name__
}
:
\n\t
{
error_msg
}
"
)
if
len
(
unexpected_keys
)
>
0
:
if
len
(
unexpected_keys
)
>
0
:
logger
.
warning
ing
(
logger
.
warning
(
f
"Some weights of the model checkpoint at
{
pretrained_model_name_or_path
}
were not used when"
f
"Some weights of the model checkpoint at
{
pretrained_model_name_or_path
}
were not used when"
f
" initializing
{
model
.
__class__
.
__name__
}
:
{
unexpected_keys
}
\n
- This IS expected if you are"
f
" initializing
{
model
.
__class__
.
__name__
}
:
{
unexpected_keys
}
\n
- This IS expected if you are"
f
" initializing
{
model
.
__class__
.
__name__
}
from the checkpoint of a model trained on another task or"
f
" initializing
{
model
.
__class__
.
__name__
}
from the checkpoint of a model trained on another task or"
...
@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -502,7 +502,7 @@ class ModelMixin(torch.nn.Module):
else
:
else
:
logger
.
info
(
f
"All model checkpoint weights were used when initializing
{
model
.
__class__
.
__name__
}
.
\n
"
)
logger
.
info
(
f
"All model checkpoint weights were used when initializing
{
model
.
__class__
.
__name__
}
.
\n
"
)
if
len
(
missing_keys
)
>
0
:
if
len
(
missing_keys
)
>
0
:
logger
.
warning
ing
(
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized:
{
missing_keys
}
\n
You should probably"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized:
{
missing_keys
}
\n
You should probably"
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
...
@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module):
...
@@ -521,7 +521,7 @@ class ModelMixin(torch.nn.Module):
for
key
,
shape1
,
shape2
in
mismatched_keys
for
key
,
shape1
,
shape2
in
mismatched_keys
]
]
)
)
logger
.
warning
ing
(
logger
.
warning
(
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"Some weights of
{
model
.
__class__
.
__name__
}
were not initialized from the model checkpoint at"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized because the shapes did not"
f
"
{
pretrained_model_name_or_path
}
and are newly initialized because the shapes did not"
f
" match:
\n
{
mismatched_warning
}
\n
You should probably TRAIN this model on a down-stream task to be able"
f
" match:
\n
{
mismatched_warning
}
\n
You should probably TRAIN this model on a down-stream task to be able"
...
...
tests/test_modeling_utils.py
View file @
1c953bc3
...
@@ -18,6 +18,7 @@ import inspect
...
@@ -18,6 +18,7 @@ import inspect
import
tempfile
import
tempfile
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
pytest
import
torch
import
torch
...
@@ -32,6 +33,7 @@ from diffusers import (
...
@@ -32,6 +33,7 @@ from diffusers import (
LatentDiffusion
,
LatentDiffusion
,
PNDMScheduler
,
PNDMScheduler
,
UNetModel
,
UNetModel
,
GLIDESuperResUNetModel
)
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
@@ -94,7 +96,7 @@ class ModelTesterMixin:
...
@@ -94,7 +96,7 @@ class ModelTesterMixin:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
model
.
save_pretrained
(
tmpdirname
)
model
.
save_pretrained
(
tmpdirname
)
new_model
=
UNetModel
.
from_pretrained
(
tmpdirname
)
new_model
=
self
.
model_class
.
from_pretrained
(
tmpdirname
)
new_model
.
to
(
torch_device
)
new_model
.
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -178,7 +180,7 @@ class ModelTesterMixin:
...
@@ -178,7 +180,7 @@ class ModelTesterMixin:
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
model
.
train
()
model
.
train
()
output
=
model
(
**
inputs_dict
)
output
=
model
(
**
inputs_dict
)
noise
=
torch
.
randn
(
inputs_dict
[
"x"
].
shape
).
to
(
torch_device
)
noise
=
torch
.
randn
(
(
inputs_dict
[
"x"
].
shape
[
0
],
)
+
self
.
get_output_shape
).
to
(
torch_device
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
=
torch
.
nn
.
functional
.
mse_loss
(
output
,
noise
)
loss
.
backward
()
loss
.
backward
()
...
@@ -197,6 +199,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -197,6 +199,14 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
return
{
"x"
:
noise
,
"timesteps"
:
time_step
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
init_dict
=
{
"ch"
:
32
,
"ch"
:
32
,
...
@@ -227,7 +237,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -227,7 +237,6 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
torch
.
cuda
.
manual_seed_all
(
0
)
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
noise
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
print
(
noise
.
shape
)
time_step
=
torch
.
tensor
([
10
])
time_step
=
torch
.
tensor
([
10
])
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -240,6 +249,98 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -240,6 +249,98 @@ class UnetModelTests(ModelTesterMixin, unittest.TestCase):
print
(
output_slice
)
print
(
output_slice
)
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
GLIDESuperResUNetTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
GLIDESuperResUNetModel
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
6
sizes
=
(
32
,
32
)
low_res_size
=
(
4
,
4
)
torch_device
=
"cpu"
noise
=
torch
.
randn
((
batch_size
,
num_channels
//
2
)
+
sizes
).
to
(
torch_device
)
low_res
=
torch
.
randn
((
batch_size
,
3
)
+
low_res_size
).
to
(
torch_device
)
time_step
=
torch
.
tensor
([
10
]
*
noise
.
shape
[
0
],
device
=
torch_device
)
return
{
"x"
:
noise
,
"timesteps"
:
time_step
,
"low_res"
:
low_res
}
@
property
def
get_input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
get_output_shape
(
self
):
return
(
6
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"attention_resolutions"
:
(
2
,),
"channel_mult"
:
(
1
,
2
),
"in_channels"
:
6
,
"out_channels"
:
6
,
"model_channels"
:
32
,
"num_head_channels"
:
8
,
"num_heads_upsample"
:
1
,
"num_res_blocks"
:
2
,
"resblock_updown"
:
True
,
"resolution"
:
32
,
"use_scale_shift_norm"
:
True
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_output
(
self
):
init_dict
,
inputs_dict
=
self
.
prepare_init_args_and_inputs_for_common
()
model
=
self
.
model_class
(
**
init_dict
)
model
.
to
(
torch_device
)
model
.
eval
()
with
torch
.
no_grad
():
output
=
model
(
**
inputs_dict
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
self
.
assertIsNotNone
(
output
)
expected_shape
=
inputs_dict
[
"x"
].
shape
self
.
assertEqual
(
output
.
shape
,
expected_shape
,
"Input and output shapes do not match"
)
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
# TODO (patil-suraj): Check why GLIDESuperResUNetModel always outputs zero
@
unittest
.
skip
(
"GLIDESuperResUNetModel always outputs zero"
)
def
test_output_pretrained
(
self
):
model
=
GLIDESuperResUNetModel
.
from_pretrained
(
"fusing/glide-super-res-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
noise
=
torch
.
randn
(
1
,
3
,
32
,
32
)
low_res
=
torch
.
randn
(
1
,
3
,
4
,
4
)
time_step
=
torch
.
tensor
([
42
]
*
noise
.
shape
[
0
])
with
torch
.
no_grad
():
output
=
model
(
noise
,
time_step
,
low_res
)
output
,
_
=
torch
.
split
(
output
,
3
,
dim
=
1
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
0.2891
,
-
0.1899
,
0.2595
,
-
0.6214
,
0.0968
,
-
0.2622
,
0.4688
,
0.1311
,
0.0053
])
# fmt: on
print
(
output_slice
)
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
PipelineTesterMixin
(
unittest
.
TestCase
):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment