Unverified Commit 37d113cc authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

DiT Pipeline (#1806)



* added dit model

* import

* initial pipeline

* initial convert script

* initial pipeline

* make style

* raise valueerror

* single function

* rename classes

* use DDIMScheduler

* timesteps embedder

* samples to cpu

* fix var names

* fix numpy type

* use timesteps class for proj

* fix typo

* fix arg name

* flip_sin_to_cos and better var names

* fix C shape cal

* make style

* remove unused imports

* cleanup

* add back patch_size

* initial dit doc

* typo

* Update docs/source/api/pipelines/dit.mdx
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* added copyright license headers

* added example usage and toc

* fix variable names asserts

* remove comment

* added docs

* fix typo

* upstream changes

* set proper device for drop_ids

* added initial dit pipeline test

* update docs

* fix imports

* make fix-copies

* isort

* fix imports

* get rid of more magic numbers

* fix code when guidance is off

* remove block_kwargs

* cleanup script

* removed to_2tuple

* use FeedForward class instead of another MLP

* style

* work on mergint DiTBlock with BasicTransformerBlock

* added missing final_dropout and args to BasicTransformerBlock

* use norm from block

* fix arg

* remove unused arg

* fix call to class_embedder

* use timesteps

* make style

* attn_output gets multiplied

* removed commented code

* use Transformer2D

* use self.is_input_patches

* fix flags

* fixed conversion to use Transformer2DModel

* fixes for pipeline

* remove dit.py

* fix timesteps device

* use randn_tensor and fix fp16 inf.

* timesteps_emb already the right dtype

* fix dit test class

* fix test and style

* fix norm2 usage in vq-diffusion

* added author names to pipeline and lmagenet labels link

* fix tests

* use norm_type as string

* rename dit to transformer

* fix name

* fix test

* set  norm_type = "layer" by default

* fix tests

* do not skip common tests

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* revert AdaLayerNorm API

* fix norm_type name

* make sure all components are in eval mode

* revert norm2 API

* compact

* finish deprecation

* add slow tests

* remove @

* refactor some stuff

* upload

* Update src/diffusers/pipelines/dit/pipeline_dit.py

* finish more

* finish docs

* improve docs

* finish docs
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarWilliam Berman <WLBberman@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7e29b747
......@@ -10,14 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionSafePipelineOutput
......@@ -65,14 +58,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
],
scheduler: KarrasDiffusionSchedulers,
safety_checker: SafeStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,
......
......@@ -7,7 +7,7 @@ import PIL.Image
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging
from ..pipeline_utils import DiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
......@@ -53,7 +53,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel
text_unet: UNet2DConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
scheduler: KarrasDiffusionSchedulers
def __init__(
self,
......@@ -64,7 +64,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel,
text_unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()
......
......@@ -28,7 +28,7 @@ from transformers import (
)
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel
......@@ -62,7 +62,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel
text_unet: UNetFlatConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
scheduler: KarrasDiffusionSchedulers
_optional_components = ["text_unet"]
......@@ -75,7 +75,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel,
text_unet: UNetFlatConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()
self.register_modules(
......
......@@ -23,7 +23,7 @@ import PIL
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
......@@ -53,7 +53,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
image_encoder: CLIPVisionModelWithProjection
image_unet: UNet2DConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
scheduler: KarrasDiffusionSchedulers
def __init__(
self,
......@@ -61,7 +61,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
image_encoder: CLIPVisionModelWithProjection,
image_unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()
self.register_modules(
......
......@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel
......@@ -54,7 +54,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel
text_unet: UNetFlatConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
scheduler: KarrasDiffusionSchedulers
_optional_components = ["text_unet"]
......@@ -65,7 +65,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel,
text_unet: UNetFlatConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()
self.register_modules(
......
......@@ -39,7 +39,7 @@ else:
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_unclip import UnCLIPScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
try:
......@@ -55,7 +55,12 @@ else:
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
from .scheduling_utils_flax import (
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
try:
......
......@@ -23,8 +23,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput, deprecate, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@dataclass
......@@ -112,7 +112,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1
......
......@@ -24,8 +24,8 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
......@@ -102,7 +102,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
......
......@@ -22,8 +22,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput, deprecate, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@dataclass
......@@ -105,7 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1
......
......@@ -24,8 +24,8 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
......@@ -85,7 +85,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
......
......@@ -22,8 +22,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -106,7 +105,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
......
......@@ -21,8 +21,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import deprecate
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -117,7 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1
......
......@@ -24,8 +24,8 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
......@@ -140,7 +140,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype
......
......@@ -21,8 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
......@@ -116,7 +115,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
......
......@@ -19,8 +19,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput, logging, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -71,7 +71,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
......
......@@ -19,8 +19,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput, logging, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
......@@ -72,7 +72,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
......
......@@ -18,8 +18,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
......@@ -48,7 +47,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 2
@register_to_config
......
......@@ -18,8 +18,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, randn_tensor
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
......@@ -49,7 +49,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 2
@register_to_config
......
......@@ -18,8 +18,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
......@@ -49,7 +48,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 2
@register_to_config
......
......@@ -21,8 +21,8 @@ import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@dataclass
......@@ -70,7 +70,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment