Unverified Commit 836f3f35 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Rename pipelines (#115)

up
parent 9c3820d0
...@@ -3,7 +3,7 @@ import argparse ...@@ -3,7 +3,7 @@ import argparse
import OmegaConf import OmegaConf
import torch import torch
from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler from diffusers import UNetLDMModel, VQModel, LDMPipeline, DDIMScheduler
def convert_ldm_original(checkpoint_path, config_path, output_path): def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path) config = OmegaConf.load(config_path)
...@@ -41,7 +41,7 @@ def convert_ldm_original(checkpoint_path, config_path, output_path): ...@@ -41,7 +41,7 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
clip_sample=False, clip_sample=False,
) )
pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler) pipeline = LDMPipeline(vqvae, unet, noise_scheduler)
pipeline.save_pretrained(output_path) pipeline.save_pretrained(output_path)
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import argparse import argparse
import json import json
import torch import torch
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LatentDiffusionUncondPipeline from diffusers import VQModel, DDPMScheduler, UNet2DModel, LDMPipeline
def shave_segments(path, n_shave_prefix_segments=1): def shave_segments(path, n_shave_prefix_segments=1):
...@@ -326,7 +326,7 @@ if __name__ == "__main__": ...@@ -326,7 +326,7 @@ if __name__ == "__main__":
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1])) scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1])) vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1]))
pipe = LatentDiffusionUncondPipeline(unet=model, scheduler=scheduler, vae=vqvae) pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)
pipe.save_pretrained(args.dump_path) pipe.save_pretrained(args.dump_path)
except: except:
model.save_pretrained(args.dump_path) model.save_pretrained(args.dump_path)
...@@ -9,11 +9,11 @@ __version__ = "0.0.4" ...@@ -9,11 +9,11 @@ __version__ = "0.0.4"
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .pipeline_utils import DiffusionPipeline from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline from .pipelines import DDIMPipeline, DDPMPipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
if is_transformers_available(): if is_transformers_available():
from .pipelines import LatentDiffusionPipeline from .pipelines import LDMTextToImagePipeline
else: else:
from .utils.dummy_transformers_objects import * from .utils.dummy_transformers_objects import *
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .ddim import DDIMPipeline from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LatentDiffusionUncondPipeline from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline from .score_sde_ve import ScoreSdeVePipeline
if is_transformers_available(): if is_transformers_available():
from .latent_diffusion import LatentDiffusionPipeline from .latent_diffusion import LDMTextToImagePipeline
...@@ -2,4 +2,4 @@ from ...utils import is_transformers_available ...@@ -2,4 +2,4 @@ from ...utils import is_transformers_available
if is_transformers_available(): if is_transformers_available():
from .pipeline_latent_diffusion import LatentDiffusionPipeline, LDMBertModel from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline
...@@ -14,7 +14,7 @@ from transformers.utils import logging ...@@ -14,7 +14,7 @@ from transformers.utils import logging
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
class LatentDiffusionPipeline(DiffusionPipeline): class LDMTextToImagePipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler): def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
......
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline from .pipeline_latent_diffusion_uncond import LDMPipeline
...@@ -5,7 +5,7 @@ from tqdm.auto import tqdm ...@@ -5,7 +5,7 @@ from tqdm.auto import tqdm
from ...pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
class LatentDiffusionUncondPipeline(DiffusionPipeline): class LDMPipeline(DiffusionPipeline):
def __init__(self, vqvae, unet, scheduler): def __init__(self, vqvae, unet, scheduler):
super().__init__() super().__init__()
scheduler = scheduler.set_format("pt") scheduler = scheduler.set_format("pt")
......
...@@ -3,42 +3,7 @@ ...@@ -3,42 +3,7 @@
from ..utils import DummyObject, requires_backends from ..utils import DummyObject, requires_backends
class GlideSuperResUNetModel(metaclass=DummyObject): class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class GlideTextToImageUNetModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class GlideUNetModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class UNetGradTTSModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class GlidePipeline(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class LatentDiffusionPipeline(metaclass=DummyObject):
_backends = ["transformers"] _backends = ["transformers"]
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
......
...@@ -29,8 +29,8 @@ from diffusers import ( ...@@ -29,8 +29,8 @@ from diffusers import (
DDIMScheduler, DDIMScheduler,
DDPMPipeline, DDPMPipeline,
DDPMScheduler, DDPMScheduler,
LatentDiffusionPipeline, LDMPipeline,
LatentDiffusionUncondPipeline, LDMTextToImagePipeline,
PNDMPipeline, PNDMPipeline,
PNDMScheduler, PNDMScheduler,
ScoreSdeVePipeline, ScoreSdeVePipeline,
...@@ -826,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -826,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ldm_text2img(self): def test_ldm_text2img(self):
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256") ldm = LDMTextToImagePipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -842,7 +842,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -842,7 +842,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ldm_text2img_fast(self): def test_ldm_text2img_fast(self):
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256") ldm = LDMTextToImagePipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
prompt = "A painting of a squirrel eating a burger" prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
...@@ -877,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase): ...@@ -877,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow @slow
def test_ldm_uncond(self): def test_ldm_uncond(self):
ldm = LatentDiffusionUncondPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256") ldm = LDMPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256")
generator = torch.manual_seed(0) generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"] image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment