Unverified Commit b41f809a authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[Kandinsky 3.0] Follow-up TODOs (#5944)

clean-up kendinsky 3.0
parent 0f55c17e
...@@ -42,7 +42,7 @@ if is_torch_available(): ...@@ -42,7 +42,7 @@ if is_torch_available():
_import_structure["unet_2d"] = ["UNet2DModel"] _import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"] _import_structure["unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"] _import_structure["vq_model"] = ["VQModel"]
...@@ -72,7 +72,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -72,7 +72,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_2d import UNet2DModel from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel from .vq_model import VQModel
......
...@@ -16,7 +16,7 @@ from typing import Callable, Optional, Union ...@@ -16,7 +16,7 @@ from typing import Callable, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import einsum, nn from torch import nn
from ..utils import USE_PEFT_BACKEND, deprecate, logging from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
...@@ -109,15 +109,17 @@ class Attention(nn.Module): ...@@ -109,15 +109,17 @@ class Attention(nn.Module):
residual_connection: bool = False, residual_connection: bool = False,
_from_deprecated_attn_block: bool = False, _from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None, processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
): ):
super().__init__() super().__init__()
self.inner_dim = dim_head * heads self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax self.upcast_softmax = upcast_softmax
self.rescale_output_factor = rescale_output_factor self.rescale_output_factor = rescale_output_factor
self.residual_connection = residual_connection self.residual_connection = residual_connection
self.dropout = dropout self.dropout = dropout
self.out_dim = out_dim if out_dim is not None else query_dim
# we make use of this private variable to know whether this class is loaded # we make use of this private variable to know whether this class is loaded
# with an deprecated state dict so that we can convert it on the fly # with an deprecated state dict so that we can convert it on the fly
...@@ -126,7 +128,7 @@ class Attention(nn.Module): ...@@ -126,7 +128,7 @@ class Attention(nn.Module):
self.scale_qk = scale_qk self.scale_qk = scale_qk
self.scale = dim_head**-0.5 if self.scale_qk else 1.0 self.scale = dim_head**-0.5 if self.scale_qk else 1.0
self.heads = heads self.heads = out_dim // dim_head if out_dim is not None else heads
# for slice_size > 0 the attention score computation # for slice_size > 0 the attention score computation
# is split across the batch axis to save memory # is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice` # You can set slice_size with `set_attention_slice`
...@@ -193,7 +195,7 @@ class Attention(nn.Module): ...@@ -193,7 +195,7 @@ class Attention(nn.Module):
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
self.to_out = nn.ModuleList([]) self.to_out = nn.ModuleList([])
self.to_out.append(linear_cls(self.inner_dim, query_dim, bias=out_bias)) self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout)) self.to_out.append(nn.Dropout(dropout))
# set attention processor # set attention processor
...@@ -2219,44 +2221,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -2219,44 +2221,6 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
return hidden_states return hidden_states
# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
# this way torch.compile and co. will work as well
class Kandi3AttnProcessor:
r"""
Default kandinsky3 proccesor for performing attention-related computations.
"""
@staticmethod
def _reshape(hid_states, h):
b, n, f = hid_states.shape
d = f // h
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)
def __call__(
self,
attn,
x,
context,
context_mask=None,
):
query = self._reshape(attn.to_q(x), h=attn.num_heads)
key = self._reshape(attn.to_k(context), h=attn.num_heads)
value = self._reshape(attn.to_v(context), h=attn.num_heads)
attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)
if context_mask is not None:
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)
out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
out = attn.to_out[0](out)
return out
LORA_ATTENTION_PROCESSORS = ( LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
...@@ -2282,7 +2246,6 @@ CROSS_ATTENTION_PROCESSORS = ( ...@@ -2282,7 +2246,6 @@ CROSS_ATTENTION_PROCESSORS = (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
Kandi3AttnProcessor,
) )
AttentionProcessor = Union[ AttentionProcessor = Union[
......
import math # Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Tuple, Union from typing import Dict, Tuple, Union
import torch import torch
import torch.nn.functional as F
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import nn from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput, logging from ..utils import BaseOutput, logging
from .attention_processor import AttentionProcessor, Kandi3AttnProcessor from .attention_processor import Attention, AttentionProcessor, AttnProcessor
from .embeddings import TimestepEmbedding from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin from .modeling_utils import ModelMixin
...@@ -22,36 +34,6 @@ class Kandinsky3UNetOutput(BaseOutput): ...@@ -22,36 +34,6 @@ class Kandinsky3UNetOutput(BaseOutput):
sample: torch.FloatTensor = None sample: torch.FloatTensor = None
# TODO(Yiyi): This class needs to be removed
def set_default_item(condition, item_1, item_2=None):
if condition:
return item_1
else:
return item_2
# TODO(Yiyi): This class needs to be removed
def set_default_layer(condition, layer_1, args_1=[], kwargs_1={}, layer_2=torch.nn.Identity, args_2=[], kwargs_2={}):
if condition:
return layer_1(*args_1, **kwargs_1)
else:
return layer_2(*args_2, **kwargs_2)
# TODO(Yiyi): This class should be removed and be replaced by Timesteps
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, type_tensor=None):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=x.device) * -emb)
emb = x[:, None] * emb[None, :]
return torch.cat((emb.sin(), emb.cos()), dim=-1)
class Kandinsky3EncoderProj(nn.Module): class Kandinsky3EncoderProj(nn.Module):
def __init__(self, encoder_hid_dim, cross_attention_dim): def __init__(self, encoder_hid_dim, cross_attention_dim):
super().__init__() super().__init__()
...@@ -87,9 +69,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin): ...@@ -87,9 +69,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
out_channels = in_channels out_channels = in_channels
init_channels = block_out_channels[0] // 2 init_channels = block_out_channels[0] // 2
# TODO(Yiyi): Should be replaced with Timesteps class -> make sure that results are the same self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
# self.time_proj = Timesteps(init_channels, flip_sin_to_cos=False, downscale_freq_shift=1)
self.time_proj = SinusoidalPosEmb(init_channels)
self.time_embedding = TimestepEmbedding( self.time_embedding = TimestepEmbedding(
init_channels, init_channels,
...@@ -106,7 +86,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin): ...@@ -106,7 +86,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
hidden_dims = [init_channels] + list(block_out_channels) hidden_dims = [init_channels] + list(block_out_channels)
in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:])) in_out_dims = list(zip(hidden_dims[:-1], hidden_dims[1:]))
text_dims = [set_default_item(is_exist, cross_attention_dim) for is_exist in add_cross_attention] text_dims = [cross_attention_dim if is_exist else None for is_exist in add_cross_attention]
num_blocks = len(block_out_channels) * [layers_per_block] num_blocks = len(block_out_channels) * [layers_per_block]
layer_params = [num_blocks, text_dims, add_self_attention] layer_params = [num_blocks, text_dims, add_self_attention]
rev_layer_params = map(reversed, layer_params) rev_layer_params = map(reversed, layer_params)
...@@ -118,7 +98,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin): ...@@ -118,7 +98,7 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
zip(in_out_dims, *layer_params) zip(in_out_dims, *layer_params)
): ):
down_sample = level != (self.num_levels - 1) down_sample = level != (self.num_levels - 1)
cat_dims.append(set_default_item(level != (self.num_levels - 1), out_dim, 0)) cat_dims.append(out_dim if level != (self.num_levels - 1) else 0)
self.down_blocks.append( self.down_blocks.append(
Kandinsky3DownSampleBlock( Kandinsky3DownSampleBlock(
in_dim, in_dim,
...@@ -223,18 +203,16 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin): ...@@ -223,18 +203,16 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
""" """
Disables custom attention processors and sets the default attention implementation. Disables custom attention processors and sets the default attention implementation.
""" """
self.set_attn_processor(Kandi3AttnProcessor()) self.set_attn_processor(AttnProcessor())
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"): if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value module.gradient_checkpointing = value
def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True): def forward(self, sample, timestep, encoder_hidden_states=None, encoder_attention_mask=None, return_dict=True):
# TODO(Yiyi): Clean up the following variables - these names should not be used if encoder_attention_mask is not None:
# but instead only the ones that we pass to forward encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
x = sample encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
context_mask = encoder_attention_mask
context = encoder_hidden_states
if not torch.is_tensor(timestep): if not torch.is_tensor(timestep):
dtype = torch.float32 if isinstance(timestep, float) else torch.int32 dtype = torch.float32 if isinstance(timestep, float) else torch.int32
...@@ -244,33 +222,33 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin): ...@@ -244,33 +222,33 @@ class Kandinsky3UNet(ModelMixin, ConfigMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = timestep.expand(sample.shape[0]) timestep = timestep.expand(sample.shape[0])
time_embed_input = self.time_proj(timestep).to(x.dtype) time_embed_input = self.time_proj(timestep).to(sample.dtype)
time_embed = self.time_embedding(time_embed_input) time_embed = self.time_embedding(time_embed_input)
context = self.encoder_hid_proj(context) encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
if context is not None: if encoder_hidden_states is not None:
time_embed = self.add_time_condition(time_embed, context, context_mask) time_embed = self.add_time_condition(time_embed, encoder_hidden_states, encoder_attention_mask)
hidden_states = [] hidden_states = []
x = self.conv_in(x) sample = self.conv_in(sample)
for level, down_sample in enumerate(self.down_blocks): for level, down_sample in enumerate(self.down_blocks):
x = down_sample(x, time_embed, context, context_mask) sample = down_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
if level != self.num_levels - 1: if level != self.num_levels - 1:
hidden_states.append(x) hidden_states.append(sample)
for level, up_sample in enumerate(self.up_blocks): for level, up_sample in enumerate(self.up_blocks):
if level != 0: if level != 0:
x = torch.cat([x, hidden_states.pop()], dim=1) sample = torch.cat([sample, hidden_states.pop()], dim=1)
x = up_sample(x, time_embed, context, context_mask) sample = up_sample(sample, time_embed, encoder_hidden_states, encoder_attention_mask)
x = self.conv_norm_out(x) sample = self.conv_norm_out(sample)
x = self.conv_act_out(x) sample = self.conv_act_out(sample)
x = self.conv_out(x) sample = self.conv_out(sample)
if not return_dict: if not return_dict:
return (x,) return (sample,)
return Kandinsky3UNetOutput(sample=x) return Kandinsky3UNetOutput(sample=sample)
class Kandinsky3UpSampleBlock(nn.Module): class Kandinsky3UpSampleBlock(nn.Module):
...@@ -290,7 +268,7 @@ class Kandinsky3UpSampleBlock(nn.Module): ...@@ -290,7 +268,7 @@ class Kandinsky3UpSampleBlock(nn.Module):
self_attention=True, self_attention=True,
): ):
super().__init__() super().__init__()
up_resolutions = [[None, set_default_item(up_sample, True), None, None]] + [[None] * 4] * (num_blocks - 1) up_resolutions = [[None, True if up_sample else None, None, None]] + [[None] * 4] * (num_blocks - 1)
hidden_channels = ( hidden_channels = (
[(in_channels + cat_dim, in_channels)] [(in_channels + cat_dim, in_channels)]
+ [(in_channels, in_channels)] * (num_blocks - 2) + [(in_channels, in_channels)] * (num_blocks - 2)
...@@ -303,27 +281,27 @@ class Kandinsky3UpSampleBlock(nn.Module): ...@@ -303,27 +281,27 @@ class Kandinsky3UpSampleBlock(nn.Module):
self.self_attention = self_attention self.self_attention = self_attention
self.context_dim = context_dim self.context_dim = context_dim
attentions.append( if self_attention:
set_default_layer( attentions.append(
self_attention, Kandinsky3AttentionBlock(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
Kandinsky3AttentionBlock,
(out_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
) )
) else:
attentions.append(nn.Identity())
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
resnets_in.append( resnets_in.append(
Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution) Kandinsky3ResNetBlock(in_channel, in_channel, time_embed_dim, groups, compression_ratio, up_resolution)
) )
attentions.append(
set_default_layer( if context_dim is not None:
context_dim is not None, attentions.append(
Kandinsky3AttentionBlock, Kandinsky3AttentionBlock(
(in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio), in_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
layer_2=nn.Identity, )
) )
) else:
attentions.append(nn.Identity())
resnets_out.append( resnets_out.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
) )
...@@ -367,29 +345,29 @@ class Kandinsky3DownSampleBlock(nn.Module): ...@@ -367,29 +345,29 @@ class Kandinsky3DownSampleBlock(nn.Module):
self.self_attention = self_attention self.self_attention = self_attention
self.context_dim = context_dim self.context_dim = context_dim
attentions.append( if self_attention:
set_default_layer( attentions.append(
self_attention, Kandinsky3AttentionBlock(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio)
Kandinsky3AttentionBlock,
(in_channels, time_embed_dim, None, groups, head_dim, expansion_ratio),
layer_2=nn.Identity,
) )
) else:
attentions.append(nn.Identity())
up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, set_default_item(down_sample, False), None]] up_resolutions = [[None] * 4] * (num_blocks - 1) + [[None, None, False if down_sample else None, None]]
hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1) hidden_channels = [(in_channels, out_channels)] + [(out_channels, out_channels)] * (num_blocks - 1)
for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions): for (in_channel, out_channel), up_resolution in zip(hidden_channels, up_resolutions):
resnets_in.append( resnets_in.append(
Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio) Kandinsky3ResNetBlock(in_channel, out_channel, time_embed_dim, groups, compression_ratio)
) )
attentions.append(
set_default_layer( if context_dim is not None:
context_dim is not None, attentions.append(
Kandinsky3AttentionBlock, Kandinsky3AttentionBlock(
(out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio), out_channel, time_embed_dim, context_dim, groups, head_dim, expansion_ratio
layer_2=nn.Identity, )
) )
) else:
attentions.append(nn.Identity())
resnets_out.append( resnets_out.append(
Kandinsky3ResNetBlock( Kandinsky3ResNetBlock(
out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution out_channel, out_channel, time_embed_dim, groups, compression_ratio, up_resolution
...@@ -431,68 +409,23 @@ class Kandinsky3ConditionalGroupNorm(nn.Module): ...@@ -431,68 +409,23 @@ class Kandinsky3ConditionalGroupNorm(nn.Module):
return x return x
# TODO(Yiyi): This class should ideally not even exist, it slows everything needlessly down. I'm pretty
# sure we can delete it and instead just pass an attention_mask
class Attention(nn.Module):
def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
super().__init__()
assert out_channels % head_dim == 0
self.num_heads = out_channels // head_dim
self.scale = head_dim**-0.5
# to_q
self.to_q = nn.Linear(in_channels, out_channels, bias=False)
# to_k
self.to_k = nn.Linear(context_dim, out_channels, bias=False)
# to_v
self.to_v = nn.Linear(context_dim, out_channels, bias=False)
processor = Kandi3AttnProcessor()
self.set_processor(processor)
# to_out
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(out_channels, out_channels, bias=False))
def set_processor(self, processor: "AttnProcessor"): # noqa: F821
# if current processor is in `self._modules` and if passed `processor` is not, we need to
# pop `processor` from `self._modules`
if (
hasattr(self, "processor")
and isinstance(self.processor, torch.nn.Module)
and not isinstance(processor, torch.nn.Module)
):
logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
self._modules.pop("processor")
self.processor = processor
def forward(self, x, context, context_mask=None, image_mask=None):
return self.processor(
self,
x,
context=context,
context_mask=context_mask,
)
class Kandinsky3Block(nn.Module): class Kandinsky3Block(nn.Module):
def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None): def __init__(self, in_channels, out_channels, time_embed_dim, kernel_size=3, norm_groups=32, up_resolution=None):
super().__init__() super().__init__()
self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim) self.group_norm = Kandinsky3ConditionalGroupNorm(norm_groups, in_channels, time_embed_dim)
self.activation = nn.SiLU() self.activation = nn.SiLU()
self.up_sample = set_default_layer( if up_resolution is not None and up_resolution:
up_resolution is not None and up_resolution, self.up_sample = nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
nn.ConvTranspose2d, else:
(in_channels, in_channels), self.up_sample = nn.Identity()
{"kernel_size": 2, "stride": 2},
)
padding = int(kernel_size > 1) padding = int(kernel_size > 1)
self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding) self.projection = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
self.down_sample = set_default_layer(
up_resolution is not None and not up_resolution, if up_resolution is not None and not up_resolution:
nn.Conv2d, self.down_sample = nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
(out_channels, out_channels), else:
{"kernel_size": 2, "stride": 2}, self.down_sample = nn.Identity()
)
def forward(self, x, time_embed): def forward(self, x, time_embed):
x = self.group_norm(x, time_embed) x = self.group_norm(x, time_embed)
...@@ -521,14 +454,18 @@ class Kandinsky3ResNetBlock(nn.Module): ...@@ -521,14 +454,18 @@ class Kandinsky3ResNetBlock(nn.Module):
) )
] ]
) )
self.shortcut_up_sample = set_default_layer( self.shortcut_up_sample = (
True in up_resolutions, nn.ConvTranspose2d, (in_channels, in_channels), {"kernel_size": 2, "stride": 2} nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2)
if True in up_resolutions
else nn.Identity()
) )
self.shortcut_projection = set_default_layer( self.shortcut_projection = (
in_channels != out_channels, nn.Conv2d, (in_channels, out_channels), {"kernel_size": 1} nn.Conv2d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()
) )
self.shortcut_down_sample = set_default_layer( self.shortcut_down_sample = (
False in up_resolutions, nn.Conv2d, (out_channels, out_channels), {"kernel_size": 2, "stride": 2} nn.Conv2d(out_channels, out_channels, kernel_size=2, stride=2)
if False in up_resolutions
else nn.Identity()
) )
def forward(self, x, time_embed): def forward(self, x, time_embed):
...@@ -546,9 +483,16 @@ class Kandinsky3ResNetBlock(nn.Module): ...@@ -546,9 +483,16 @@ class Kandinsky3ResNetBlock(nn.Module):
class Kandinsky3AttentionPooling(nn.Module): class Kandinsky3AttentionPooling(nn.Module):
def __init__(self, num_channels, context_dim, head_dim=64): def __init__(self, num_channels, context_dim, head_dim=64):
super().__init__() super().__init__()
self.attention = Attention(context_dim, num_channels, context_dim, head_dim) self.attention = Attention(
context_dim,
context_dim,
dim_head=head_dim,
out_dim=num_channels,
out_bias=False,
)
def forward(self, x, context, context_mask=None): def forward(self, x, context, context_mask=None):
context_mask = context_mask.to(dtype=context.dtype)
context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask) context = self.attention(context.mean(dim=1, keepdim=True), context, context_mask)
return x + context.squeeze(1) return x + context.squeeze(1)
...@@ -557,7 +501,13 @@ class Kandinsky3AttentionBlock(nn.Module): ...@@ -557,7 +501,13 @@ class Kandinsky3AttentionBlock(nn.Module):
def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4): def __init__(self, num_channels, time_embed_dim, context_dim=None, norm_groups=32, head_dim=64, expansion_ratio=4):
super().__init__() super().__init__()
self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) self.in_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
self.attention = Attention(num_channels, num_channels, context_dim or num_channels, head_dim) self.attention = Attention(
num_channels,
context_dim or num_channels,
dim_head=head_dim,
out_dim=num_channels,
out_bias=False,
)
hidden_channels = expansion_ratio * num_channels hidden_channels = expansion_ratio * num_channels
self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim) self.out_norm = Kandinsky3ConditionalGroupNorm(norm_groups, num_channels, time_embed_dim)
...@@ -572,14 +522,10 @@ class Kandinsky3AttentionBlock(nn.Module): ...@@ -572,14 +522,10 @@ class Kandinsky3AttentionBlock(nn.Module):
out = self.in_norm(x, time_embed) out = self.in_norm(x, time_embed)
out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1) out = out.reshape(x.shape[0], -1, height * width).permute(0, 2, 1)
context = context if context is not None else out context = context if context is not None else out
if context_mask is not None:
context_mask = context_mask.to(dtype=context.dtype)
if image_mask is not None: out = self.attention(out, context, context_mask)
mask_height, mask_width = image_mask.shape[-2:]
kernel_size = (mask_height // height, mask_width // width)
image_mask = F.max_pool2d(image_mask, kernel_size, kernel_size)
image_mask = image_mask.reshape(image_mask.shape[0], -1)
out = self.attention(out, context, context_mask, image_mask)
out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width) out = out.permute(0, 2, 1).unsqueeze(-1).reshape(out.shape[0], -1, height, width)
x = x + out x = x + out
......
...@@ -21,8 +21,8 @@ except OptionalDependencyNotAvailable: ...@@ -21,8 +21,8 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"] _import_structure["pipeline_kandinsky3"] = ["Kandinsky3Pipeline"]
_import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"] _import_structure["pipeline_kandinsky3_img2img"] = ["Kandinsky3Img2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
...@@ -33,8 +33,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -33,8 +33,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * from ...utils.dummy_torch_and_transformers_objects import *
else: else:
from .kandinsky3_pipeline import Kandinsky3Pipeline from .pipeline_kandinsky3 import Kandinsky3Pipeline
from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline from .pipeline_kandinsky3_img2img import Kandinsky3Img2ImgPipeline
else: else:
import sys import sys
......
from typing import Callable, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import torch import torch
from transformers import T5EncoderModel, T5Tokenizer from transformers import T5EncoderModel, T5Tokenizer
...@@ -7,8 +7,10 @@ from ...loaders import LoraLoaderMixin ...@@ -7,8 +7,10 @@ from ...loaders import LoraLoaderMixin
from ...models import Kandinsky3UNet, VQModel from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
deprecate,
is_accelerate_available, is_accelerate_available,
logging, logging,
replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -16,6 +18,23 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput ...@@ -16,6 +18,23 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers import AutoPipelineForText2Image
>>> import torch
>>> pipe = AutoPipelineForText2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
>>> generator = torch.Generator(device="cpu").manual_seed(0)
>>> image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]
```
"""
def downscale_height_and_width(height, width, scale_factor=8): def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2 new_height = height // scale_factor**2
...@@ -29,6 +48,13 @@ def downscale_height_and_width(height, width, scale_factor=8): ...@@ -29,6 +48,13 @@ def downscale_height_and_width(height, width, scale_factor=8):
class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->unet->movq" model_cpu_offload_seq = "text_encoder->unet->movq"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"negative_attention_mask",
"attention_mask",
]
def __init__( def __init__(
self, self,
...@@ -50,7 +76,7 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -50,7 +76,7 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
else: else:
raise ImportError("Please install accelerate via `pip install accelerate`") raise ImportError("Please install accelerate via `pip install accelerate`")
for model in [self.text_encoder, self.unet]: for model in [self.text_encoder, self.unet, self.movq]:
if model is not None: if model is not None:
remove_hook_from_module(model, recurse=True) remove_hook_from_module(model, recurse=True)
...@@ -77,6 +103,8 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -77,6 +103,8 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
_cut_context=False, _cut_context=False,
attention_mask: Optional[torch.FloatTensor] = None,
negative_attention_mask: Optional[torch.FloatTensor] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -101,6 +129,10 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -101,6 +129,10 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
negative_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
""" """
if prompt is not None and negative_prompt is not None: if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt): if type(prompt) is not type(negative_prompt):
...@@ -228,14 +260,21 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -228,14 +260,21 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
attention_mask=None,
negative_attention_mask=None,
): ):
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
...@@ -262,8 +301,42 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -262,8 +301,42 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
if negative_prompt_embeds is not None and negative_attention_mask is None:
raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`")
if negative_prompt_embeds is not None and negative_attention_mask is not None:
if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape:
raise ValueError(
"`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but"
f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`"
f" {negative_attention_mask.shape}."
)
if prompt_embeds is not None and attention_mask is None:
raise ValueError("Please provide `attention_mask` along with `prompt_embeds`")
if prompt_embeds is not None and attention_mask is not None:
if prompt_embeds.shape[:2] != attention_mask.shape:
raise ValueError(
"`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`"
f" {attention_mask.shape}."
)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
...@@ -276,11 +349,14 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -276,11 +349,14 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
negative_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
latents=None, latents=None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
): ):
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -324,6 +400,10 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -324,6 +400,10 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
negative_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -343,12 +423,53 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -343,12 +423,53 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
""" """
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
cut_context = True cut_context = True
device = self._execution_device device = self._execution_device
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) self.check_inputs(
prompt,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
attention_mask,
negative_attention_mask,
)
self._guidance_scale = guidance_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -357,24 +478,21 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -357,24 +478,21 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
else: else:
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
prompt, prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
_cut_context=cut_context, _cut_context=cut_context,
attention_mask=attention_mask,
negative_attention_mask=negative_attention_mask,
) )
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool()
# 4. Prepare timesteps # 4. Prepare timesteps
...@@ -397,11 +515,11 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -397,11 +515,11 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
self.text_encoder_offload_hook.offload() self.text_encoder_offload_hook.offload()
# 7. Denoising loop # 7. Denoising loop
# TODO(Yiyi): Correct the following line and use correctly num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
...@@ -412,7 +530,7 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -412,7 +530,7 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
return_dict=False, return_dict=False,
)[0] )[0]
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
...@@ -425,26 +543,45 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -425,26 +543,45 @@ class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
latents, latents,
generator=generator, generator=generator,
).prev_sample ).prev_sample
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing if callback_on_step_end is not None:
image = self.movq.decode(latents, force_not_quantize=True)["sample"] callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
attention_mask = callback_outputs.pop("attention_mask", attention_mask)
negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask)
if output_type not in ["pt", "np", "pil"]: if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError( raise ValueError(
f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}" f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
) )
if output_type in ["np", "pil"]: if not output_type == "latent":
image = image * 0.5 + 0.5 image = self.movq.decode(latents, force_not_quantize=True)["sample"]
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
else:
image = latents
if output_type == "pil": self.maybe_free_model_hooks()
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
import inspect import inspect
from typing import Callable, List, Optional, Union from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL import PIL
...@@ -11,8 +11,10 @@ from ...loaders import LoraLoaderMixin ...@@ -11,8 +11,10 @@ from ...loaders import LoraLoaderMixin
from ...models import Kandinsky3UNet, VQModel from ...models import Kandinsky3UNet, VQModel
from ...schedulers import DDPMScheduler from ...schedulers import DDPMScheduler
from ...utils import ( from ...utils import (
deprecate,
is_accelerate_available, is_accelerate_available,
logging, logging,
replace_example_docstring,
) )
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
...@@ -20,6 +22,24 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput ...@@ -20,6 +22,24 @@ from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers import AutoPipelineForImage2Image
>>> from diffusers.utils import load_image
>>> import torch
>>> pipe = AutoPipelineForImage2Image.from_pretrained("kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16)
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A painting of the inside of a subway train with tiny raccoons."
>>> image = load_image("https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png")
>>> generator = torch.Generator(device="cpu").manual_seed(0)
>>> image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]
```
"""
def downscale_height_and_width(height, width, scale_factor=8): def downscale_height_and_width(height, width, scale_factor=8):
new_height = height // scale_factor**2 new_height = height // scale_factor**2
...@@ -40,7 +60,14 @@ def prepare_image(pil_image): ...@@ -40,7 +60,14 @@ def prepare_image(pil_image):
class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
model_cpu_offload_seq = "text_encoder->unet->movq" model_cpu_offload_seq = "text_encoder->movq->unet->movq"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"negative_attention_mask",
"attention_mask",
]
def __init__( def __init__(
self, self,
...@@ -99,6 +126,8 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -99,6 +126,8 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
_cut_context=False, _cut_context=False,
attention_mask: Optional[torch.FloatTensor] = None,
negative_attention_mask: Optional[torch.FloatTensor] = None,
): ):
r""" r"""
Encodes the prompt into text encoder hidden states. Encodes the prompt into text encoder hidden states.
...@@ -123,6 +152,10 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -123,6 +152,10 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
negative_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
""" """
if prompt is not None and negative_prompt is not None: if prompt is not None and negative_prompt is not None:
if type(prompt) is not type(negative_prompt): if type(prompt) is not type(negative_prompt):
...@@ -299,15 +332,23 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -299,15 +332,23 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
negative_prompt=None, negative_prompt=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None, negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
attention_mask=None,
negative_attention_mask=None,
): ):
if (callback_steps is None) or ( if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError( raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type" f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}." f" {type(callback_steps)}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -334,7 +375,42 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -334,7 +375,42 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
f" {negative_prompt_embeds.shape}." f" {negative_prompt_embeds.shape}."
) )
if negative_prompt_embeds is not None and negative_attention_mask is None:
raise ValueError("Please provide `negative_attention_mask` along with `negative_prompt_embeds`")
if negative_prompt_embeds is not None and negative_attention_mask is not None:
if negative_prompt_embeds.shape[:2] != negative_attention_mask.shape:
raise ValueError(
"`negative_prompt_embeds` and `negative_attention_mask` must have the same batch_size and token length when passed directly, but"
f" got: `negative_prompt_embeds` {negative_prompt_embeds.shape[:2]} != `negative_attention_mask`"
f" {negative_attention_mask.shape}."
)
if prompt_embeds is not None and attention_mask is None:
raise ValueError("Please provide `attention_mask` along with `prompt_embeds`")
if prompt_embeds is not None and attention_mask is not None:
if prompt_embeds.shape[:2] != attention_mask.shape:
raise ValueError(
"`prompt_embeds` and `attention_mask` must have the same batch_size and token length when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape[:2]} != `attention_mask`"
f" {attention_mask.shape}."
)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
...@@ -347,15 +423,117 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -347,15 +423,117 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
negative_attention_mask: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_steps: int = 1, callback_on_step_end_tensor_inputs: List[str] = ["latents"],
latents=None, **kwargs,
): ):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, or tensor representing an image batch, that will be used as the starting point for the
process.
strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 3.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask. Must provide if passing `prompt_embeds` directly.
negative_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated negative attention mask. Must provide if passing `negative_prompt_embeds` directly.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
cut_context = True cut_context = True
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) self.check_inputs(
prompt,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
attention_mask,
negative_attention_mask,
)
self._guidance_scale = guidance_scale
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -366,24 +544,21 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -366,24 +544,21 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
device = self._execution_device device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt # 3. Encode input prompt
prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt( prompt_embeds, negative_prompt_embeds, attention_mask, negative_attention_mask = self.encode_prompt(
prompt, prompt,
do_classifier_free_guidance, self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt, num_images_per_prompt=num_images_per_prompt,
device=device, device=device,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
_cut_context=cut_context, _cut_context=cut_context,
attention_mask=attention_mask,
negative_attention_mask=negative_attention_mask,
) )
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool() attention_mask = torch.cat([negative_attention_mask, attention_mask]).bool()
if not isinstance(image, list): if not isinstance(image, list):
...@@ -409,11 +584,11 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -409,11 +584,11 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
self.text_encoder_offload_hook.offload() self.text_encoder_offload_hook.offload()
# 7. Denoising loop # 7. Denoising loop
# TODO(Yiyi): Correct the following line and use correctly num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
# num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# predict the noise residual # predict the noise residual
noise_pred = self.unet( noise_pred = self.unet(
...@@ -422,7 +597,7 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -422,7 +597,7 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
encoder_attention_mask=attention_mask, encoder_attention_mask=attention_mask,
)[0] )[0]
if do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond noise_pred = (guidance_scale + 1.0) * noise_pred_text - guidance_scale * noise_pred_uncond
...@@ -434,25 +609,44 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -434,25 +609,44 @@ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
latents, latents,
generator=generator, generator=generator,
).prev_sample ).prev_sample
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type not in ["pt", "np", "pil"]: if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
attention_mask = callback_outputs.pop("attention_mask", attention_mask)
negative_attention_mask = callback_outputs.pop("negative_attention_mask", negative_attention_mask)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
# post-processing
if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError( raise ValueError(
f"Only the output types `pt`, `pil` and `np` are supported not output_type={output_type}" f"Only the output types `pt`, `pil`, `np` and `latent` are supported not output_type={output_type}"
) )
if not output_type == "latent":
image = self.movq.decode(latents, force_not_quantize=True)["sample"]
if output_type in ["np", "pil"]:
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type in ["np", "pil"]: if output_type == "pil":
image = image * 0.5 + 0.5 image = self.numpy_to_pil(image)
image = image.clamp(0, 1) else:
image = image.cpu().permute(0, 2, 3, 1).float().numpy() image = latents
if output_type == "pil": self.maybe_free_model_hooks()
image = self.numpy_to_pil(image)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -165,10 +165,6 @@ class Kandinsky3PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -165,10 +165,6 @@ class Kandinsky3PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2) super().test_inference_batch_single_identical(expected_max_diff=1e-2)
def test_model_cpu_offload_forward_pass(self):
# TODO(Yiyi) - this test should work, skipped for time reasons for now
pass
@slow @slow
@require_torch_gpu @require_torch_gpu
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import random
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoPipelineForImage2Image,
Kandinsky3Img2ImgPipeline,
Kandinsky3UNet,
VQModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
load_image,
require_torch_gpu,
slow,
)
from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class Kandinsky3Img2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Kandinsky3Img2ImgPipeline
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"}
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
test_xformers_attention = False
required_optional_params = frozenset(
[
"num_inference_steps",
"num_images_per_prompt",
"generator",
"output_type",
"return_dict",
]
)
@property
def dummy_movq_kwargs(self):
return {
"block_out_channels": [32, 64],
"down_block_types": ["DownEncoderBlock2D", "AttnDownEncoderBlock2D"],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 1,
"norm_num_groups": 8,
"norm_type": "spatial",
"num_vq_embeddings": 12,
"out_channels": 3,
"up_block_types": [
"AttnUpDecoderBlock2D",
"UpDecoderBlock2D",
],
"vq_embed_dim": 4,
}
@property
def dummy_movq(self):
torch.manual_seed(0)
model = VQModel(**self.dummy_movq_kwargs)
return model
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = Kandinsky3UNet(
in_channels=4,
time_embedding_dim=4,
groups=2,
attention_head_dim=4,
layers_per_block=3,
block_out_channels=(32, 64),
cross_attention_dim=4,
encoder_hid_dim=32,
)
scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
thresholding=False,
)
torch.manual_seed(0)
movq = self.dummy_movq
torch.manual_seed(0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"unet": unet,
"scheduler": scheduler,
"movq": movq,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
# create init_image
image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device)
image = image.cpu().permute(0, 2, 3, 1)[0]
init_image = Image.fromarray(np.uint8(image)).convert("RGB")
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"image": init_image,
"generator": generator,
"strength": 0.75,
"num_inference_steps": 10,
"guidance_scale": 6.0,
"output_type": "np",
}
return inputs
def test_kandinsky3_img2img(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array(
[0.576259, 0.6132097, 0.41703486, 0.603196, 0.62062526, 0.4655338, 0.5434324, 0.5660727, 0.65433365]
)
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
@slow
@require_torch_gpu
class Kandinsky3Img2ImgPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_kandinskyV3_img2img(self):
pipe = AutoPipelineForImage2Image.from_pretrained(
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png"
)
w, h = 512, 512
image = image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
prompt = "A painting of the inside of a subway train with tiny raccoons."
image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]
assert image.size == (512, 512)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/i2i.png"
)
image_processor = VaeImageProcessor()
image_np = image_processor.pil_to_numpy(image)
expected_image_np = image_processor.pil_to_numpy(expected_image)
self.assertTrue(np.allclose(image_np, expected_image_np, atol=5e-2))
...@@ -377,6 +377,10 @@ class PipelineTesterMixin: ...@@ -377,6 +377,10 @@ class PipelineTesterMixin:
with CaptureLogger(logger) as cap_logger: with CaptureLogger(logger) as cap_logger:
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
for name in pipe_loaded.components.keys(): for name in pipe_loaded.components.keys():
if name not in pipe_loaded._optional_components: if name not in pipe_loaded._optional_components:
assert name in str(cap_logger) assert name in str(cap_logger)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment