Unverified Commit 3651b14c authored by Pedro Cuenca's avatar Pedro Cuenca Committed by GitHub
Browse files

SDXL flax (#4254)



* support transformer_layers_per block in flax UNet

* add support for text_time additional embeddings to Flax UNet

* rename attention layers for VAE

* add shape asserts when renaming attention layers

* transpose VAE attention layers

* add pipeline flax SDXL code [WIP]

* continue add pipeline flax SDXL code [WIP]

* cleanup

* Working on JIT support

Fixed prompt embedding shapes so they work in parallel mode. Assuming we
always have both text encoders for now, for simplicity.

* Fixing embeddings (untested)

* Remove spurious line

* Shard guidance_scale when jitting.

* Decode images

* Fix sharding

* style

* Refiner UNet can be loaded.

* Refiner / img2img pipeline

* Allow latent outputs from base and latent inputs in refiner

This makes it possible to chain base + refiner without having to use the
vae decoder in the base model, the vae encoder in the refiner, skipping
conversions to/from PIL, and avoiding TPU <-> CPU memory copies.

* Adapt to FlaxCLIPTextModelOutput

* Update Flax XL pipeline to FlaxCLIPTextModelOutput

* make fix-copies

* make style

* add euler scheduler

* Fix import

* Fix copies, comment unused code.

* Fix SDXL Flax imports

* Fix euler discrete begin

* improve init import

* finish

* put discrete euler in init

* fix flax euler

* Fix more

* make style

* correct init

* correct init

* Temporarily remove FlaxStableDiffusionXLImg2ImgPipeline

* correct pipelines

* finish

---------
Co-authored-by: default avatarMartin Müller <martin.muller.me@gmail.com>
Co-authored-by: default avatarpatil-suraj <surajp815@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
parent 2e860e89
...@@ -368,6 +368,7 @@ else: ...@@ -368,6 +368,7 @@ else:
"FlaxDDIMScheduler", "FlaxDDIMScheduler",
"FlaxDDPMScheduler", "FlaxDDPMScheduler",
"FlaxDPMSolverMultistepScheduler", "FlaxDPMSolverMultistepScheduler",
"FlaxEulerDiscreteScheduler",
"FlaxKarrasVeScheduler", "FlaxKarrasVeScheduler",
"FlaxLMSDiscreteScheduler", "FlaxLMSDiscreteScheduler",
"FlaxPNDMScheduler", "FlaxPNDMScheduler",
...@@ -395,6 +396,7 @@ else: ...@@ -395,6 +396,7 @@ else:
"FlaxStableDiffusionImg2ImgPipeline", "FlaxStableDiffusionImg2ImgPipeline",
"FlaxStableDiffusionInpaintPipeline", "FlaxStableDiffusionInpaintPipeline",
"FlaxStableDiffusionPipeline", "FlaxStableDiffusionPipeline",
"FlaxStableDiffusionXLPipeline",
] ]
) )
...@@ -673,6 +675,7 @@ if TYPE_CHECKING: ...@@ -673,6 +675,7 @@ if TYPE_CHECKING:
FlaxDDIMScheduler, FlaxDDIMScheduler,
FlaxDDPMScheduler, FlaxDDPMScheduler,
FlaxDPMSolverMultistepScheduler, FlaxDPMSolverMultistepScheduler,
FlaxEulerDiscreteScheduler,
FlaxKarrasVeScheduler, FlaxKarrasVeScheduler,
FlaxLMSDiscreteScheduler, FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler, FlaxPNDMScheduler,
...@@ -691,6 +694,7 @@ if TYPE_CHECKING: ...@@ -691,6 +694,7 @@ if TYPE_CHECKING:
FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionImg2ImgPipeline,
FlaxStableDiffusionInpaintPipeline, FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline, FlaxStableDiffusionPipeline,
FlaxStableDiffusionXLPipeline,
) )
try: try:
......
...@@ -42,9 +42,25 @@ def rename_key(key): ...@@ -42,9 +42,25 @@ def rename_key(key):
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
"""Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary"""
# conv norm or layer norm # conv norm or layer norm
renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
# rename attention layers
if len(pt_tuple_key) > 1:
for rename_from, rename_to in (
("to_out_0", "proj_attn"),
("to_k", "key"),
("to_v", "value"),
("to_q", "query"),
):
if pt_tuple_key[-2] == rename_from:
weight_name = pt_tuple_key[-1]
weight_name = "kernel" if weight_name == "weight" else weight_name
renamed_pt_tuple_key = pt_tuple_key[:-2] + (rename_to, weight_name)
if renamed_pt_tuple_key in random_flax_state_dict:
assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape
return renamed_pt_tuple_key, pt_tensor.T
if ( if (
any("norm" in str_ for str_ in pt_tuple_key) any("norm" in str_ for str_ in pt_tuple_key)
and (pt_tuple_key[-1] == "bias") and (pt_tuple_key[-1] == "bias")
......
...@@ -303,23 +303,23 @@ class FlaxModelMixin(PushToHubMixin): ...@@ -303,23 +303,23 @@ class FlaxModelMixin(PushToHubMixin):
"framework": "flax", "framework": "flax",
} }
# Load config if we don't provide a configuration # Load config if we don't provide one
config_path = config if config is not None else pretrained_model_name_or_path if config is None:
model, model_kwargs = cls.from_config( config, unused_kwargs = cls.load_config(
config_path, pretrained_model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
return_unused_kwargs=True, return_unused_kwargs=True,
force_download=force_download, force_download=force_download,
resume_download=resume_download, resume_download=resume_download,
proxies=proxies, proxies=proxies,
local_files_only=local_files_only, local_files_only=local_files_only,
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
revision=revision, revision=revision,
subfolder=subfolder, subfolder=subfolder,
# model args **kwargs,
dtype=dtype, )
**kwargs,
) model, model_kwargs = cls.from_config(config, dtype=dtype, return_unused_kwargs=True, **unused_kwargs)
# Load model # Load model
pretrained_path_with_subfolder = ( pretrained_path_with_subfolder = (
......
...@@ -52,6 +52,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -52,6 +52,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
only_cross_attention: bool = False only_cross_attention: bool = False
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self): def setup(self):
resnets = [] resnets = []
...@@ -72,7 +73,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module): ...@@ -72,7 +73,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
in_channels=self.out_channels, in_channels=self.out_channels,
n_heads=self.num_attention_heads, n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads,
depth=1, depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention, only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
...@@ -192,6 +193,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -192,6 +193,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
only_cross_attention: bool = False only_cross_attention: bool = False
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self): def setup(self):
resnets = [] resnets = []
...@@ -213,7 +215,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module): ...@@ -213,7 +215,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
in_channels=self.out_channels, in_channels=self.out_channels,
n_heads=self.num_attention_heads, n_heads=self.num_attention_heads,
d_head=self.out_channels // self.num_attention_heads, d_head=self.out_channels // self.num_attention_heads,
depth=1, depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
only_cross_attention=self.only_cross_attention, only_cross_attention=self.only_cross_attention,
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
...@@ -331,6 +333,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -331,6 +333,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
use_linear_projection: bool = False use_linear_projection: bool = False
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
dtype: jnp.dtype = jnp.float32 dtype: jnp.dtype = jnp.float32
transformer_layers_per_block: int = 1
def setup(self): def setup(self):
# there is always at least one resnet # there is always at least one resnet
...@@ -350,7 +353,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module): ...@@ -350,7 +353,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
in_channels=self.in_channels, in_channels=self.in_channels,
n_heads=self.num_attention_heads, n_heads=self.num_attention_heads,
d_head=self.in_channels // self.num_attention_heads, d_head=self.in_channels // self.num_attention_heads,
depth=1, depth=self.transformer_layers_per_block,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype, dtype=self.dtype,
......
...@@ -883,7 +883,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -883,7 +883,6 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
time_ids = added_cond_kwargs.get("time_ids") time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype) add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds) aug_emb = self.add_embedding(add_embeds)
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple, Union from typing import Dict, Optional, Tuple, Union
import flax import flax
import flax.linen as nn import flax.linen as nn
...@@ -116,6 +116,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -116,6 +116,11 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
flip_sin_to_cos: bool = True flip_sin_to_cos: bool = True
freq_shift: int = 0 freq_shift: int = 0
use_memory_efficient_attention: bool = False use_memory_efficient_attention: bool = False
transformer_layers_per_block: Union[int, Tuple[int]] = 1
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
addition_embed_type_num_heads: int = 64
projection_class_embeddings_input_dim: Optional[int] = None
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
# init input tensors # init input tensors
...@@ -127,7 +132,17 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -127,7 +132,17 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
params_rng, dropout_rng = jax.random.split(rng) params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng} rngs = {"params": params_rng, "dropout": dropout_rng}
return self.init(rngs, sample, timesteps, encoder_hidden_states)["params"] added_cond_kwargs = None
if self.addition_embed_type == "text_time":
# TODO: how to get this from the config? It's no longer cross_attention_dim
text_embeds_dim = 1280
time_ids_channels = self.projection_class_embeddings_input_dim - text_embeds_dim
time_ids_dims = time_ids_channels // self.addition_time_embed_dim
added_cond_kwargs = {
"text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
"time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
}
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
def setup(self): def setup(self):
block_out_channels = self.block_out_channels block_out_channels = self.block_out_channels
...@@ -168,6 +183,24 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -168,6 +183,24 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
if isinstance(num_attention_heads, int): if isinstance(num_attention_heads, int):
num_attention_heads = (num_attention_heads,) * len(self.down_block_types) num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
# transformer layers per block
transformer_layers_per_block = self.transformer_layers_per_block
if isinstance(transformer_layers_per_block, int):
transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types)
# addition embed types
if self.addition_embed_type is None:
self.add_embedding = None
elif self.addition_embed_type == "text_time":
if self.addition_time_embed_dim is None:
raise ValueError(
f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None"
)
self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift)
self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
else:
raise ValueError(f"addition_embed_type: {self.addition_embed_type} must be None or `text_time`.")
# down # down
down_blocks = [] down_blocks = []
output_channel = block_out_channels[0] output_channel = block_out_channels[0]
...@@ -182,6 +215,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -182,6 +215,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel, out_channels=output_channel,
dropout=self.dropout, dropout=self.dropout,
num_layers=self.layers_per_block, num_layers=self.layers_per_block,
transformer_layers_per_block=transformer_layers_per_block[i],
num_attention_heads=num_attention_heads[i], num_attention_heads=num_attention_heads[i],
add_downsample=not is_final_block, add_downsample=not is_final_block,
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
...@@ -207,6 +241,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -207,6 +241,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
in_channels=block_out_channels[-1], in_channels=block_out_channels[-1],
dropout=self.dropout, dropout=self.dropout,
num_attention_heads=num_attention_heads[-1], num_attention_heads=num_attention_heads[-1],
transformer_layers_per_block=transformer_layers_per_block[-1],
use_linear_projection=self.use_linear_projection, use_linear_projection=self.use_linear_projection,
use_memory_efficient_attention=self.use_memory_efficient_attention, use_memory_efficient_attention=self.use_memory_efficient_attention,
dtype=self.dtype, dtype=self.dtype,
...@@ -218,6 +253,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -218,6 +253,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
reversed_num_attention_heads = list(reversed(num_attention_heads)) reversed_num_attention_heads = list(reversed(num_attention_heads))
only_cross_attention = list(reversed(only_cross_attention)) only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
for i, up_block_type in enumerate(self.up_block_types): for i, up_block_type in enumerate(self.up_block_types):
prev_output_channel = output_channel prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i] output_channel = reversed_block_out_channels[i]
...@@ -231,6 +267,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -231,6 +267,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
out_channels=output_channel, out_channels=output_channel,
prev_output_channel=prev_output_channel, prev_output_channel=prev_output_channel,
num_layers=self.layers_per_block + 1, num_layers=self.layers_per_block + 1,
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
num_attention_heads=reversed_num_attention_heads[i], num_attention_heads=reversed_num_attention_heads[i],
add_upsample=not is_final_block, add_upsample=not is_final_block,
dropout=self.dropout, dropout=self.dropout,
...@@ -269,6 +306,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -269,6 +306,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
sample, sample,
timesteps, timesteps,
encoder_hidden_states, encoder_hidden_states,
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
down_block_additional_residuals=None, down_block_additional_residuals=None,
mid_block_additional_residual=None, mid_block_additional_residual=None,
return_dict: bool = True, return_dict: bool = True,
...@@ -300,6 +338,31 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): ...@@ -300,6 +338,31 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
t_emb = self.time_proj(timesteps) t_emb = self.time_proj(timesteps)
t_emb = self.time_embedding(t_emb) t_emb = self.time_embedding(t_emb)
# additional embeddings
aug_emb = None
if self.addition_embed_type == "text_time":
if added_cond_kwargs is None:
raise ValueError(
f"Need to provide argument `added_cond_kwargs` for {self.__class__} when using `addition_embed_type={self.addition_embed_type}`"
)
text_embeds = added_cond_kwargs.get("text_embeds")
if text_embeds is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
)
time_ids = added_cond_kwargs.get("time_ids")
if time_ids is None:
raise ValueError(
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
)
# compute time embeds
time_embeds = self.add_time_proj(jnp.ravel(time_ids)) # (1, 6) => (6,) => (6, 256)
time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1))
add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1)
aug_emb = self.add_embedding(add_embeds)
t_emb = t_emb + aug_emb if aug_emb is not None else t_emb
# 2. pre-process # 2. pre-process
sample = jnp.transpose(sample, (0, 2, 3, 1)) sample = jnp.transpose(sample, (0, 2, 3, 1))
sample = self.conv_in(sample) sample = self.conv_in(sample)
......
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ..utils import ( from ..utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
get_objects_from_module, get_objects_from_module,
is_flax_available, is_flax_available,
is_k_diffusion_available, is_k_diffusion_available,
is_librosa_available, is_librosa_available,
is_note_seq_available, is_note_seq_available,
is_onnx_available, is_onnx_available,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
) )
# These modules contain pipelines from multiple libraries/frameworks # These modules contain pipelines from multiple libraries/frameworks
_dummy_objects = {} _dummy_objects = {}
_import_structure = {"stable_diffusion": [], "latent_diffusion": [], "controlnet": []} _import_structure = {"stable_diffusion": [], "stable_diffusion_xl": [], "latent_diffusion": [], "controlnet": []}
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils import dummy_pt_objects # noqa F403 from ..utils import dummy_pt_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_pt_objects)) _dummy_objects.update(get_objects_from_module(dummy_pt_objects))
else: else:
_import_structure["auto_pipeline"] = [ _import_structure["auto_pipeline"] = [
"AutoPipelineForImage2Image", "AutoPipelineForImage2Image",
"AutoPipelineForInpainting", "AutoPipelineForInpainting",
"AutoPipelineForText2Image", "AutoPipelineForText2Image",
] ]
_import_structure["consistency_models"] = ["ConsistencyModelPipeline"] _import_structure["consistency_models"] = ["ConsistencyModelPipeline"]
_import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"] _import_structure["dance_diffusion"] = ["DanceDiffusionPipeline"]
_import_structure["ddim"] = ["DDIMPipeline"] _import_structure["ddim"] = ["DDIMPipeline"]
_import_structure["ddpm"] = ["DDPMPipeline"] _import_structure["ddpm"] = ["DDPMPipeline"]
_import_structure["dit"] = ["DiTPipeline"] _import_structure["dit"] = ["DiTPipeline"]
_import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"]) _import_structure["latent_diffusion"].extend(["LDMSuperResolutionPipeline"])
_import_structure["latent_diffusion_uncond"] = ["LDMPipeline"] _import_structure["latent_diffusion_uncond"] = ["LDMPipeline"]
_import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"] _import_structure["pipeline_utils"] = ["AudioPipelineOutput", "DiffusionPipeline", "ImagePipelineOutput"]
_import_structure["pndm"] = ["PNDMPipeline"] _import_structure["pndm"] = ["PNDMPipeline"]
_import_structure["repaint"] = ["RePaintPipeline"] _import_structure["repaint"] = ["RePaintPipeline"]
_import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"] _import_structure["score_sde_ve"] = ["ScoreSdeVePipeline"]
_import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"] _import_structure["stochastic_karras_ve"] = ["KarrasVePipeline"]
try: try:
if not (is_torch_available() and is_librosa_available()): if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_librosa_objects # noqa F403 from ..utils import dummy_torch_and_librosa_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_librosa_objects))
else: else:
_import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"] _import_structure["audio_diffusion"] = ["AudioDiffusionPipeline", "Mel"]
try: try:
if not (is_torch_available() and is_transformers_available()): if not (is_torch_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils import dummy_torch_and_transformers_objects # noqa F403 from ..utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"] _import_structure["alt_diffusion"] = ["AltDiffusionImg2ImgPipeline", "AltDiffusionPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"] _import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [ _import_structure["audioldm2"] = [
"AudioLDM2Pipeline", "AudioLDM2Pipeline",
"AudioLDM2ProjectionModel", "AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel", "AudioLDM2UNet2DConditionModel",
] ]
_import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"] _import_structure["blip_diffusion"] = ["BlipDiffusionPipeline"]
_import_structure["controlnet"].extend( _import_structure["controlnet"].extend(
[ [
"BlipDiffusionControlNetPipeline", "BlipDiffusionControlNetPipeline",
"StableDiffusionControlNetImg2ImgPipeline", "StableDiffusionControlNetImg2ImgPipeline",
"StableDiffusionControlNetInpaintPipeline", "StableDiffusionControlNetInpaintPipeline",
"StableDiffusionControlNetPipeline", "StableDiffusionControlNetPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline", "StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline", "StableDiffusionXLControlNetInpaintPipeline",
"StableDiffusionXLControlNetPipeline", "StableDiffusionXLControlNetPipeline",
] ]
) )
_import_structure["deepfloyd_if"] = [ _import_structure["deepfloyd_if"] = [
"IFImg2ImgPipeline", "IFImg2ImgPipeline",
"IFImg2ImgSuperResolutionPipeline", "IFImg2ImgSuperResolutionPipeline",
"IFInpaintingPipeline", "IFInpaintingPipeline",
"IFInpaintingSuperResolutionPipeline", "IFInpaintingSuperResolutionPipeline",
"IFPipeline", "IFPipeline",
"IFSuperResolutionPipeline", "IFSuperResolutionPipeline",
] ]
_import_structure["kandinsky"] = [ _import_structure["kandinsky"] = [
"KandinskyCombinedPipeline", "KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline", "KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline", "KandinskyImg2ImgPipeline",
"KandinskyInpaintCombinedPipeline", "KandinskyInpaintCombinedPipeline",
"KandinskyInpaintPipeline", "KandinskyInpaintPipeline",
"KandinskyPipeline", "KandinskyPipeline",
"KandinskyPriorPipeline", "KandinskyPriorPipeline",
] ]
_import_structure["kandinsky2_2"] = [ _import_structure["kandinsky2_2"] = [
"KandinskyV22CombinedPipeline", "KandinskyV22CombinedPipeline",
"KandinskyV22ControlnetImg2ImgPipeline", "KandinskyV22ControlnetImg2ImgPipeline",
"KandinskyV22ControlnetPipeline", "KandinskyV22ControlnetPipeline",
"KandinskyV22Img2ImgCombinedPipeline", "KandinskyV22Img2ImgCombinedPipeline",
"KandinskyV22Img2ImgPipeline", "KandinskyV22Img2ImgPipeline",
"KandinskyV22InpaintCombinedPipeline", "KandinskyV22InpaintCombinedPipeline",
"KandinskyV22InpaintPipeline", "KandinskyV22InpaintPipeline",
"KandinskyV22Pipeline", "KandinskyV22Pipeline",
"KandinskyV22PriorEmb2EmbPipeline", "KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline", "KandinskyV22PriorPipeline",
] ]
_import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"]) _import_structure["latent_diffusion"].extend(["LDMTextToImagePipeline"])
_import_structure["musicldm"] = ["MusicLDMPipeline"] _import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"] _import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_diffusion"].extend( _import_structure["stable_diffusion"].extend(
[ [
"CLIPImageProjection", "CLIPImageProjection",
"CycleDiffusionPipeline", "CycleDiffusionPipeline",
"StableDiffusionAttendAndExcitePipeline", "StableDiffusionAttendAndExcitePipeline",
"StableDiffusionDepth2ImgPipeline", "StableDiffusionDepth2ImgPipeline",
"StableDiffusionDiffEditPipeline", "StableDiffusionDiffEditPipeline",
"StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENPipeline",
"StableDiffusionGLIGENPipeline", "StableDiffusionGLIGENPipeline",
"StableDiffusionGLIGENTextImagePipeline", "StableDiffusionGLIGENTextImagePipeline",
"StableDiffusionImageVariationPipeline", "StableDiffusionImageVariationPipeline",
"StableDiffusionImg2ImgPipeline", "StableDiffusionImg2ImgPipeline",
"StableDiffusionInpaintPipeline", "StableDiffusionInpaintPipeline",
"StableDiffusionInpaintPipelineLegacy", "StableDiffusionInpaintPipelineLegacy",
"StableDiffusionInstructPix2PixPipeline", "StableDiffusionInstructPix2PixPipeline",
"StableDiffusionLatentUpscalePipeline", "StableDiffusionLatentUpscalePipeline",
"StableDiffusionLDM3DPipeline", "StableDiffusionLDM3DPipeline",
"StableDiffusionModelEditingPipeline", "StableDiffusionModelEditingPipeline",
"StableDiffusionPanoramaPipeline", "StableDiffusionPanoramaPipeline",
"StableDiffusionParadigmsPipeline", "StableDiffusionParadigmsPipeline",
"StableDiffusionPipeline", "StableDiffusionPipeline",
"StableDiffusionPix2PixZeroPipeline", "StableDiffusionPix2PixZeroPipeline",
"StableDiffusionSAGPipeline", "StableDiffusionSAGPipeline",
"StableDiffusionUpscalePipeline", "StableDiffusionUpscalePipeline",
"StableUnCLIPImg2ImgPipeline", "StableUnCLIPImg2ImgPipeline",
"StableUnCLIPPipeline", "StableUnCLIPPipeline",
] ]
) )
_import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"] _import_structure["stable_diffusion_safe"] = ["StableDiffusionPipelineSafe"]
_import_structure["stable_diffusion_xl"] = [ _import_structure["stable_diffusion_xl"].extend(
"StableDiffusionXLImg2ImgPipeline", [
"StableDiffusionXLInpaintPipeline", "StableDiffusionXLImg2ImgPipeline",
"StableDiffusionXLInstructPix2PixPipeline", "StableDiffusionXLInpaintPipeline",
"StableDiffusionXLPipeline", "StableDiffusionXLInstructPix2PixPipeline",
] "StableDiffusionXLPipeline",
_import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"] ]
_import_structure["text_to_video_synthesis"] = [ )
"TextToVideoSDPipeline", _import_structure["t2i_adapter"] = ["StableDiffusionAdapterPipeline", "StableDiffusionXLAdapterPipeline"]
"TextToVideoZeroPipeline", _import_structure["text_to_video_synthesis"] = [
"VideoToVideoSDPipeline", "TextToVideoSDPipeline",
] "TextToVideoZeroPipeline",
_import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"] "VideoToVideoSDPipeline",
_import_structure["unidiffuser"] = [ ]
"ImageTextPipelineOutput", _import_structure["unclip"] = ["UnCLIPImageVariationPipeline", "UnCLIPPipeline"]
"UniDiffuserModel", _import_structure["unidiffuser"] = [
"UniDiffuserPipeline", "ImageTextPipelineOutput",
"UniDiffuserTextDecoder", "UniDiffuserModel",
] "UniDiffuserPipeline",
_import_structure["versatile_diffusion"] = [ "UniDiffuserTextDecoder",
"VersatileDiffusionDualGuidedPipeline", ]
"VersatileDiffusionImageVariationPipeline", _import_structure["versatile_diffusion"] = [
"VersatileDiffusionPipeline", "VersatileDiffusionDualGuidedPipeline",
"VersatileDiffusionTextToImagePipeline", "VersatileDiffusionImageVariationPipeline",
] "VersatileDiffusionPipeline",
_import_structure["vq_diffusion"] = ["VQDiffusionPipeline"] "VersatileDiffusionTextToImagePipeline",
_import_structure["wuerstchen"] = [ ]
"WuerstchenCombinedPipeline", _import_structure["vq_diffusion"] = ["VQDiffusionPipeline"]
"WuerstchenDecoderPipeline", _import_structure["wuerstchen"] = [
"WuerstchenPriorPipeline", "WuerstchenCombinedPipeline",
] "WuerstchenDecoderPipeline",
try: "WuerstchenPriorPipeline",
if not is_onnx_available(): ]
raise OptionalDependencyNotAvailable() try:
except OptionalDependencyNotAvailable: if not is_onnx_available():
from ..utils import dummy_onnx_objects # noqa F403 raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_onnx_objects)) from ..utils import dummy_onnx_objects # noqa F403
else:
_import_structure["onnx_utils"] = ["OnnxRuntimeModel"] _dummy_objects.update(get_objects_from_module(dummy_onnx_objects))
try: else:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()): _import_structure["onnx_utils"] = ["OnnxRuntimeModel"]
raise OptionalDependencyNotAvailable() try:
except OptionalDependencyNotAvailable: if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403 raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects)) from ..utils import dummy_torch_and_transformers_and_onnx_objects # noqa F403
else:
_import_structure["stable_diffusion"].extend( _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_onnx_objects))
[ else:
"OnnxStableDiffusionImg2ImgPipeline", _import_structure["stable_diffusion"].extend(
"OnnxStableDiffusionInpaintPipeline", [
"OnnxStableDiffusionInpaintPipelineLegacy", "OnnxStableDiffusionImg2ImgPipeline",
"OnnxStableDiffusionPipeline", "OnnxStableDiffusionInpaintPipeline",
"OnnxStableDiffusionUpscalePipeline", "OnnxStableDiffusionInpaintPipelineLegacy",
"StableDiffusionOnnxPipeline", "OnnxStableDiffusionPipeline",
] "OnnxStableDiffusionUpscalePipeline",
) "StableDiffusionOnnxPipeline",
try: ]
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): )
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: try:
from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403 if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects)) except OptionalDependencyNotAvailable:
else: from ..utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
_import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
try: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
if not is_flax_available(): else:
raise OptionalDependencyNotAvailable() _import_structure["stable_diffusion"].extend(["StableDiffusionKDiffusionPipeline"])
except OptionalDependencyNotAvailable: try:
from ..utils import dummy_flax_objects # noqa F403 if not is_flax_available():
raise OptionalDependencyNotAvailable()
_dummy_objects.update(get_objects_from_module(dummy_flax_objects)) except OptionalDependencyNotAvailable:
else: from ..utils import dummy_flax_objects # noqa F403
_import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
try: _dummy_objects.update(get_objects_from_module(dummy_flax_objects))
if not (is_flax_available() and is_transformers_available()): else:
raise OptionalDependencyNotAvailable() _import_structure["pipeline_flax_utils"] = ["FlaxDiffusionPipeline"]
except OptionalDependencyNotAvailable: try:
from ..utils import dummy_flax_and_transformers_objects # noqa F403 if not (is_flax_available() and is_transformers_available()):
raise OptionalDependencyNotAvailable()
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects)) except OptionalDependencyNotAvailable:
else: from ..utils import dummy_flax_and_transformers_objects # noqa F403
_import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"])
_import_structure["stable_diffusion"].extend( _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
[ else:
"FlaxStableDiffusionImg2ImgPipeline", _import_structure["controlnet"].extend(["FlaxStableDiffusionControlNetPipeline"])
"FlaxStableDiffusionInpaintPipeline", _import_structure["stable_diffusion"].extend(
"FlaxStableDiffusionPipeline", [
] "FlaxStableDiffusionImg2ImgPipeline",
) "FlaxStableDiffusionInpaintPipeline",
try: "FlaxStableDiffusionPipeline",
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): ]
raise OptionalDependencyNotAvailable() )
except OptionalDependencyNotAvailable: _import_structure["stable_diffusion_xl"].extend(
from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403 [
"FlaxStableDiffusionXLPipeline",
_dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects)) ]
else: )
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"] try:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
if TYPE_CHECKING: raise OptionalDependencyNotAvailable()
try: except OptionalDependencyNotAvailable:
if not is_torch_available(): from ..utils import dummy_transformers_and_torch_and_note_seq_objects # noqa F403
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_transformers_and_torch_and_note_seq_objects))
from ..utils.dummy_pt_objects import * # noqa F403 else:
_import_structure["spectrogram_diffusion"] = ["MidiProcessor", "SpectrogramDiffusionPipeline"]
else:
from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image if TYPE_CHECKING:
from .consistency_models import ConsistencyModelPipeline try:
from .dance_diffusion import DanceDiffusionPipeline if not is_torch_available():
from .ddim import DDIMPipeline raise OptionalDependencyNotAvailable()
from .ddpm import DDPMPipeline except OptionalDependencyNotAvailable:
from .dit import DiTPipeline from ..utils.dummy_pt_objects import * # noqa F403
from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline else:
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput from .auto_pipeline import AutoPipelineForImage2Image, AutoPipelineForInpainting, AutoPipelineForText2Image
from .pndm import PNDMPipeline from .consistency_models import ConsistencyModelPipeline
from .repaint import RePaintPipeline from .dance_diffusion import DanceDiffusionPipeline
from .score_sde_ve import ScoreSdeVePipeline from .ddim import DDIMPipeline
from .stochastic_karras_ve import KarrasVePipeline from .ddpm import DDPMPipeline
from .dit import DiTPipeline
try: from .latent_diffusion import LDMSuperResolutionPipeline
if not (is_torch_available() and is_librosa_available()): from .latent_diffusion_uncond import LDMPipeline
raise OptionalDependencyNotAvailable() from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
except OptionalDependencyNotAvailable: from .pndm import PNDMPipeline
from ..utils.dummy_torch_and_librosa_objects import * from .repaint import RePaintPipeline
else: from .score_sde_ve import ScoreSdeVePipeline
from .audio_diffusion import AudioDiffusionPipeline, Mel from .stochastic_karras_ve import KarrasVePipeline
try: try:
if not (is_torch_available() and is_transformers_available()): if not (is_torch_available() and is_librosa_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_objects import * from ..utils.dummy_torch_and_librosa_objects import *
else: else:
from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline from .audio_diffusion import AudioDiffusionPipeline, Mel
from .audioldm import AudioLDMPipeline
from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel try:
from .blip_diffusion import BlipDiffusionPipeline if not (is_torch_available() and is_transformers_available()):
from .controlnet import ( raise OptionalDependencyNotAvailable()
BlipDiffusionControlNetPipeline, except OptionalDependencyNotAvailable:
StableDiffusionControlNetImg2ImgPipeline, from ..utils.dummy_torch_and_transformers_objects import *
StableDiffusionControlNetInpaintPipeline, else:
StableDiffusionControlNetPipeline, from .alt_diffusion import AltDiffusionImg2ImgPipeline, AltDiffusionPipeline
StableDiffusionXLControlNetImg2ImgPipeline, from .audioldm import AudioLDMPipeline
StableDiffusionXLControlNetInpaintPipeline, from .audioldm2 import AudioLDM2Pipeline, AudioLDM2ProjectionModel, AudioLDM2UNet2DConditionModel
StableDiffusionXLControlNetPipeline, from .blip_diffusion import BlipDiffusionPipeline
) from .controlnet import (
from .deepfloyd_if import ( BlipDiffusionControlNetPipeline,
IFImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline,
IFImg2ImgSuperResolutionPipeline, StableDiffusionControlNetInpaintPipeline,
IFInpaintingPipeline, StableDiffusionControlNetPipeline,
IFInpaintingSuperResolutionPipeline, StableDiffusionXLControlNetImg2ImgPipeline,
IFPipeline, StableDiffusionXLControlNetInpaintPipeline,
IFSuperResolutionPipeline, StableDiffusionXLControlNetPipeline,
) )
from .kandinsky import ( from .deepfloyd_if import (
KandinskyCombinedPipeline, IFImg2ImgPipeline,
KandinskyImg2ImgCombinedPipeline, IFImg2ImgSuperResolutionPipeline,
KandinskyImg2ImgPipeline, IFInpaintingPipeline,
KandinskyInpaintCombinedPipeline, IFInpaintingSuperResolutionPipeline,
KandinskyInpaintPipeline, IFPipeline,
KandinskyPipeline, IFSuperResolutionPipeline,
KandinskyPriorPipeline, )
) from .kandinsky import (
from .kandinsky2_2 import ( KandinskyCombinedPipeline,
KandinskyV22CombinedPipeline, KandinskyImg2ImgCombinedPipeline,
KandinskyV22ControlnetImg2ImgPipeline, KandinskyImg2ImgPipeline,
KandinskyV22ControlnetPipeline, KandinskyInpaintCombinedPipeline,
KandinskyV22Img2ImgCombinedPipeline, KandinskyInpaintPipeline,
KandinskyV22Img2ImgPipeline, KandinskyPipeline,
KandinskyV22InpaintCombinedPipeline, KandinskyPriorPipeline,
KandinskyV22InpaintPipeline, )
KandinskyV22Pipeline, from .kandinsky2_2 import (
KandinskyV22PriorEmb2EmbPipeline, KandinskyV22CombinedPipeline,
KandinskyV22PriorPipeline, KandinskyV22ControlnetImg2ImgPipeline,
) KandinskyV22ControlnetPipeline,
from .latent_diffusion import LDMTextToImagePipeline KandinskyV22Img2ImgCombinedPipeline,
from .musicldm import MusicLDMPipeline KandinskyV22Img2ImgPipeline,
from .paint_by_example import PaintByExamplePipeline KandinskyV22InpaintCombinedPipeline,
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline KandinskyV22InpaintPipeline,
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline KandinskyV22Pipeline,
from .stable_diffusion import ( KandinskyV22PriorEmb2EmbPipeline,
CLIPImageProjection, KandinskyV22PriorPipeline,
CycleDiffusionPipeline, )
StableDiffusionAttendAndExcitePipeline, from .latent_diffusion import LDMTextToImagePipeline
StableDiffusionDepth2ImgPipeline, from .musicldm import MusicLDMPipeline
StableDiffusionDiffEditPipeline, from .paint_by_example import PaintByExamplePipeline
StableDiffusionGLIGENPipeline, from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
StableDiffusionGLIGENTextImagePipeline, from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
StableDiffusionImageVariationPipeline, from .stable_diffusion import (
StableDiffusionImg2ImgPipeline, CLIPImageProjection,
StableDiffusionInpaintPipeline, CycleDiffusionPipeline,
StableDiffusionInpaintPipelineLegacy, StableDiffusionAttendAndExcitePipeline,
StableDiffusionInstructPix2PixPipeline, StableDiffusionDepth2ImgPipeline,
StableDiffusionLatentUpscalePipeline, StableDiffusionDiffEditPipeline,
StableDiffusionLDM3DPipeline, StableDiffusionGLIGENPipeline,
StableDiffusionModelEditingPipeline, StableDiffusionGLIGENTextImagePipeline,
StableDiffusionPanoramaPipeline, StableDiffusionImageVariationPipeline,
StableDiffusionParadigmsPipeline, StableDiffusionImg2ImgPipeline,
StableDiffusionPipeline, StableDiffusionInpaintPipeline,
StableDiffusionPix2PixZeroPipeline, StableDiffusionInpaintPipelineLegacy,
StableDiffusionSAGPipeline, StableDiffusionInstructPix2PixPipeline,
StableDiffusionUpscalePipeline, StableDiffusionLatentUpscalePipeline,
StableUnCLIPImg2ImgPipeline, StableDiffusionLDM3DPipeline,
StableUnCLIPPipeline, StableDiffusionModelEditingPipeline,
) StableDiffusionPanoramaPipeline,
from .stable_diffusion_safe import StableDiffusionPipelineSafe StableDiffusionParadigmsPipeline,
from .stable_diffusion_xl import ( StableDiffusionPipeline,
StableDiffusionXLImg2ImgPipeline, StableDiffusionPix2PixZeroPipeline,
StableDiffusionXLInpaintPipeline, StableDiffusionSAGPipeline,
StableDiffusionXLInstructPix2PixPipeline, StableDiffusionUpscalePipeline,
StableDiffusionXLPipeline, StableUnCLIPImg2ImgPipeline,
) StableUnCLIPPipeline,
from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline )
from .text_to_video_synthesis import ( from .stable_diffusion_safe import StableDiffusionPipelineSafe
TextToVideoSDPipeline, from .stable_diffusion_xl import (
TextToVideoZeroPipeline, StableDiffusionXLImg2ImgPipeline,
VideoToVideoSDPipeline, StableDiffusionXLInpaintPipeline,
) StableDiffusionXLInstructPix2PixPipeline,
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline StableDiffusionXLPipeline,
from .unidiffuser import ( )
ImageTextPipelineOutput, from .t2i_adapter import StableDiffusionAdapterPipeline, StableDiffusionXLAdapterPipeline
UniDiffuserModel, from .text_to_video_synthesis import (
UniDiffuserPipeline, TextToVideoSDPipeline,
UniDiffuserTextDecoder, TextToVideoZeroPipeline,
) VideoToVideoSDPipeline,
from .versatile_diffusion import ( )
VersatileDiffusionDualGuidedPipeline, from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
VersatileDiffusionImageVariationPipeline, from .unidiffuser import (
VersatileDiffusionPipeline, ImageTextPipelineOutput,
VersatileDiffusionTextToImagePipeline, UniDiffuserModel,
) UniDiffuserPipeline,
from .vq_diffusion import VQDiffusionPipeline UniDiffuserTextDecoder,
from .wuerstchen import ( )
WuerstchenCombinedPipeline, from .versatile_diffusion import (
WuerstchenDecoderPipeline, VersatileDiffusionDualGuidedPipeline,
WuerstchenPriorPipeline, VersatileDiffusionImageVariationPipeline,
) VersatileDiffusionPipeline,
VersatileDiffusionTextToImagePipeline,
try: )
if not is_onnx_available(): from .vq_diffusion import VQDiffusionPipeline
raise OptionalDependencyNotAvailable() from .wuerstchen import (
except OptionalDependencyNotAvailable: WuerstchenCombinedPipeline,
from ..utils.dummy_onnx_objects import * # noqa F403 WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
else: )
from .onnx_utils import OnnxRuntimeModel
try:
try: if not is_onnx_available():
if not (is_torch_available() and is_transformers_available() and is_onnx_available()): raise OptionalDependencyNotAvailable()
raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable:
except OptionalDependencyNotAvailable: from ..utils.dummy_onnx_objects import * # noqa F403
from ..utils.dummy_torch_and_transformers_and_onnx_objects import *
else: else:
from .stable_diffusion import ( from .onnx_utils import OnnxRuntimeModel
OnnxStableDiffusionImg2ImgPipeline,
OnnxStableDiffusionInpaintPipeline, try:
OnnxStableDiffusionInpaintPipelineLegacy, if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
OnnxStableDiffusionPipeline, raise OptionalDependencyNotAvailable()
OnnxStableDiffusionUpscalePipeline, except OptionalDependencyNotAvailable:
StableDiffusionOnnxPipeline, from ..utils.dummy_torch_and_transformers_and_onnx_objects import *
) else:
from .stable_diffusion import (
try: OnnxStableDiffusionImg2ImgPipeline,
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()): OnnxStableDiffusionInpaintPipeline,
raise OptionalDependencyNotAvailable() OnnxStableDiffusionInpaintPipelineLegacy,
except OptionalDependencyNotAvailable: OnnxStableDiffusionPipeline,
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import * OnnxStableDiffusionUpscalePipeline,
else: StableDiffusionOnnxPipeline,
from .stable_diffusion import StableDiffusionKDiffusionPipeline )
try: try:
if not is_flax_available(): if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_objects import * # noqa F403 from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
else: else:
from .pipeline_flax_utils import FlaxDiffusionPipeline from .stable_diffusion import StableDiffusionKDiffusionPipeline
try: try:
if not (is_flax_available() and is_transformers_available()): if not is_flax_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import * from ..utils.dummy_flax_objects import * # noqa F403
else: else:
from .controlnet import FlaxStableDiffusionControlNetPipeline from .pipeline_flax_utils import FlaxDiffusionPipeline
from .stable_diffusion import (
FlaxStableDiffusionImg2ImgPipeline, try:
FlaxStableDiffusionInpaintPipeline, if not (is_flax_available() and is_transformers_available()):
FlaxStableDiffusionPipeline, raise OptionalDependencyNotAvailable()
) except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_and_transformers_objects import *
try: else:
if not (is_transformers_available() and is_torch_available() and is_note_seq_available()): from .controlnet import FlaxStableDiffusionControlNetPipeline
raise OptionalDependencyNotAvailable() from .stable_diffusion import (
except OptionalDependencyNotAvailable: FlaxStableDiffusionImg2ImgPipeline,
from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403 FlaxStableDiffusionInpaintPipeline,
FlaxStableDiffusionPipeline,
else: )
from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline from .stable_diffusion_xl import (
FlaxStableDiffusionXLPipeline,
else: )
import sys
try:
sys.modules[__name__] = _LazyModule( if not (is_transformers_available() and is_torch_available() and is_note_seq_available()):
__name__, raise OptionalDependencyNotAvailable()
globals()["__file__"], except OptionalDependencyNotAvailable:
_import_structure, from ..utils.dummy_transformers_and_torch_and_note_seq_objects import * # noqa F403
module_spec=__spec__,
) else:
for name, value in _dummy_objects.items(): from .spectrogram_diffusion import MidiProcessor, SpectrogramDiffusionPipeline
setattr(sys.modules[__name__], name, value)
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
...@@ -394,10 +394,29 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -394,10 +394,29 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
# extract them here # extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class) expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
init_kwargs = {} # define init kwargs
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
# Throw nice warnings / errors for fast accelerate loading
if len(unused_kwargs) > 0:
logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
)
# inference_params # inference_params
params = {} params = {}
......
...@@ -4,14 +4,18 @@ from ...utils import ( ...@@ -4,14 +4,18 @@ from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
get_objects_from_module, get_objects_from_module,
is_flax_available,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
) )
_dummy_objects = {} _dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["StableDiffusionXLPipelineOutput"]} _import_structure = {"pipeline_output": ["StableDiffusionXLPipelineOutput"]}
if is_transformers_available() and is_flax_available():
_import_structure["pipeline_output"].extend(["FlaxStableDiffusionXLPipelineOutput"])
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -25,6 +29,12 @@ else: ...@@ -25,6 +29,12 @@ else:
_import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"] _import_structure["pipeline_stable_diffusion_xl_inpaint"] = ["StableDiffusionXLInpaintPipeline"]
_import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"] _import_structure["pipeline_stable_diffusion_xl_instruct_pix2pix"] = ["StableDiffusionXLInstructPix2PixPipeline"]
if is_transformers_available() and is_flax_available():
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
_additional_imports.update({"PNDMSchedulerState": PNDMSchedulerState})
_import_structure["pipeline_flax_stable_diffusion_xl"] = ["FlaxStableDiffusionXLPipeline"]
if TYPE_CHECKING: if TYPE_CHECKING:
try: try:
...@@ -38,6 +48,17 @@ if TYPE_CHECKING: ...@@ -38,6 +48,17 @@ if TYPE_CHECKING:
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline from .pipeline_stable_diffusion_xl_instruct_pix2pix import StableDiffusionXLInstructPix2PixPipeline
try:
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_flax_objects import *
else:
from .pipeline_flax_stable_diffusion_xl import (
FlaxStableDiffusionXLPipeline,
)
from .pipeline_output import FlaxStableDiffusionXLPipelineOutput
else: else:
import sys import sys
...@@ -50,3 +71,5 @@ else: ...@@ -50,3 +71,5 @@ else:
for name, value in _dummy_objects.items(): for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value) setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Dict, List, Optional, Union
import jax
import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict
from transformers import CLIPTokenizer, FlaxCLIPTextModel
from diffusers.utils import logging
from ...models import FlaxAutoencoderKL, FlaxUNet2DConditionModel
from ...schedulers import (
FlaxDDIMScheduler,
FlaxDPMSolverMultistepScheduler,
FlaxLMSDiscreteScheduler,
FlaxPNDMScheduler,
)
from ..pipeline_flax_utils import FlaxDiffusionPipeline
from . import FlaxStableDiffusionXLPipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Set to True to use python for loop instead of jax.fori_loop for easier debugging
DEBUG = False
class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
def __init__(
self,
text_encoder: FlaxCLIPTextModel,
text_encoder_2: FlaxCLIPTextModel,
vae: FlaxAutoencoderKL,
tokenizer: CLIPTokenizer,
tokenizer_2: CLIPTokenizer,
unet: FlaxUNet2DConditionModel,
scheduler: Union[
FlaxDDIMScheduler, FlaxPNDMScheduler, FlaxLMSDiscreteScheduler, FlaxDPMSolverMultistepScheduler
],
dtype: jnp.dtype = jnp.float32,
):
super().__init__()
self.dtype = dtype
self.register_modules(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
unet=unet,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
def prepare_inputs(self, prompt: Union[str, List[str]]):
if not isinstance(prompt, (str, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
# Assume we have the two encoders
inputs = []
for tokenizer in [self.tokenizer, self.tokenizer_2]:
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors="np",
)
inputs.append(text_inputs.input_ids)
inputs = jnp.stack(inputs, axis=1)
return inputs
def __call__(
self,
prompt_ids: jax.Array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
num_inference_steps: int = 50,
guidance_scale: Union[float, jax.Array] = 7.5,
height: Optional[int] = None,
width: Optional[int] = None,
latents: jnp.array = None,
neg_prompt_ids: jnp.array = None,
return_dict: bool = True,
output_type: str = None,
jit: bool = False,
):
# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
if isinstance(guidance_scale, float) and jit:
# Convert to a tensor so each device gets a copy.
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
guidance_scale = guidance_scale[:, None]
return_latents = output_type == "latent"
if jit:
images = _p_generate(
self,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
return_latents,
)
else:
images = self._generate(
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
return_latents,
)
if not return_dict:
return (images,)
return FlaxStableDiffusionXLPipelineOutput(images=images)
def get_embeddings(self, prompt_ids: jnp.array, params):
# We assume we have the two encoders
# bs, encoder_input, seq_length
te_1_inputs = prompt_ids[:, 0, :]
te_2_inputs = prompt_ids[:, 1, :]
prompt_embeds = self.text_encoder(te_1_inputs, params=params["text_encoder"], output_hidden_states=True)
prompt_embeds = prompt_embeds["hidden_states"][-2]
prompt_embeds_2_out = self.text_encoder_2(
te_2_inputs, params=params["text_encoder_2"], output_hidden_states=True
)
prompt_embeds_2 = prompt_embeds_2_out["hidden_states"][-2]
text_embeds = prompt_embeds_2_out["text_embeds"]
prompt_embeds = jnp.concatenate([prompt_embeds, prompt_embeds_2], axis=-1)
return prompt_embeds, text_embeds
def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, bs, dtype):
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = jnp.array([add_time_ids] * bs, dtype=dtype)
return add_time_ids
def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
num_inference_steps: int,
height: int,
width: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
neg_prompt_ids: Optional[jnp.array] = None,
return_latents=False,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
# Encode input prompt
prompt_embeds, pooled_embeds = self.get_embeddings(prompt_ids, params)
# Get unconditional embeddings
batch_size = prompt_embeds.shape[0]
if neg_prompt_ids is None:
neg_prompt_ids = self.prepare_inputs([""] * batch_size)
neg_prompt_embeds, negative_pooled_embeds = self.get_embeddings(neg_prompt_ids, params)
add_time_ids = self._get_add_time_ids(
(height, width), (0, 0), (height, width), prompt_embeds.shape[0], dtype=prompt_embeds.dtype
)
prompt_embeds = jnp.concatenate([neg_prompt_embeds, prompt_embeds], axis=0) # (2, 77, 2048)
add_text_embeds = jnp.concatenate([negative_pooled_embeds, pooled_embeds], axis=0)
add_time_ids = jnp.concatenate([add_time_ids, add_time_ids], axis=0)
# Ensure model output will be `float32` before going into the scheduler
guidance_scale = jnp.array([guidance_scale], dtype=jnp.float32)
# Create random latents
latents_shape = (
batch_size,
self.unet.config.in_channels,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
)
if latents is None:
latents = jax.random.normal(prng_seed, shape=latents_shape, dtype=jnp.float32)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * params["scheduler"].init_noise_sigma
# Prepare scheduler state
scheduler_state = self.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
)
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
# Denoising loop
def loop_body(step, args):
latents, scheduler_state = args
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
latents_input = jnp.concatenate([latents] * 2)
t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step]
timestep = jnp.broadcast_to(t, latents_input.shape[0])
latents_input = self.scheduler.scale_model_input(scheduler_state, latents_input, t)
# predict the noise residual
noise_pred = self.unet.apply(
{"params": params["unet"]},
jnp.array(latents_input),
jnp.array(timestep, dtype=jnp.int32),
encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs,
).sample
# perform guidance
noise_pred_uncond, noise_prediction_text = jnp.split(noise_pred, 2, axis=0)
noise_pred = noise_pred_uncond + guidance_scale * (noise_prediction_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state
if DEBUG:
# run with python for loop
for i in range(num_inference_steps):
latents, scheduler_state = loop_body(i, (latents, scheduler_state))
else:
latents, _ = jax.lax.fori_loop(0, num_inference_steps, loop_body, (latents, scheduler_state))
if return_latents:
return latents
# Decode latents
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.apply({"params": params["vae"]}, latents, method=self.vae.decode).sample
image = (image / 2 + 0.5).clip(0, 1).transpose(0, 2, 3, 1)
return image
# Static argnums are pipe, num_inference_steps, height, width, return_latents. A change would trigger recompilation.
# Non-static args are (sharded) input tensors mapped over their first dimension (hence, `0`).
@partial(
jax.pmap,
in_axes=(None, 0, 0, 0, None, None, None, 0, 0, 0, None),
static_broadcasted_argnums=(0, 4, 5, 6, 10),
)
def _p_generate(
pipe,
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
return_latents,
):
return pipe._generate(
prompt_ids,
params,
prng_seed,
num_inference_steps,
height,
width,
guidance_scale,
latents,
neg_prompt_ids,
return_latents,
)
...@@ -4,7 +4,11 @@ from typing import List, Union ...@@ -4,7 +4,11 @@ from typing import List, Union
import numpy as np import numpy as np
import PIL import PIL
from ...utils import BaseOutput from ...utils import (
BaseOutput,
is_flax_available,
is_transformers_available,
)
@dataclass @dataclass
...@@ -19,3 +23,19 @@ class StableDiffusionXLPipelineOutput(BaseOutput): ...@@ -19,3 +23,19 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
""" """
images: Union[List[PIL.Image.Image], np.ndarray] images: Union[List[PIL.Image.Image], np.ndarray]
if is_transformers_available() and is_flax_available():
import flax
@flax.struct.dataclass
class FlaxStableDiffusionXLPipelineOutput(BaseOutput):
"""
Output class for Flax Stable Diffusion XL pipelines.
Args:
images (`np.ndarray`)
Array of shape `(batch_size, height, width, num_channels)` with images from the diffusion pipeline.
"""
images: np.ndarray
...@@ -1094,7 +1094,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1094,7 +1094,6 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
time_ids = added_cond_kwargs.get("time_ids") time_ids = added_cond_kwargs.get("time_ids")
time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = self.add_time_proj(time_ids.flatten())
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
add_embeds = add_embeds.to(emb.dtype) add_embeds = add_embeds.to(emb.dtype)
aug_emb = self.add_embedding(add_embeds) aug_emb = self.add_embedding(add_embeds)
......
...@@ -76,6 +76,7 @@ else: ...@@ -76,6 +76,7 @@ else:
_import_structure["scheduling_ddim_flax"] = ["FlaxDDIMScheduler"] _import_structure["scheduling_ddim_flax"] = ["FlaxDDIMScheduler"]
_import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"] _import_structure["scheduling_ddpm_flax"] = ["FlaxDDPMScheduler"]
_import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_multistep_flax"] = ["FlaxDPMSolverMultistepScheduler"]
_import_structure["scheduling_euler_discrete_flax"] = ["FlaxEulerDiscreteScheduler"]
_import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"] _import_structure["scheduling_karras_ve_flax"] = ["FlaxKarrasVeScheduler"]
_import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"] _import_structure["scheduling_lms_discrete_flax"] = ["FlaxLMSDiscreteScheduler"]
_import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"] _import_structure["scheduling_pndm_flax"] = ["FlaxPNDMScheduler"]
......
# Copyright 2023 Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
@flax.struct.dataclass
class EulerDiscreteSchedulerState:
common: CommonSchedulerState
# setable values
init_noise_sigma: jnp.ndarray
timesteps: jnp.ndarray
sigmas: jnp.ndarray
num_inference_steps: Optional[int] = None
@classmethod
def create(
cls, common: CommonSchedulerState, init_noise_sigma: jnp.ndarray, timesteps: jnp.ndarray, sigmas: jnp.ndarray
):
return cls(common=common, init_noise_sigma=init_noise_sigma, timesteps=timesteps, sigmas=sigmas)
@dataclass
class FlaxEulerDiscreteSchedulerOutput(FlaxSchedulerOutput):
state: EulerDiscreteSchedulerState
class FlaxEulerDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
"""
Euler scheduler (Algorithm 2) from Karras et al. (2022) https://arxiv.org/abs/2206.00364. . Based on the original
k-diffusion implementation by Katherine Crowson:
https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
[`~SchedulerMixin.from_pretrained`] functions.
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
beta_start (`float`): the starting `beta` value of inference.
beta_end (`float`): the final `beta` value.
beta_schedule (`str`):
the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
`linear` or `scaled_linear`.
trained_betas (`jnp.ndarray`, optional):
option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
https://imagen.research.google/video/paper.pdf)
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
the `dtype` used for params and computation.
"""
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
dtype: jnp.dtype
@property
def has_state(self):
return True
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[jnp.ndarray] = None,
prediction_type: str = "epsilon",
timestep_spacing: str = "linspace",
dtype: jnp.dtype = jnp.float32,
):
self.dtype = dtype
def create_state(self, common: Optional[CommonSchedulerState] = None) -> EulerDiscreteSchedulerState:
if common is None:
common = CommonSchedulerState.create(self)
timesteps = jnp.arange(0, self.config.num_train_timesteps).round()[::-1]
sigmas = ((1 - common.alphas_cumprod) / common.alphas_cumprod) ** 0.5
sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas)
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
init_noise_sigma = sigmas.max()
else:
init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5
return EulerDiscreteSchedulerState.create(
common=common,
init_noise_sigma=init_noise_sigma,
timesteps=timesteps,
sigmas=sigmas,
)
def scale_model_input(self, state: EulerDiscreteSchedulerState, sample: jnp.ndarray, timestep: int) -> jnp.ndarray:
"""
Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Args:
state (`EulerDiscreteSchedulerState`):
the `FlaxEulerDiscreteScheduler` state data class instance.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
timestep (`int`):
current discrete timestep in the diffusion chain.
Returns:
`jnp.ndarray`: scaled input sample
"""
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
step_index = step_index[0]
sigma = state.sigmas[step_index]
sample = sample / ((sigma**2 + 1) ** 0.5)
return sample
def set_timesteps(
self, state: EulerDiscreteSchedulerState, num_inference_steps: int, shape: Tuple = ()
) -> EulerDiscreteSchedulerState:
"""
Sets the timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
state (`EulerDiscreteSchedulerState`):
the `FlaxEulerDiscreteScheduler` state data class instance.
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
if self.config.timestep_spacing == "linspace":
timesteps = jnp.linspace(self.config.num_train_timesteps - 1, 0, num_inference_steps, dtype=self.dtype)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // num_inference_steps
timesteps = (jnp.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(float)
timesteps += 1
else:
raise ValueError(
f"timestep_spacing must be one of ['linspace', 'leading'], got {self.config.timestep_spacing}"
)
sigmas = ((1 - state.common.alphas_cumprod) / state.common.alphas_cumprod) ** 0.5
sigmas = jnp.interp(timesteps, jnp.arange(0, len(sigmas)), sigmas)
sigmas = jnp.concatenate([sigmas, jnp.array([0.0], dtype=self.dtype)])
# standard deviation of the initial noise distribution
if self.config.timestep_spacing in ["linspace", "trailing"]:
init_noise_sigma = sigmas.max()
else:
init_noise_sigma = (sigmas.max() ** 2 + 1) ** 0.5
return state.replace(
timesteps=timesteps,
sigmas=sigmas,
num_inference_steps=num_inference_steps,
init_noise_sigma=init_noise_sigma,
)
def step(
self,
state: EulerDiscreteSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxEulerDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
state (`EulerDiscreteSchedulerState`):
the `FlaxEulerDiscreteScheduler` state data class instance.
model_output (`jnp.ndarray`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
order: coefficient for multi-step inference.
return_dict (`bool`): option for returning tuple rather than FlaxEulerDiscreteScheduler class
Returns:
[`FlaxEulerDiscreteScheduler`] or `tuple`: [`FlaxEulerDiscreteScheduler`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is the sample tensor.
"""
if state.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
(step_index,) = jnp.where(state.timesteps == timestep, size=1)
step_index = step_index[0]
sigma = state.sigmas[step_index]
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
if self.config.prediction_type == "epsilon":
pred_original_sample = sample - sigma * model_output
elif self.config.prediction_type == "v_prediction":
# * c_out + input * c_skip
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, or `v_prediction`"
)
# 2. Convert to an ODE derivative
derivative = (sample - pred_original_sample) / sigma
# dt = sigma_down - sigma
dt = state.sigmas[step_index + 1] - sigma
prev_sample = sample + derivative * dt
if not return_dict:
return (prev_sample, state)
return FlaxEulerDiscreteSchedulerOutput(prev_sample=prev_sample, state=state)
def add_noise(
self,
state: EulerDiscreteSchedulerState,
original_samples: jnp.ndarray,
noise: jnp.ndarray,
timesteps: jnp.ndarray,
) -> jnp.ndarray:
sigma = state.sigmas[timesteps].flatten()
sigma = broadcast_to_shape_from_left(sigma, noise.shape)
noisy_samples = original_samples + noise * sigma
return noisy_samples
def __len__(self):
return self.config.num_train_timesteps
...@@ -37,6 +37,7 @@ class FlaxKarrasDiffusionSchedulers(Enum): ...@@ -37,6 +37,7 @@ class FlaxKarrasDiffusionSchedulers(Enum):
FlaxPNDMScheduler = 3 FlaxPNDMScheduler = 3
FlaxLMSDiscreteScheduler = 4 FlaxLMSDiscreteScheduler = 4
FlaxDPMSolverMultistepScheduler = 5 FlaxDPMSolverMultistepScheduler = 5
FlaxEulerDiscreteScheduler = 6
@dataclass @dataclass
......
...@@ -60,3 +60,18 @@ class FlaxStableDiffusionPipeline(metaclass=DummyObject): ...@@ -60,3 +60,18 @@ class FlaxStableDiffusionPipeline(metaclass=DummyObject):
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"]) requires_backends(cls, ["flax", "transformers"])
class FlaxStableDiffusionXLPipeline(metaclass=DummyObject):
_backends = ["flax", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax", "transformers"])
...@@ -122,6 +122,21 @@ class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject): ...@@ -122,6 +122,21 @@ class FlaxDPMSolverMultistepScheduler(metaclass=DummyObject):
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxEulerDiscreteScheduler(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxKarrasVeScheduler(metaclass=DummyObject): class FlaxKarrasVeScheduler(metaclass=DummyObject):
_backends = ["flax"] _backends = ["flax"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment