Unverified Commit 37d113cc authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

DiT Pipeline (#1806)



* added dit model

* import

* initial pipeline

* initial convert script

* initial pipeline

* make style

* raise valueerror

* single function

* rename classes

* use DDIMScheduler

* timesteps embedder

* samples to cpu

* fix var names

* fix numpy type

* use timesteps class for proj

* fix typo

* fix arg name

* flip_sin_to_cos and better var names

* fix C shape cal

* make style

* remove unused imports

* cleanup

* add back patch_size

* initial dit doc

* typo

* Update docs/source/api/pipelines/dit.mdx
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* added copyright license headers

* added example usage and toc

* fix variable names asserts

* remove comment

* added docs

* fix typo

* upstream changes

* set proper device for drop_ids

* added initial dit pipeline test

* update docs

* fix imports

* make fix-copies

* isort

* fix imports

* get rid of more magic numbers

* fix code when guidance is off

* remove block_kwargs

* cleanup script

* removed to_2tuple

* use FeedForward class instead of another MLP

* style

* work on mergint DiTBlock with BasicTransformerBlock

* added missing final_dropout and args to BasicTransformerBlock

* use norm from block

* fix arg

* remove unused arg

* fix call to class_embedder

* use timesteps

* make style

* attn_output gets multiplied

* removed commented code

* use Transformer2D

* use self.is_input_patches

* fix flags

* fixed conversion to use Transformer2DModel

* fixes for pipeline

* remove dit.py

* fix timesteps device

* use randn_tensor and fix fp16 inf.

* timesteps_emb already the right dtype

* fix dit test class

* fix test and style

* fix norm2 usage in vq-diffusion

* added author names to pipeline and lmagenet labels link

* fix tests

* use norm_type as string

* rename dit to transformer

* fix name

* fix test

* set  norm_type = "layer" by default

* fix tests

* do not skip common tests

* Update src/diffusers/models/attention.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* revert AdaLayerNorm API

* fix norm_type name

* make sure all components are in eval mode

* revert norm2 API

* compact

* finish deprecation

* add slow tests

* remove @

* refactor some stuff

* upload

* Update src/diffusers/pipelines/dit/pipeline_dit.py

* finish more

* finish docs

* improve docs

* finish docs
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarWilliam Berman <WLBberman@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 7e29b747
...@@ -10,14 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer ...@@ -10,14 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import ( from ...schedulers import KarrasDiffusionSchedulers
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionSafePipelineOutput from . import StableDiffusionSafePipelineOutput
...@@ -65,14 +58,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline): ...@@ -65,14 +58,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Union[ scheduler: KarrasDiffusionSchedulers,
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
],
safety_checker: SafeStableDiffusionSafetyChecker, safety_checker: SafeStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
......
...@@ -7,7 +7,7 @@ import PIL.Image ...@@ -7,7 +7,7 @@ import PIL.Image
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging from ...utils import logging
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
...@@ -53,7 +53,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): ...@@ -53,7 +53,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel image_unet: UNet2DConditionModel
text_unet: UNet2DConditionModel text_unet: UNet2DConditionModel
vae: AutoencoderKL vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] scheduler: KarrasDiffusionSchedulers
def __init__( def __init__(
self, self,
...@@ -64,7 +64,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline): ...@@ -64,7 +64,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel, image_unet: UNet2DConditionModel,
text_unet: UNet2DConditionModel, text_unet: UNet2DConditionModel,
vae: AutoencoderKL, vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: KarrasDiffusionSchedulers,
): ):
super().__init__() super().__init__()
......
...@@ -28,7 +28,7 @@ from transformers import ( ...@@ -28,7 +28,7 @@ from transformers import (
) )
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
...@@ -62,7 +62,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -62,7 +62,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel image_unet: UNet2DConditionModel
text_unet: UNetFlatConditionModel text_unet: UNetFlatConditionModel
vae: AutoencoderKL vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] scheduler: KarrasDiffusionSchedulers
_optional_components = ["text_unet"] _optional_components = ["text_unet"]
...@@ -75,7 +75,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline): ...@@ -75,7 +75,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel, image_unet: UNet2DConditionModel,
text_unet: UNetFlatConditionModel, text_unet: UNetFlatConditionModel,
vae: AutoencoderKL, vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: KarrasDiffusionSchedulers,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
......
...@@ -23,7 +23,7 @@ import PIL ...@@ -23,7 +23,7 @@ import PIL
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -53,7 +53,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -53,7 +53,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
image_encoder: CLIPVisionModelWithProjection image_encoder: CLIPVisionModelWithProjection
image_unet: UNet2DConditionModel image_unet: UNet2DConditionModel
vae: AutoencoderKL vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] scheduler: KarrasDiffusionSchedulers
def __init__( def __init__(
self, self,
...@@ -61,7 +61,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline): ...@@ -61,7 +61,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
image_encoder: CLIPVisionModelWithProjection, image_encoder: CLIPVisionModelWithProjection,
image_unet: UNet2DConditionModel, image_unet: UNet2DConditionModel,
vae: AutoencoderKL, vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: KarrasDiffusionSchedulers,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
......
...@@ -21,7 +21,7 @@ import torch.utils.checkpoint ...@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel from .modeling_text_unet import UNetFlatConditionModel
...@@ -54,7 +54,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -54,7 +54,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel image_unet: UNet2DConditionModel
text_unet: UNetFlatConditionModel text_unet: UNetFlatConditionModel
vae: AutoencoderKL vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler] scheduler: KarrasDiffusionSchedulers
_optional_components = ["text_unet"] _optional_components = ["text_unet"]
...@@ -65,7 +65,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline): ...@@ -65,7 +65,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel, image_unet: UNet2DConditionModel,
text_unet: UNetFlatConditionModel, text_unet: UNetFlatConditionModel,
vae: AutoencoderKL, vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], scheduler: KarrasDiffusionSchedulers,
): ):
super().__init__() super().__init__()
self.register_modules( self.register_modules(
......
...@@ -39,7 +39,7 @@ else: ...@@ -39,7 +39,7 @@ else:
from .scheduling_sde_ve import ScoreSdeVeScheduler from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_unclip import UnCLIPScheduler from .scheduling_unclip import UnCLIPScheduler
from .scheduling_utils import SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler from .scheduling_vq_diffusion import VQDiffusionScheduler
try: try:
...@@ -55,7 +55,12 @@ else: ...@@ -55,7 +55,12 @@ else:
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left from .scheduling_utils_flax import (
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
try: try:
......
...@@ -23,8 +23,8 @@ import numpy as np ...@@ -23,8 +23,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor from ..utils import BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@dataclass @dataclass
...@@ -112,7 +112,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): ...@@ -112,7 +112,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"] _deprecated_kwargs = ["predict_epsilon"]
order = 1 order = 1
......
...@@ -24,8 +24,8 @@ import jax.numpy as jnp ...@@ -24,8 +24,8 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate from ..utils import deprecate
from .scheduling_utils_flax import ( from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState, CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin, FlaxSchedulerMixin,
FlaxSchedulerOutput, FlaxSchedulerOutput,
add_noise_common, add_noise_common,
...@@ -102,7 +102,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -102,7 +102,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation. the `dtype` used for params and computation.
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"] _deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype dtype: jnp.dtype
......
...@@ -22,8 +22,8 @@ import numpy as np ...@@ -22,8 +22,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor from ..utils import BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@dataclass @dataclass
...@@ -105,7 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): ...@@ -105,7 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"] _deprecated_kwargs = ["predict_epsilon"]
order = 1 order = 1
......
...@@ -24,8 +24,8 @@ import jax.numpy as jnp ...@@ -24,8 +24,8 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate from ..utils import deprecate
from .scheduling_utils_flax import ( from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState, CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin, FlaxSchedulerMixin,
FlaxSchedulerOutput, FlaxSchedulerOutput,
add_noise_common, add_noise_common,
...@@ -85,7 +85,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -85,7 +85,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation. the `dtype` used for params and computation.
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"] _deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype dtype: jnp.dtype
......
...@@ -22,8 +22,7 @@ import numpy as np ...@@ -22,8 +22,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...@@ -106,7 +105,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -106,7 +105,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1 order = 1
@register_to_config @register_to_config
......
...@@ -21,8 +21,8 @@ import numpy as np ...@@ -21,8 +21,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate from ..utils import deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...@@ -117,7 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): ...@@ -117,7 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"] _deprecated_kwargs = ["predict_epsilon"]
order = 1 order = 1
......
...@@ -24,8 +24,8 @@ import jax.numpy as jnp ...@@ -24,8 +24,8 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate from ..utils import deprecate
from .scheduling_utils_flax import ( from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState, CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin, FlaxSchedulerMixin,
FlaxSchedulerOutput, FlaxSchedulerOutput,
add_noise_common, add_noise_common,
...@@ -140,7 +140,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): ...@@ -140,7 +140,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation. the `dtype` used for params and computation.
""" """
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"] _deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype dtype: jnp.dtype
......
...@@ -21,8 +21,7 @@ import numpy as np ...@@ -21,8 +21,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
...@@ -116,7 +115,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin): ...@@ -116,7 +115,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1 order = 1
@register_to_config @register_to_config
......
...@@ -19,8 +19,8 @@ import numpy as np ...@@ -19,8 +19,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor from ..utils import BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -71,7 +71,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -71,7 +71,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1 order = 1
@register_to_config @register_to_config
......
...@@ -19,8 +19,8 @@ import numpy as np ...@@ -19,8 +19,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor from ..utils import BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
...@@ -72,7 +72,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -72,7 +72,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1 order = 1
@register_to_config @register_to_config
......
...@@ -18,8 +18,7 @@ import numpy as np ...@@ -18,8 +18,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
...@@ -48,7 +47,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -48,7 +47,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 2 order = 2
@register_to_config @register_to_config
......
...@@ -18,8 +18,8 @@ import numpy as np ...@@ -18,8 +18,8 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, randn_tensor from ..utils import randn_tensor
from .scheduling_utils import SchedulerMixin, SchedulerOutput from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
...@@ -49,7 +49,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -49,7 +49,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 2 order = 2
@register_to_config @register_to_config
......
...@@ -18,8 +18,7 @@ import numpy as np ...@@ -18,8 +18,7 @@ import numpy as np
import torch import torch
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
from .scheduling_utils import SchedulerMixin, SchedulerOutput
class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
...@@ -49,7 +48,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -49,7 +48,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 2 order = 2
@register_to_config @register_to_config
......
...@@ -21,8 +21,8 @@ import torch ...@@ -21,8 +21,8 @@ import torch
from scipy import integrate from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@dataclass @dataclass
...@@ -70,7 +70,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): ...@@ -70,7 +70,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf) https://imagen.research.google/video/paper.pdf)
""" """
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() _compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1 order = 1
@register_to_config @register_to_config
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment