Unverified Commit 29b2c93c authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

Make repo structure consistent (#1862)



* move files a bit

* more refactors

* fix more

* more fixes

* fix more onnx

* make style

* upload

* fix

* up

* fix more

* up again

* up

* small fix

* Update src/diffusers/__init__.py
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>

* correct
Co-authored-by: default avatarPedro Cuenca <pedro@huggingface.co>
parent ab0e92fd
...@@ -20,11 +20,11 @@ import torch.nn.functional as F ...@@ -20,11 +20,11 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..models.embeddings import ImagePositionalEmbeddings from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput from ..utils import BaseOutput
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
from .cross_attention import CrossAttention from .cross_attention import CrossAttention
from .modeling_utils import ModelMixin
@dataclass @dataclass
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class AutoencoderKL(ModelMixin, ConfigMixin):
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
and Max Welling.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
"""
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)
# pass init params to Decoder
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)
self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def enable_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from torch import nn
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
class DualTransformer2DModel(nn.Module):
"""
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
):
super().__init__()
self.transformers = nn.ModuleList(
[
Transformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
num_layers=num_layers,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
sample_size=sample_size,
num_vector_embeds=num_vector_embeds,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
)
for _ in range(2)
]
)
# Variables that can be set by a pipeline:
# The ratio of transformer1 to transformer2's output states to be combined during inference
self.mix_ratio = 0.5
# The shape of `encoder_hidden_states` is expected to be
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
self.condition_lengths = [77, 257]
# Which transformer to use to encode which condition.
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
self.transformer_index_for_condition = [1, 0]
def forward(
self,
hidden_states,
encoder_hidden_states,
timestep=None,
attention_mask=None,
cross_attention_kwargs=None,
return_dict: bool = True,
):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
attention_mask (`torch.FloatTensor`, *optional*):
Optional attention mask to be applied in CrossAttention
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
tensor.
"""
input_states = hidden_states
encoded_states = []
tokens_start = 0
# attention_mask is not used yet
for i in range(2):
# for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
transformer_index = self.transformer_index_for_condition[i]
encoded_state = self.transformers[transformer_index](
input_states,
encoder_hidden_states=condition_state,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
encoded_states.append(encoded_state - input_states)
tokens_start += self.condition_lengths[i]
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
output_states = output_states + input_states
if not return_dict:
return (output_states,)
return Transformer2DModelOutput(sample=output_states)
...@@ -19,7 +19,7 @@ import jax.numpy as jnp ...@@ -19,7 +19,7 @@ import jax.numpy as jnp
from flax.traverse_util import flatten_dict, unflatten_dict from flax.traverse_util import flatten_dict, unflatten_dict
from jax.random import PRNGKey from jax.random import PRNGKey
from .utils import logging from ..utils import logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -27,9 +27,8 @@ from huggingface_hub import hf_hub_download ...@@ -27,9 +27,8 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
from . import __version__, is_torch_available from .. import __version__, is_torch_available
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax from ..utils import (
from .utils import (
CONFIG_NAME, CONFIG_NAME,
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
FLAX_WEIGHTS_NAME, FLAX_WEIGHTS_NAME,
...@@ -37,6 +36,7 @@ from .utils import ( ...@@ -37,6 +36,7 @@ from .utils import (
WEIGHTS_NAME, WEIGHTS_NAME,
logging, logging,
) )
from .modeling_flax_pytorch_utils import convert_pytorch_state_dict_to_flax
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
...@@ -26,11 +26,11 @@ from huggingface_hub import hf_hub_download ...@@ -26,11 +26,11 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError from requests import HTTPError
from . import __version__ from .. import __version__
from .hub_utils import HF_HUB_OFFLINE from ..utils import (
from .utils import (
CONFIG_NAME, CONFIG_NAME,
DIFFUSERS_CACHE, DIFFUSERS_CACHE,
HF_HUB_OFFLINE,
HUGGINGFACE_CO_RESOLVE_ENDPOINT, HUGGINGFACE_CO_RESOLVE_ENDPOINT,
SAFETENSORS_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME,
WEIGHTS_NAME, WEIGHTS_NAME,
...@@ -149,7 +149,7 @@ class ModelMixin(torch.nn.Module): ...@@ -149,7 +149,7 @@ class ModelMixin(torch.nn.Module):
and saving models. and saving models.
- **config_name** ([`str`]) -- A filename under which the model should be stored when calling - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
[`~modeling_utils.ModelMixin.save_pretrained`]. [`~models.ModelMixin.save_pretrained`].
""" """
config_name = CONFIG_NAME config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
...@@ -231,7 +231,7 @@ class ModelMixin(torch.nn.Module): ...@@ -231,7 +231,7 @@ class ModelMixin(torch.nn.Module):
): ):
""" """
Save a model and its configuration file to a directory, so that it can be re-loaded using the Save a model and its configuration file to a directory, so that it can be re-loaded using the
`[`~modeling_utils.ModelMixin.from_pretrained`]` class method. `[`~models.ModelMixin.from_pretrained`]` class method.
Arguments: Arguments:
save_directory (`str` or `os.PathLike`): save_directory (`str` or `os.PathLike`):
......
...@@ -6,10 +6,10 @@ import torch.nn.functional as F ...@@ -6,10 +6,10 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput from ..utils import BaseOutput
from .attention import BasicTransformerBlock from .attention import BasicTransformerBlock
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
@dataclass @dataclass
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput
from .attention import BasicTransformerBlock
from .modeling_utils import ModelMixin
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
for the unnoised latent pixels.
"""
sample: torch.FloatTensor
class Transformer2DModel(ModelMixin, ConfigMixin):
"""
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
embeddings) inputs.
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
transformer action. Finally, reshape to image.
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
classes of unnoised image.
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = in_channels is not None
self.is_input_vectorized = num_vector_embeds is not None
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized:
raise ValueError(
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is not None."
)
# 2. Define input layers
if self.is_input_continuous:
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
self.height = sample_size
self.width = sample_size
self.num_vector_embeds = num_vector_embeds
self.num_latent_pixels = self.height * self.width
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
)
for d in range(num_layers)
]
)
# 4. Define output layers
if self.is_input_continuous:
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
cross_attention_kwargs=None,
return_dict: bool = True,
):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
tensor.
"""
# 1. Input
if self.is_input_continuous:
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
)
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
...@@ -19,9 +19,9 @@ import torch ...@@ -19,9 +19,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block
......
...@@ -18,9 +18,9 @@ import torch ...@@ -18,9 +18,9 @@ import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput from ..utils import BaseOutput
from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
......
...@@ -19,10 +19,10 @@ import torch.nn as nn ...@@ -19,10 +19,10 @@ import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor from .cross_attention import AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .unet_2d_blocks import ( from .unet_2d_blocks import (
CrossAttnDownBlock2D, CrossAttnDownBlock2D,
CrossAttnUpBlock2D, CrossAttnUpBlock2D,
......
...@@ -20,9 +20,9 @@ import jax.numpy as jnp ...@@ -20,9 +20,9 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..modeling_flax_utils import FlaxModelMixin
from ..utils import BaseOutput from ..utils import BaseOutput
from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps from .embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
from .modeling_flax_utils import FlaxModelMixin
from .unet_2d_blocks_flax import ( from .unet_2d_blocks_flax import (
FlaxCrossAttnDownBlock2D, FlaxCrossAttnDownBlock2D,
FlaxCrossAttnUpBlock2D, FlaxCrossAttnUpBlock2D,
......
...@@ -12,14 +12,12 @@ ...@@ -12,14 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple, Union from typing import Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput from ..utils import BaseOutput
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
...@@ -37,33 +35,6 @@ class DecoderOutput(BaseOutput): ...@@ -37,33 +35,6 @@ class DecoderOutput(BaseOutput):
sample: torch.FloatTensor sample: torch.FloatTensor
@dataclass
class VQEncoderOutput(BaseOutput):
"""
Output of VQModel encoding method.
Args:
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Encoded output sample of the model. Output of the last layer of the model.
"""
latents: torch.FloatTensor
@dataclass
class AutoencoderKLOutput(BaseOutput):
"""
Output of AutoencoderKL encoding method.
Args:
latent_dist (`DiagonalGaussianDistribution`):
Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
`DiagonalGaussianDistribution` allows for sampling latents from the distribution.
"""
latent_dist: "DiagonalGaussianDistribution"
class Encoder(nn.Module): class Encoder(nn.Module):
def __init__( def __init__(
self, self,
...@@ -384,255 +355,3 @@ class DiagonalGaussianDistribution(object): ...@@ -384,255 +355,3 @@ class DiagonalGaussianDistribution(object):
def mode(self): def mode(self):
return self.mean return self.mean
class VQModel(ModelMixin, ConfigMixin):
r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
Kavukcuoglu.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
"""
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 3,
sample_size: int = 32,
num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
vq_embed_dim: Optional[int] = None,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=False,
)
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1)
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1)
# pass init params to Decoder
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
)
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
h = self.encoder(x)
h = self.quant_conv(h)
if not return_dict:
return (h,)
return VQEncoderOutput(latents=h)
def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
h = self.encode(x).latents
dec = self.decode(h).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
class AutoencoderKL(ModelMixin, ConfigMixin):
r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
and Max Welling.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `4`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
"""
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=True,
)
# pass init params to Decoder
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
)
self.quant_conv = torch.nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.use_slicing = False
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def enable_slicing(self):
r"""
Enable sliced VAE decoding.
When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
...@@ -25,8 +25,8 @@ import jax.numpy as jnp ...@@ -25,8 +25,8 @@ import jax.numpy as jnp
from flax.core.frozen_dict import FrozenDict from flax.core.frozen_dict import FrozenDict
from ..configuration_utils import ConfigMixin, flax_register_to_config from ..configuration_utils import ConfigMixin, flax_register_to_config
from ..modeling_flax_utils import FlaxModelMixin
from ..utils import BaseOutput from ..utils import BaseOutput
from .modeling_flax_utils import FlaxModelMixin
@flax.struct.dataclass @flax.struct.dataclass
......
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .modeling_utils import ModelMixin
from .vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
@dataclass
class VQEncoderOutput(BaseOutput):
"""
Output of VQModel encoding method.
Args:
latents (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
Encoded output sample of the model. Output of the last layer of the model.
"""
latents: torch.FloatTensor
class VQModel(ModelMixin, ConfigMixin):
r"""VQ-VAE model from the paper Neural Discrete Representation Learning by Aaron van den Oord, Oriol Vinyals and Koray
Kavukcuoglu.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the model (such as downloading or saving, etc.)
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
obj:`(64,)`): Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
"""
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 3,
sample_size: int = 32,
num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
vq_embed_dim: Optional[int] = None,
):
super().__init__()
# pass init params to Encoder
self.encoder = Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
double_z=False,
)
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
# pass init params to Decoder
self.decoder = Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
)
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> VQEncoderOutput:
h = self.encoder(x)
h = self.quant_conv(h)
if not return_dict:
return (h,)
return VQEncoderOutput(latents=h)
def decode(
self, h: torch.FloatTensor, force_not_quantize: bool = False, return_dict: bool = True
) -> Union[DecoderOutput, torch.FloatTensor]:
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(self, sample: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
h = self.encode(x).latents
dec = self.decode(h).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# coding=utf-8 # Copyright 2022 The HuggingFace Team. All rights reserved.
# Copyright 2022 The HuggingFace Inc. team.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,869 +10,10 @@ ...@@ -12,869 +10,10 @@
# distributed under the License is distributed on an "AS IS" BASIS, # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import inspect
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import numpy as np
import torch
import diffusers
import PIL
from huggingface_hub import model_info, snapshot_download
from packaging import version
from PIL import Image
from tqdm.auto import tqdm
from .configuration_utils import ConfigMixin
from .dynamic_modules_utils import get_class_from_dynamic_module
from .hub_utils import HF_HUB_OFFLINE, http_user_agent
from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from .utils import (
CONFIG_NAME,
DIFFUSERS_CACHE,
ONNX_WEIGHTS_NAME,
WEIGHTS_NAME,
BaseOutput,
deprecate,
is_accelerate_available,
is_safetensors_available,
is_torch_version,
is_transformers_available,
logging,
)
if is_transformers_available():
import transformers
from transformers import PreTrainedModel
INDEX_FILE = "diffusion_pytorch_model.bin"
CUSTOM_PIPELINE_FILE_NAME = "pipeline.py"
DUMMY_MODULES_FOLDER = "diffusers.utils"
TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils"
logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"SchedulerMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"OnnxRuntimeModel": ["save_pretrained", "from_pretrained"],
},
"transformers": {
"PreTrainedTokenizer": ["save_pretrained", "from_pretrained"],
"PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"],
"PreTrainedModel": ["save_pretrained", "from_pretrained"],
"FeatureExtractionMixin": ["save_pretrained", "from_pretrained"],
"ProcessorMixin": ["save_pretrained", "from_pretrained"],
"ImageProcessingMixin": ["save_pretrained", "from_pretrained"],
},
"onnxruntime.training": {
"ORTModule": ["save_pretrained", "from_pretrained"],
},
}
ALL_IMPORTABLE_CLASSES = {}
for library in LOADABLE_CLASSES:
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
@dataclass
class ImagePipelineOutput(BaseOutput):
"""
Output class for image pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
@dataclass
class AudioPipelineOutput(BaseOutput):
"""
Output class for audio pipelines.
Args:
audios (`np.ndarray`)
List of denoised samples of shape `(batch_size, num_channels, sample_rate)`. Numpy array present the
denoised audio samples of the diffusion pipeline.
"""
audios: np.ndarray
def is_safetensors_compatible(info) -> bool:
filenames = set(sibling.rfilename for sibling in info.siblings)
pt_filenames = set(filename for filename in filenames if filename.endswith(".bin"))
is_safetensors_compatible = any(file.endswith(".safetensors") for file in filenames)
for pt_filename in pt_filenames:
prefix, raw = os.path.split(pt_filename)
if raw == "pytorch_model.bin":
# transformers specific
sf_filename = os.path.join(prefix, "model.safetensors")
else:
sf_filename = pt_filename[: -len(".bin")] + ".safetensors"
if is_safetensors_compatible and sf_filename not in filenames:
logger.warning(f"{sf_filename} not found")
is_safetensors_compatible = False
return is_safetensors_compatible
class DiffusionPipeline(ConfigMixin):
r"""
Base class for all models.
[`DiffusionPipeline`] takes care of storing all components (models, schedulers, processors) for diffusion pipelines
and handles methods for loading, downloading and saving models as well as a few methods common to all pipelines to:
- move all PyTorch modules to the device of your choice
- enabling/disabling the progress bar for the denoising iteration
Class attributes:
- **config_name** (`str`) -- name of the config file that will store the class and module names of all
components of the diffusion pipeline.
- **_optional_components** (List[`str`]) -- list of all components that are optional so they don't have to be
passed for the pipeline to function (should be overridden by subclasses).
"""
config_name = "model_index.json"
_optional_components = []
def register_modules(self, **kwargs):
# import it here to avoid circular import
from diffusers import pipelines
for name, module in kwargs.items():
# retrieve library
if module is None:
register_dict = {name: (None, None)}
else:
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
library = pipeline_dir
# retrieve class_name
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}
# save model index config
self.register_to_config(**register_dict)
# set models
setattr(self, name, module)
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
safe_serialization: bool = False,
):
"""
Save all variables of the pipeline that can be saved and loaded as well as the pipelines configuration file to
a directory. A pipeline variable can be saved and loaded if its class implements both a save and loading
method. The pipeline can easily be re-loaded using the `[`~DiffusionPipeline.from_pretrained`]` class method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""
self.save_config(save_directory)
model_index_dict = dict(self.config)
model_index_dict.pop("_class_name")
model_index_dict.pop("_diffusers_version")
model_index_dict.pop("_module", None)
expected_modules, optional_kwargs = self._get_signature_keys(self)
def is_saveable_module(name, value):
if name not in expected_modules:
return False
if name in self._optional_components and value[0] is None:
return False
return True
model_index_dict = {k: v for k, v in model_index_dict.items() if is_saveable_module(k, v)}
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
model_cls = sub_model.__class__
save_method_name = None
# search for the model's base class in LOADABLE_CLASSES
for library_name, library_classes in LOADABLE_CLASSES.items():
library = importlib.import_module(library_name)
for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class, None)
if class_candidate is not None and issubclass(model_cls, class_candidate):
# if we found a suitable base class in LOADABLE_CLASSES then grab its save method
save_method_name = save_load_methods[0]
break
if save_method_name is not None:
break
save_method = getattr(sub_model, save_method_name)
# Call the save method with the argument safe_serialization only if it's supported
save_method_signature = inspect.signature(save_method)
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
if save_method_accept_safe:
save_method(
os.path.join(save_directory, pipeline_component_name), safe_serialization=safe_serialization
)
else:
save_method(os.path.join(save_directory, pipeline_component_name))
def to(self, torch_device: Optional[Union[str, torch.device]] = None):
if torch_device is None:
return self
module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
if module.dtype == torch.float16 and str(torch_device) in ["cpu"]:
logger.warning(
"Pipelines loaded with `torch_dtype=torch.float16` cannot run with `cpu` device. It"
" is not recommended to move them to `cpu` as running them will fail. Please make"
" sure to use an accelerator to run the pipeline in inference, due to the lack of"
" support for`float16` operations on this device in PyTorch. Please, remove the"
" `torch_dtype=torch.float16` argument, or use another device for inference."
)
module.to(torch_device)
return self
@property
def device(self) -> torch.device:
r"""
Returns:
`torch.device`: The torch device on which the pipeline is located.
"""
module_names, _, _ = self.extract_init_dict(dict(self.config))
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
return module.device
return torch.device("cpu")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a PyTorch diffusion pipeline from pre-trained pipeline weights.
The pipeline is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated).
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
task.
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
weights are discarded.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
Can be either:
- A string, the *repo id* of a pretrained pipeline hosted inside a model repo on
https://huggingface.co/ Valid repo ids have to be located under a user or organization name, like
`CompVis/ldm-text2im-large-256`.
- A path to a *directory* containing pipeline weights saved using
[`~DiffusionPipeline.save_pretrained`], e.g., `./my_pipeline_directory/`.
torch_dtype (`str` or `torch.dtype`, *optional*):
Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
will be automatically derived from the model's weights.
custom_pipeline (`str`, *optional*):
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
Can be either:
- A string, the *repo id* of a custom pipeline hosted inside a model repo on
https://huggingface.co/. Valid repo ids have to be located under a user or organization name,
like `hf-internal-testing/diffusers-dummy-pipeline`.
<Tip>
It is required that the model repo has a file, called `pipeline.py` that defines the custom
pipeline.
</Tip>
- A string, the *file name* of a community pipeline hosted on GitHub under
https://github.com/huggingface/diffusers/tree/main/examples/community. Valid file names have to
match exactly the file name without `.py` located under the above link, *e.g.*
`clip_guided_stable_diffusion`.
<Tip>
Community pipelines are always loaded from the current `main` branch of GitHub.
</Tip>
- A path to a *directory* containing a custom pipeline, e.g., `./my_pipeline_directory/`.
<Tip>
It is required that the directory has a file, called `pipeline.py` that defines the custom
pipeline.
</Tip>
For more information on how to load and create custom pipelines, please have a look at [Loading and
Adding Custom
Pipelines](https://huggingface.co/docs/diffusers/using-diffusers/custom_pipeline_overview)
torch_dtype (`str` or `torch.dtype`, *optional*):
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
output_loading_info(`bool`, *optional*, defaults to `False`):
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
custom_revision (`str`, *optional*, defaults to `"main"` when loading from the Hub and to local version of `diffusers` when loading from GitHub):
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
`revision` when loading a custom pipeline from the Hub. It can be a diffusers version when loading a
custom pipeline from GitHub.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information. specify the folder name here.
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
A map that specifies where each submodule should go. It doesn't need to be refined to each
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
same device.
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
more information about each option see [designing a device
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
setting this argument to `True` will raise an error.
return_cached_folder (`bool`, *optional*, defaults to `False`):
If set to `True`, path to downloaded cached folder will be returned in addition to loaded pipeline.
kwargs (remaining dictionary of keyword arguments, *optional*):
Can be used to overwrite load - and saveable variables - *i.e.* the pipeline components - of the
specific pipeline class. The overwritten components are then directly passed to the pipelines
`__init__` method. See example below for more information.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models), *e.g.* `"runwayml/stable-diffusion-v1-5"`
</Tip> # limitations under the License.
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.
</Tip>
Examples:
```py
>>> from diffusers import DiffusionPipeline
>>> # Download pipeline from huggingface.co and cache.
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
>>> # Download pipeline that requires an authorization token
>>> # For more information on access tokens, please refer to this section
>>> # of the documentation](https://huggingface.co/docs/hub/security-tokens)
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> # Use a different scheduler
>>> from diffusers import LMSDiscreteScheduler
>>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> pipeline.scheduler = scheduler
```
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None)
sess_options = kwargs.pop("sess_options", None)
device_map = kwargs.pop("device_map", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
return_cached_folder = kwargs.pop("return_cached_folder", False)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
config_dict = cls.load_config(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
)
# make sure we only download sub-folders and `diffusers` filenames
folder_names = [k for k in config_dict.keys() if not k.startswith("_")]
allow_patterns = [os.path.join(k, "*") for k in folder_names]
allow_patterns += [WEIGHTS_NAME, SCHEDULER_CONFIG_NAME, CONFIG_NAME, ONNX_WEIGHTS_NAME, cls.config_name]
# make sure we don't download flax weights
ignore_patterns = ["*.msgpack"]
if custom_pipeline is not None:
allow_patterns += [CUSTOM_PIPELINE_FILE_NAME]
if cls != DiffusionPipeline:
requested_pipeline_class = cls.__name__
else:
requested_pipeline_class = config_dict.get("_class_name", cls.__name__)
user_agent = {"pipeline_class": requested_pipeline_class}
if custom_pipeline is not None and not custom_pipeline.endswith(".py"):
user_agent["custom_pipeline"] = custom_pipeline
user_agent = http_user_agent(user_agent)
if is_safetensors_available():
info = model_info(
pretrained_model_name_or_path,
use_auth_token=use_auth_token,
revision=revision,
)
if is_safetensors_compatible(info):
ignore_patterns.append("*.bin")
else:
ignore_patterns.append("*.safetensors")
# download all allow_patterns
cached_folder = snapshot_download(
pretrained_model_name_or_path,
cache_dir=cache_dir,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
user_agent=user_agent,
)
else:
cached_folder = pretrained_model_name_or_path
config_dict = cls.load_config(cached_folder)
# 2. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if custom_pipeline is not None:
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
else:
file_name = CUSTOM_PIPELINE_FILE_NAME
pipeline_class = get_class_from_dynamic_module(
custom_pipeline, module_file=file_name, cache_dir=cache_dir, revision=custom_revision
)
elif cls != DiffusionPipeline:
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# To be removed in 1.0.0
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
version.parse(config_dict["_diffusers_version"]).base_version
) <= version.parse("0.5.1"):
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
pipeline_class = StableDiffusionInpaintPipelineLegacy
deprecation_message = (
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
f" checkpoint {pretrained_model_name_or_path} to the format of"
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
)
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
# some modules can be passed directly to the init
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
# define init kwargs
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
# remove `null` components
def load_module(name, value):
if value[0] is None:
return False
if name in passed_class_obj and passed_class_obj[name] is None:
return False
return True
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
if len(unused_kwargs) > 0:
logger.warning(
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if device_map is not None and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `device_map=None`."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
if low_cpu_mem_usage is False and device_map is not None:
raise ValueError(
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
)
# import it here to avoid circular import
from diffusers import pipelines
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
# if the model is in a pipeline module, then we load it from the pipeline
if name in passed_class_obj:
# 1. check that passed_class_obj has correct parent class
if not is_pipeline_module:
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
expected_class_obj = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
expected_class_obj = class_candidate
if not issubclass(passed_class_obj[name].__class__, expected_class_obj):
raise ValueError(
f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be"
f" {expected_class_obj}"
)
else:
logger.warning(
f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it"
" has the correct type"
)
# set passed class object
loaded_sub_model = passed_class_obj[name]
elif is_pipeline_module:
pipeline_module = getattr(pipelines, library_name)
class_obj = getattr(pipeline_module, class_name)
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in importable_classes.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()}
if loaded_sub_model is None:
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if class_candidate is not None and issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
if load_method_name is None:
none_module = class_obj.__module__
is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith(
TRANSFORMERS_DUMMY_MODULES_FOLDER
)
if is_dummy_path and "dummy" in none_module:
# call class_obj for nice error message of missing requirements
class_obj()
raise ValueError(
f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have"
f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}."
)
load_method = getattr(class_obj, load_method_name)
loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
is_diffusers_model = issubclass(class_obj, diffusers.ModelMixin)
is_transformers_model = (
is_transformers_available()
and issubclass(class_obj, PreTrainedModel)
and version.parse(version.parse(transformers.__version__).base_version) >= version.parse("4.20.0")
)
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.
if is_diffusers_model or is_transformers_model:
loading_kwargs["device_map"] = device_map
loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs)
else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder, **loading_kwargs)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
# 4. Potentially add passed objects if expected
missing_modules = set(expected_modules) - set(init_kwargs.keys())
passed_modules = list(passed_class_obj.keys())
optional_modules = pipeline_class._optional_components
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
for module in missing_modules:
init_kwargs[module] = passed_class_obj.get(module, None)
elif len(missing_modules) > 0:
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
)
# 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
if return_cached_folder:
return model, cached_folder
return model
@staticmethod
def _get_signature_keys(obj):
parameters = inspect.signature(obj.__init__).parameters
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
expected_modules = set(required_parameters.keys()) - set(["self"])
return expected_modules, optional_parameters
@property
def components(self) -> Dict[str, Any]:
r"""
The `self.components` property can be useful to run different pipelines with the same weights and
configurations to not have to re-allocate memory.
Examples:
```py
>>> from diffusers import (
... StableDiffusionPipeline,
... StableDiffusionImg2ImgPipeline,
... StableDiffusionInpaintPipeline,
... )
>>> text2img = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> img2img = StableDiffusionImg2ImgPipeline(**text2img.components)
>>> inpaint = StableDiffusionInpaintPipeline(**text2img.components)
```
Returns:
A dictionary containing all the modules needed to initialize the pipeline.
"""
expected_modules, optional_parameters = self._get_signature_keys(self)
components = {
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
}
if set(components.keys()) != expected_modules:
raise ValueError(
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
f" {expected_modules} to be defined, but {components} are defined."
)
return components
@staticmethod
def numpy_to_pil(images):
"""
Convert a numpy image or a batch of images to a PIL image.
"""
if images.ndim == 3:
images = images[None, ...]
images = (images * 255).round().astype("uint8")
if images.shape[-1] == 1:
# special case for grayscale (single channel) images
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
else:
pil_images = [Image.fromarray(image) for image in images]
return pil_images
def progress_bar(self, iterable=None, total=None):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
elif not isinstance(self._progress_bar_config, dict):
raise ValueError(
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
def enable_xformers_memory_efficient_attention(self):
r"""
Enable memory efficient attention as implemented in xformers.
When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
time. Speed up at training time is not guaranteed.
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
"""
self.set_use_memory_efficient_attention_xformers(True)
def disable_xformers_memory_efficient_attention(self):
r"""
Disable memory efficient attention as implemented in xformers.
"""
self.set_use_memory_efficient_attention_xformers(False)
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
for child in module.children():
fn_recursive_set_mem_eff(child)
module_names, _, _ = self.extract_init_dict(dict(self.config))
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
self.set_attention_slice(slice_size)
def disable_attention_slicing(self): # NOTE: This file is deprecated and will be removed in a future version.
r""" # It only exists so that temporarely `from diffusers.pipelines import DiffusionPipeline` works
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
back to computing attention in one step.
"""
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)
def set_attention_slice(self, slice_size: Optional[int]): from .pipelines import DiffusionPipeline, ImagePipelineOutput # noqa: F401
module_names, _, _ = self.extract_init_dict(dict(self.config))
for module_name in module_names:
module = getattr(self, module_name)
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size)
...@@ -20,6 +20,7 @@ else: ...@@ -20,6 +20,7 @@ else:
from .ddpm import DDPMPipeline from .ddpm import DDPMPipeline
from .latent_diffusion import LDMSuperResolutionPipeline from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline from .latent_diffusion_uncond import LDMPipeline
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
from .pndm import PNDMPipeline from .pndm import PNDMPipeline
from .repaint import RePaintPipeline from .repaint import RePaintPipeline
from .score_sde_ve import ScoreSdeVePipeline from .score_sde_ve import ScoreSdeVePipeline
...@@ -62,6 +63,14 @@ else: ...@@ -62,6 +63,14 @@ else:
) )
from .vq_diffusion import VQDiffusionPipeline from .vq_diffusion import VQDiffusionPipeline
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_onnx_objects import * # noqa F403
else:
from .onnx_utils import OnnxRuntimeModel
try: try:
if not (is_torch_available() and is_transformers_available() and is_onnx_available()): if not (is_torch_available() and is_transformers_available() and is_onnx_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -84,6 +93,14 @@ except OptionalDependencyNotAvailable: ...@@ -84,6 +93,14 @@ except OptionalDependencyNotAvailable:
else: else:
from .stable_diffusion import StableDiffusionKDiffusionPipeline from .stable_diffusion import StableDiffusionKDiffusionPipeline
try:
if not is_flax_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_flax_objects import * # noqa F403
else:
from .pipeline_flax_utils import FlaxDiffusionPipeline
try: try:
if not (is_flax_available() and is_transformers_available()): if not (is_flax_available() and is_transformers_available()):
......
...@@ -23,7 +23,6 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer ...@@ -23,7 +23,6 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
...@@ -33,6 +32,7 @@ from ...schedulers import ( ...@@ -33,6 +32,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
) )
from ...utils import deprecate, logging from ...utils import deprecate, logging
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
......
...@@ -25,7 +25,6 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer ...@@ -25,7 +25,6 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import DiffusionPipeline
from ...schedulers import ( from ...schedulers import (
DDIMScheduler, DDIMScheduler,
DPMSolverMultistepScheduler, DPMSolverMultistepScheduler,
...@@ -35,6 +34,7 @@ from ...schedulers import ( ...@@ -35,6 +34,7 @@ from ...schedulers import (
PNDMScheduler, PNDMScheduler,
) )
from ...utils import PIL_INTERPOLATION, deprecate, logging from ...utils import PIL_INTERPOLATION, deprecate, logging
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
......
...@@ -22,8 +22,8 @@ import torch ...@@ -22,8 +22,8 @@ import torch
from PIL import Image from PIL import Image
from ...models import AutoencoderKL, UNet2DConditionModel from ...models import AutoencoderKL, UNet2DConditionModel
from ...pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
from ...schedulers import DDIMScheduler, DDPMScheduler from ...schedulers import DDIMScheduler, DDPMScheduler
from ..pipeline_utils import AudioPipelineOutput, BaseOutput, DiffusionPipeline, ImagePipelineOutput
from .mel import Mel from .mel import Mel
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment