Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
333a8da6
"vscode:/vscode.git/clone" did not exist on "a1fad8286f86c46821f8038d86e358e9cc62d20f"
Commit
333a8da6
authored
Jun 29, 2022
by
patil-suraj
Browse files
add tests for AutoencoderKL
parent
976173a4
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
76 additions
and
3 deletions
+76
-3
src/diffusers/models/vae.py
src/diffusers/models/vae.py
+3
-3
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+73
-0
No files found.
src/diffusers/models/vae.py
View file @
333a8da6
...
@@ -626,11 +626,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
...
@@ -626,11 +626,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
dec
=
self
.
decoder
(
z
)
dec
=
self
.
decoder
(
z
)
return
dec
return
dec
def
forward
(
self
,
input
,
sample_posterior
=
Tru
e
):
def
forward
(
self
,
x
,
sample_posterior
=
Fals
e
):
posterior
=
self
.
encode
(
input
)
posterior
=
self
.
encode
(
x
)
if
sample_posterior
:
if
sample_posterior
:
z
=
posterior
.
sample
()
z
=
posterior
.
sample
()
else
:
else
:
z
=
posterior
.
mode
()
z
=
posterior
.
mode
()
dec
=
self
.
decode
(
z
)
dec
=
self
.
decode
(
z
)
return
dec
,
posterior
return
dec
tests/test_modeling_utils.py
View file @
333a8da6
...
@@ -46,6 +46,7 @@ from diffusers import (
...
@@ -46,6 +46,7 @@ from diffusers import (
UNetLDMModel
,
UNetLDMModel
,
UNetModel
,
UNetModel
,
VQModel
,
VQModel
,
AutoencoderKL
,
)
)
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipeline_utils
import
DiffusionPipeline
...
@@ -883,6 +884,78 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
...
@@ -883,6 +884,78 @@ class VQModelTests(ModelTesterMixin, unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
AutoEncoderKLTests
(
ModelTesterMixin
,
unittest
.
TestCase
):
model_class
=
AutoencoderKL
@
property
def
dummy_input
(
self
):
batch_size
=
4
num_channels
=
3
sizes
=
(
32
,
32
)
image
=
floats_tensor
((
batch_size
,
num_channels
)
+
sizes
).
to
(
torch_device
)
return
{
"x"
:
image
}
@
property
def
input_shape
(
self
):
return
(
3
,
32
,
32
)
@
property
def
output_shape
(
self
):
return
(
3
,
32
,
32
)
def
prepare_init_args_and_inputs_for_common
(
self
):
init_dict
=
{
"ch"
:
64
,
"ch_mult"
:
(
1
,),
"embed_dim"
:
4
,
"in_channels"
:
3
,
"num_res_blocks"
:
1
,
"out_ch"
:
3
,
"resolution"
:
32
,
"z_channels"
:
4
,
"attn_resolutions"
:
[]
}
inputs_dict
=
self
.
dummy_input
return
init_dict
,
inputs_dict
def
test_forward_signature
(
self
):
pass
def
test_training
(
self
):
pass
def
test_from_pretrained_hub
(
self
):
model
,
loading_info
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
,
output_loading_info
=
True
)
self
.
assertIsNotNone
(
model
)
self
.
assertEqual
(
len
(
loading_info
[
"missing_keys"
]),
0
)
model
.
to
(
torch_device
)
image
=
model
(
**
self
.
dummy_input
)
assert
image
is
not
None
,
"Make sure output is not None"
def
test_output_pretrained
(
self
):
model
=
AutoencoderKL
.
from_pretrained
(
"fusing/autoencoder-kl-dummy"
)
model
.
eval
()
torch
.
manual_seed
(
0
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
manual_seed_all
(
0
)
image
=
torch
.
randn
(
1
,
model
.
config
.
in_channels
,
model
.
config
.
resolution
,
model
.
config
.
resolution
)
with
torch
.
no_grad
():
output
=
model
(
image
,
sample_posterior
=
True
)
output_slice
=
output
[
0
,
-
1
,
-
3
:,
-
3
:].
flatten
()
# fmt: off
expected_output_slice
=
torch
.
tensor
([
-
0.0814
,
-
0.0229
,
-
0.1320
,
-
0.4123
,
-
0.0366
,
-
0.3473
,
0.0438
,
-
0.1662
,
0.1750
])
# fmt: on
self
.
assertTrue
(
torch
.
allclose
(
output_slice
,
expected_output_slice
,
atol
=
1e-3
))
class
PipelineTesterMixin
(
unittest
.
TestCase
):
class
PipelineTesterMixin
(
unittest
.
TestCase
):
def
test_from_pretrained_save_pretrained
(
self
):
def
test_from_pretrained_save_pretrained
(
self
):
# 1. Load models
# 1. Load models
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment