Commit bdecc3cf authored by Patrick von Platen's avatar Patrick von Platen
Browse files

move pipelines into folders

parent 0efac0aa
...@@ -57,17 +57,19 @@ class DiffusionPipeline(ConfigMixin): ...@@ -57,17 +57,19 @@ class DiffusionPipeline(ConfigMixin):
from diffusers import pipelines from diffusers import pipelines
for name, module in kwargs.items(): for name, module in kwargs.items():
# check if the module is a pipeline module
is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1])
# retrive library # retrive library
library = module.__module__.split(".")[0] library = module.__module__.split(".")[0]
# check if the module is a pipeline module
pipeline_file = module.__module__.split(".")[-1]
pipeline_dir = module.__module__.split(".")[-2]
is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module. # if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline # Or if it's a pipeline module, then the module is inside the pipeline
# so we set the library to module name. # folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module: if library not in LOADABLE_CLASSES or is_pipeline_module:
library = module.__module__.split(".")[-1] library = pipeline_dir
# retrive class_name # retrive class_name
class_name = module.__class__.__name__ class_name = module.__class__.__name__
......
from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available from ..utils import is_inflect_available, is_transformers_available, is_unidecode_available
from .pipeline_bddm import BDDMPipeline from .bddm import BDDMPipeline
from .pipeline_ddim import DDIMPipeline from .ddim import DDIMPipeline
from .pipeline_ddpm import DDPMPipeline from .ddpm import DDPMPipeline
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline from .latent_diffusion_uncond import LatentDiffusionUncondPipeline
from .pipeline_pndm import PNDMPipeline from .pndm import PNDMPipeline
from .pipeline_score_sde_ve import ScoreSdeVePipeline from .score_sde_ve import ScoreSdeVePipeline
from .pipeline_score_sde_vp import ScoreSdeVpPipeline from .score_sde_vp import ScoreSdeVpPipeline
# from .pipeline_score_sde import ScoreSdeVePipeline
if is_transformers_available(): if is_transformers_available():
from .pipeline_glide import GlidePipeline from .glide import GlidePipeline
from .pipeline_latent_diffusion import LatentDiffusionPipeline from .latent_diffusion import LatentDiffusionPipeline
if is_transformers_available() and is_unidecode_available() and is_inflect_available(): if is_transformers_available() and is_unidecode_available() and is_inflect_available():
from .pipeline_grad_tts import GradTTSPipeline from .grad_tts import GradTTSPipeline
from .pipeline_bddm import BDDMPipeline, DiffWave
...@@ -21,9 +21,9 @@ import torch.nn.functional as F ...@@ -21,9 +21,9 @@ import torch.nn.functional as F
import tqdm import tqdm
from ..configuration_utils import ConfigMixin from ...configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ...modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in):
......
from .pipeline_ddim import DDIMPipeline
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import tqdm import tqdm
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
class DDIMPipeline(DiffusionPipeline): class DDIMPipeline(DiffusionPipeline):
......
from .pipeline_ddpm import DDPMPipeline
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import tqdm import tqdm
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
class DDPMPipeline(DiffusionPipeline): class DDPMPipeline(DiffusionPipeline):
......
from ...utils import is_transformers_available
if is_transformers_available():
from .pipeline_glide import CLIPTextModel, GlidePipeline
...@@ -18,7 +18,6 @@ import math ...@@ -18,7 +18,6 @@ import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Tuple, Union
import numpy as np
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
...@@ -30,10 +29,10 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo ...@@ -30,10 +29,10 @@ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPo
from transformers.modeling_utils import PreTrainedModel from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings from transformers.utils import ModelOutput, add_start_docstrings_to_model_forward, replace_return_docstrings
from ..models import GlideSuperResUNetModel, GlideTextToImageUNetModel from ...models import GlideSuperResUNetModel, GlideTextToImageUNetModel
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
from ..schedulers import DDIMScheduler, DDPMScheduler from ...schedulers import DDIMScheduler, DDPMScheduler
from ..utils import logging from ...utils import logging
##################### #####################
...@@ -594,7 +593,7 @@ class CLIPTextTransformer(nn.Module): ...@@ -594,7 +593,7 @@ class CLIPTextTransformer(nn.Module):
bsz, seq_len = input_shape bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here. # CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device) # causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device)
# expand attention_mask # expand attention_mask
if attention_mask is not None: if attention_mask is not None:
......
from ...utils import is_inflect_available, is_transformers_available, is_unidecode_available
if is_transformers_available() and is_unidecode_available() and is_inflect_available():
from .grad_tts_utils import GradTTSTokenizer
from .pipeline_grad_tts import GradTTSPipeline, TextEncoder
...@@ -6,10 +6,10 @@ import torch ...@@ -6,10 +6,10 @@ import torch
from torch import nn from torch import nn
import tqdm import tqdm
from diffusers import DiffusionPipeline
from diffusers.configuration_utils import ConfigMixin
from diffusers.modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin
from ...modeling_utils import ModelMixin
from ...pipeline_utils import DiffusionPipeline
from .grad_tts_utils import GradTTSTokenizer # flake8: noqa from .grad_tts_utils import GradTTSTokenizer # flake8: noqa
......
from ...utils import is_transformers_available
if is_transformers_available():
from .pipeline_latent_diffusion import AutoencoderKL, LatentDiffusionPipeline, LDMBertModel
...@@ -7,20 +7,15 @@ import torch.nn as nn ...@@ -7,20 +7,15 @@ import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
import tqdm import tqdm
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
try: from transformers.modeling_outputs import BaseModelOutput
from transformers.activations import ACT2FN from transformers.modeling_utils import PreTrainedModel
from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import PreTrainedModel from ...configuration_utils import ConfigMixin
from transformers.utils import logging from ...modeling_utils import ModelMixin
except ImportError: from ...pipeline_utils import DiffusionPipeline
raise ImportError("Please install the transformers.")
from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline
################################################################################ ################################################################################
......
from .pipeline_latent_diffusion_uncond import LatentDiffusionUncondPipeline
...@@ -6,9 +6,9 @@ import torch.nn as nn ...@@ -6,9 +6,9 @@ import torch.nn as nn
import tqdm import tqdm
from ..configuration_utils import ConfigMixin from ...configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin from ...modeling_utils import ModelMixin
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
def get_timestep_embedding(timesteps, embedding_dim): def get_timestep_embedding(timesteps, embedding_dim):
......
from .pipeline_pndm import PNDMPipeline
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
import tqdm import tqdm
from ..pipeline_utils import DiffusionPipeline from ...pipeline_utils import DiffusionPipeline
class PNDMPipeline(DiffusionPipeline): class PNDMPipeline(DiffusionPipeline):
......
from .pipeline_score_sde_ve import ScoreSdeVePipeline
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment