Unverified Commit 836f3f35 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Rename pipelines (#115)

up
parent 9c3820d0
......@@ -3,7 +3,7 @@ import argparse
import OmegaConf
import torch
from diffusers import UNetLDMModel, VQModel, LatentDiffusionUncondPipeline, DDIMScheduler
from diffusers import UNetLDMModel, VQModel, LDMPipeline, DDIMScheduler
def convert_ldm_original(checkpoint_path, config_path, output_path):
config = OmegaConf.load(config_path)
......@@ -41,7 +41,7 @@ def convert_ldm_original(checkpoint_path, config_path, output_path):
clip_sample=False,
)
pipeline = LatentDiffusionUncondPipeline(vqvae, unet, noise_scheduler)
pipeline = LDMPipeline(vqvae, unet, noise_scheduler)
pipeline.save_pretrained(output_path)
......
......@@ -17,7 +17,7 @@
import argparse
import json
import torch
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LatentDiffusionUncondPipeline
from diffusers import VQModel, DDPMScheduler, UNet2DModel, LDMPipeline
def shave_segments(path, n_shave_prefix_segments=1):
......@@ -326,7 +326,7 @@ if __name__ == "__main__":
scheduler = DDPMScheduler.from_config("/".join(args.checkpoint_path.split("/")[:-1]))
vqvae = VQModel.from_pretrained("/".join(args.checkpoint_path.split("/")[:-1]))
pipe = LatentDiffusionUncondPipeline(unet=model, scheduler=scheduler, vae=vqvae)
pipe = LDMPipeline(unet=model, scheduler=scheduler, vae=vqvae)
pipe.save_pretrained(args.dump_path)
except:
model.save_pretrained(args.dump_path)
......@@ -9,11 +9,11 @@ __version__ = "0.0.4"
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
from .pipeline_utils import DiffusionPipeline
from .pipelines import DDIMPipeline, DDPMPipeline, LatentDiffusionUncondPipeline, PNDMPipeline, ScoreSdeVePipeline
from .pipelines import DDIMPipeline, DDPMPipeline, LDMPipeline, PNDMPipeline, ScoreSdeVePipeline
from .schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, SchedulerMixin, ScoreSdeVeScheduler
if is_transformers_available():
from .pipelines import LatentDiffusionPipeline
from .pipelines import LDMTextToImagePipeline
else:
from .utils.dummy_transformers_objects import *
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LatentDiffusionUncondPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
if is_transformers_available():
from .latent_diffusion import LatentDiffusionPipeline
from .latent_diffusion import LDMTextToImagePipeline
......@@ -2,4 +2,4 @@ from ...utils import is_transformers_available
if is_transformers_available():
from .pipeline_latent_diffusion import LatentDiffusionPipeline, LDMBertModel
from .pipeline_latent_diffusion import LDMBertModel, LDMTextToImagePipeline
......@@ -14,7 +14,7 @@ from transformers.utils import logging
from ...pipeline_utils import DiffusionPipeline
class LatentDiffusionPipeline(DiffusionPipeline):
class LDMTextToImagePipeline(DiffusionPipeline):
def __init__(self, vqvae, bert, tokenizer, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
......
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline
from .pipeline_latent_diffusion_uncond import LDMPipeline
......@@ -5,7 +5,7 @@ from tqdm.auto import tqdm
from ...pipeline_utils import DiffusionPipeline
class LatentDiffusionUncondPipeline(DiffusionPipeline):
class LDMPipeline(DiffusionPipeline):
def __init__(self, vqvae, unet, scheduler):
super().__init__()
scheduler = scheduler.set_format("pt")
......
......@@ -3,42 +3,7 @@
from ..utils import DummyObject, requires_backends
class GlideSuperResUNetModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class GlideTextToImageUNetModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class GlideUNetModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class UNetGradTTSModel(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class GlidePipeline(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["transformers"])
class LatentDiffusionPipeline(metaclass=DummyObject):
class LDMTextToImagePipeline(metaclass=DummyObject):
_backends = ["transformers"]
def __init__(self, *args, **kwargs):
......
......@@ -29,8 +29,8 @@ from diffusers import (
DDIMScheduler,
DDPMPipeline,
DDPMScheduler,
LatentDiffusionPipeline,
LatentDiffusionUncondPipeline,
LDMPipeline,
LDMTextToImagePipeline,
PNDMPipeline,
PNDMScheduler,
ScoreSdeVePipeline,
......@@ -826,7 +826,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ldm_text2img(self):
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
ldm = LDMTextToImagePipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
......@@ -842,7 +842,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ldm_text2img_fast(self):
ldm = LatentDiffusionPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
ldm = LDMTextToImagePipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-text2im-large-256")
prompt = "A painting of a squirrel eating a burger"
generator = torch.manual_seed(0)
......@@ -877,7 +877,7 @@ class PipelineTesterMixin(unittest.TestCase):
@slow
def test_ldm_uncond(self):
ldm = LatentDiffusionUncondPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256")
ldm = LDMPipeline.from_pretrained("/home/patrick/google_checkpoints/ldm-celebahq-256")
generator = torch.manual_seed(0)
image = ldm(generator=generator, num_inference_steps=5, output_type="numpy")["sample"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment