Commit 147d8e07 authored by patil-suraj's avatar patil-suraj
Browse files

add test for loading model from pipeline module

parent d81b56ba
......@@ -19,9 +19,10 @@ import unittest
import torch
from diffusers import DDIM, DDPM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
from diffusers import DDIM, DDPM, BDDM, DDIMScheduler, DDPMScheduler, LatentDiffusion, UNetModel, PNDM, PNDMScheduler
from diffusers.configuration_utils import ConfigMixin
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.pipeline_bddm import DiffWave
from diffusers.testing_utils import floats_tensor, slow, torch_device
......@@ -212,3 +213,19 @@ class PipelineTesterMixin(unittest.TestCase):
assert image.shape == (1, 3, 256, 256)
expected_slice = torch.tensor([0.7295, 0.7358, 0.7256, 0.7435, 0.7095, 0.6884, 0.7325, 0.6921, 0.6458])
assert (image_slice.flatten() - expected_slice).abs().max() < 1e-2
def test_module_from_pipeline(self):
model = DiffWave(num_res_layers=4)
noise_scheduler = DDPMScheduler(timesteps=12)
bddm = BDDM(model, noise_scheduler)
# check if the library name for the diffwave moduel is set to pipeline module
self.assertTrue(bddm.config["diffwave"][0] == "pipeline_bddm")
# check if we can save and load the pipeline
with tempfile.TemporaryDirectory() as tmpdirname:
bddm.save_pretrained(tmpdirname)
_ = BDDM.from_pretrained(tmpdirname)
# check if the same works using the DifusionPipeline class
_ = DiffusionPipeline.from_pretrained(tmpdirname)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment