Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
renzhc
diffusers_dcu
Commits
147d8e07
Commit
147d8e07
authored
Jun 14, 2022
by
patil-suraj
Browse files
add test for loading model from pipeline module
parent
d81b56ba
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
1 deletion
+18
-1
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+18
-1
No files found.
tests/test_modeling_utils.py
View file @
147d8e07
...
...
@@ -19,9 +19,10 @@ import unittest
import
torch
from
diffusers
import
DDIM
,
DDPM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
UNetModel
,
PNDM
,
PNDMScheduler
from
diffusers
import
DDIM
,
DDPM
,
BDDM
,
DDIMScheduler
,
DDPMScheduler
,
LatentDiffusion
,
UNetModel
,
PNDM
,
PNDMScheduler
from
diffusers.configuration_utils
import
ConfigMixin
from
diffusers.pipeline_utils
import
DiffusionPipeline
from
diffusers.pipelines.pipeline_bddm
import
DiffWave
from
diffusers.testing_utils
import
floats_tensor
,
slow
,
torch_device
...
...
@@ -212,3 +213,19 @@ class PipelineTesterMixin(unittest.TestCase):
assert
image
.
shape
==
(
1
,
3
,
256
,
256
)
expected_slice
=
torch
.
tensor
([
0.7295
,
0.7358
,
0.7256
,
0.7435
,
0.7095
,
0.6884
,
0.7325
,
0.6921
,
0.6458
])
assert
(
image_slice
.
flatten
()
-
expected_slice
).
abs
().
max
()
<
1e-2
def
test_module_from_pipeline
(
self
):
model
=
DiffWave
(
num_res_layers
=
4
)
noise_scheduler
=
DDPMScheduler
(
timesteps
=
12
)
bddm
=
BDDM
(
model
,
noise_scheduler
)
# check if the library name for the diffwave moduel is set to pipeline module
self
.
assertTrue
(
bddm
.
config
[
"diffwave"
][
0
]
==
"pipeline_bddm"
)
# check if we can save and load the pipeline
with
tempfile
.
TemporaryDirectory
()
as
tmpdirname
:
bddm
.
save_pretrained
(
tmpdirname
)
_
=
BDDM
.
from_pretrained
(
tmpdirname
)
# check if the same works using the DifusionPipeline class
_
=
DiffusionPipeline
.
from_pretrained
(
tmpdirname
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment