Commit bdecc3cf authored by Patrick von Platen's avatar Patrick von Platen
Browse files

move pipelines into folders

parent 0efac0aa
......@@ -57,17 +57,19 @@ class DiffusionPipeline(ConfigMixin):
from diffusers import pipelines
for name, module in kwargs.items():
# check if the module is a pipeline module
is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1])
# retrive library
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
pipeline_file = module.__module__.split(".")[-1]
pipeline_dir = module.__module__.split(".")[-2]
is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# so we set the library to module name.
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = module.__module__.split(".")[-1]
library = pipeline_dir
# retrive class_name
class_name = module.__class__.__name__
......
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .pipeline_bddm import BDDMPipeline
from .pipeline_ddim import DDIMPipeline
from .pipeline_ddpm import DDPMPipeline
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline
from .pipeline_pndm import PNDMPipeline
from .pipeline_score_sde_ve import ScoreSdeVePipeline
from .pipeline_score_sde_vp import ScoreSdeVpPipeline
# from .pipeline_score_sde import ScoreSdeVePipeline
from .bddm import BDDMPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .latent_diffusion_uncond import LatentDiffusionUncondPipeline
from .pndm import PNDMPipeline
from .score_sde_ve import ScoreSdeVePipeline
from .score_sde_vp import ScoreSdeVpPipeline
if is_transformers_available():
from .pipeline_glide import GlidePipeline
from .pipeline_latent_diffusion import LatentDiffusionPipeline
from .glide import GlidePipeline
from .latent_diffusion import LatentDiffusionPipeline
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
from .pipeline_grad_tts import GradTTSPipeline
from .grad_tts import GradTTSPipeline
from .pipeline_bddm import BDDMPipeline, DiffWave
......@@ -21,9 +21,9 @@ import torch.nn.functional as F
import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
from ...configuration_utils import ConfigMixin
from ...modeling_utils import ModelMixin
from ...pipeline_utils import DiffusionPipeline
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
......
from .pipeline_ddim import DDIMPipeline
......@@ -18,7 +18,7 @@ import torch
import tqdm
from ..pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline
class DDIMPipeline(DiffusionPipeline):
......
from .pipeline_ddpm import DDPMPipeline
......@@ -18,7 +18,7 @@ import torch
import tqdm
from ..pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline
class DDPMPipeline(DiffusionPipeline):
......
from ...utils import is_transformers_available
if is_transformers_available():
from .pipeline_glide import CLIPTextModel, GlidePipeline
......@@ -18,7 +18,6 @@ import math
from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union
import numpy as np
import torch
import torch.utils.checkpoint
from torch import nn
......@@ -30,10 +29,10 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from ..models import GlideSuperResUNetModel, GlideTextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline
from ..schedulers import DDIMScheduler, DDPMScheduler
from ..utils import logging
from ...models import GlideSuperResUNetModel, GlideTextToImageUNetModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import DDIMScheduler, DDPMScheduler
from ...utils import logging
#####################
......@@ -594,7 +593,7 @@ class CLIPTextTransformer(nn.Module):
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
# causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
# expand attention_mask
if attention_mask is not None:
......
from ...utils import is_inflect_available, is_transformers_available, is_unidecode_available
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
from .grad_tts_utils import GradTTSTokenizer
from .pipeline_grad_tts import GradTTSPipeline, TextEncoder
......@@ -6,10 +6,10 @@ import torch
from torch import nn
import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin
from ...modeling_utils import ModelMixin
from ...pipeline_utils import DiffusionPipeline
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
......
from ...utils import is_transformers_available
if is_transformers_available():
from .pipeline_latent_diffusion import AutoencoderKL, LatentDiffusionPipeline, LDMBertModel
......@@ -7,20 +7,15 @@ import torch.nn as nn
import torch.utils.checkpoint
import tqdm
try:
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
except ImportError:
raise ImportError("Please install the transformers.")
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from ...configuration_utils import ConfigMixin
from ...modeling_utils import ModelMixin
from ...pipeline_utils import DiffusionPipeline
################################################################################
......
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline
......@@ -6,9 +6,9 @@ import torch.nn as nn
import tqdm
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
from ...configuration_utils import ConfigMixin
from ...modeling_utils import ModelMixin
from ...pipeline_utils import DiffusionPipeline
def get_timestep_embedding(timesteps, embedding_dim):
......
from .pipeline_pndm import PNDMPipeline
......@@ -18,7 +18,7 @@ import torch
import tqdm
from ..pipeline_utils import DiffusionPipeline
from ...pipeline_utils import DiffusionPipeline
class PNDMPipeline(DiffusionPipeline):
......
from .pipeline_score_sde_ve import ScoreSdeVePipeline
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment