Unverified Commit b978334d authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[@cene555][Kandinsky 3.0] Add Kandinsky 3.0 (#5913)

* finalize

* finalize

* finalize

* add slow test

* add slow test

* add slow test

* Fix more

* add slow test

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* fix more

* Better

* Fix more

* Fix more

* add slow test

* Add auto pipelines

* add slow test

* Add all

* add slow test

* add slow test

* add slow test

* add slow test

* add slow test

* Apply suggestions from code review

* add slow test

* add slow test
parent e5f232f7
...@@ -278,6 +278,8 @@ ...@@ -278,6 +278,8 @@
title: Kandinsky 2.1 title: Kandinsky 2.1
- local: api/pipelines/kandinsky_v22 - local: api/pipelines/kandinsky_v22
title: Kandinsky 2.2 title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/latent_consistency_models - local: api/pipelines/latent_consistency_models
title: Latent Consistency Models title: Latent Consistency Models
- local: api/pipelines/latent_diffusion - local: api/pipelines/latent_diffusion
......
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Kandinsky 3
TODO
## Kandinsky3Pipeline
[[autodoc]] Kandinsky3Pipeline
- all
- __call__
## Kandinsky3Img2ImgPipeline
[[autodoc]] Kandinsky3Img2ImgPipeline
- all
- __call__
#!/usr/bin/env python3
import argparse
import fnmatch
from safetensors.torch import load_file
from diffusers import Kandinsky3UNet
MAPPING = {
"to_time_embed.1": "time_embedding.linear_1",
"to_time_embed.3": "time_embedding.linear_2",
"in_layer": "conv_in",
"out_layer.0": "conv_norm_out",
"out_layer.2": "conv_out",
"down_samples": "down_blocks",
"up_samples": "up_blocks",
"projection_lin": "encoder_hid_proj.projection_linear",
"projection_ln": "encoder_hid_proj.projection_norm",
"feature_pooling": "add_time_condition",
"to_query": "to_q",
"to_key": "to_k",
"to_value": "to_v",
"output_layer": "to_out.0",
"self_attention_block": "attentions.0",
}
DYNAMIC_MAP = {
"resnet_attn_blocks.*.0": "resnets_in.*",
"resnet_attn_blocks.*.1": ("attentions.*", 1),
"resnet_attn_blocks.*.2": "resnets_out.*",
}
# MAPPING = {}
def convert_state_dict(unet_state_dict):
"""
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
Args:
unet_model (torch.nn.Module): The original U-Net model.
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
Returns:
OrderedDict: The converted state dictionary.
"""
# Example of renaming logic (this will vary based on your model's architecture)
converted_state_dict = {}
for key in unet_state_dict:
new_key = key
for pattern, new_pattern in MAPPING.items():
new_key = new_key.replace(pattern, new_pattern)
for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
has_matched = False
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
if isinstance(dyn_new_pattern, tuple):
new_star = star + dyn_new_pattern[-1]
dyn_new_pattern = dyn_new_pattern[0]
else:
new_star = star
pattern = dyn_pattern.replace("*", str(star))
new_pattern = dyn_new_pattern.replace("*", str(new_star))
new_key = new_key.replace(pattern, new_pattern)
has_matched = True
converted_state_dict[new_key] = unet_state_dict[key]
return converted_state_dict
def main(model_path, output_path):
# Load your original U-Net model
unet_state_dict = load_file(model_path)
# Initialize your Kandinsky3UNet model
config = {}
# Convert the state dict
converted_state_dict = convert_state_dict(unet_state_dict)
unet = Kandinsky3UNet(config)
unet.load_state_dict(converted_state_dict)
unet.save_pretrained(output_path)
print(f"Converted model saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
args = parser.parse_args()
main(args.model_path, args.output_path)
...@@ -79,6 +79,7 @@ else: ...@@ -79,6 +79,7 @@ else:
"AutoencoderTiny", "AutoencoderTiny",
"ConsistencyDecoderVAE", "ConsistencyDecoderVAE",
"ControlNetModel", "ControlNetModel",
"Kandinsky3UNet",
"ModelMixin", "ModelMixin",
"MotionAdapter", "MotionAdapter",
"MultiAdapter", "MultiAdapter",
...@@ -214,6 +215,8 @@ else: ...@@ -214,6 +215,8 @@ else:
"IFPipeline", "IFPipeline",
"IFSuperResolutionPipeline", "IFSuperResolutionPipeline",
"ImageTextPipelineOutput", "ImageTextPipelineOutput",
"Kandinsky3Img2ImgPipeline",
"Kandinsky3Pipeline",
"KandinskyCombinedPipeline", "KandinskyCombinedPipeline",
"KandinskyImg2ImgCombinedPipeline", "KandinskyImg2ImgCombinedPipeline",
"KandinskyImg2ImgPipeline", "KandinskyImg2ImgPipeline",
...@@ -446,6 +449,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -446,6 +449,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderTiny, AutoencoderTiny,
ConsistencyDecoderVAE, ConsistencyDecoderVAE,
ControlNetModel, ControlNetModel,
Kandinsky3UNet,
ModelMixin, ModelMixin,
MotionAdapter, MotionAdapter,
MultiAdapter, MultiAdapter,
...@@ -560,6 +564,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -560,6 +564,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
IFPipeline, IFPipeline,
IFSuperResolutionPipeline, IFSuperResolutionPipeline,
ImageTextPipelineOutput, ImageTextPipelineOutput,
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
KandinskyCombinedPipeline, KandinskyCombinedPipeline,
KandinskyImg2ImgCombinedPipeline, KandinskyImg2ImgCombinedPipeline,
KandinskyImg2ImgPipeline, KandinskyImg2ImgPipeline,
......
...@@ -36,6 +36,7 @@ if is_torch_available(): ...@@ -36,6 +36,7 @@ if is_torch_available():
_import_structure["unet_2d"] = ["UNet2DModel"] _import_structure["unet_2d"] = ["UNet2DModel"]
_import_structure["unet_2d_condition"] = ["UNet2DConditionModel"] _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"] _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["vq_model"] = ["VQModel"] _import_structure["vq_model"] = ["VQModel"]
...@@ -63,6 +64,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -63,6 +64,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .unet_2d import UNet2DModel from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_condition import UNet3DConditionModel from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel from .unet_motion_model import MotionAdapter, UNetMotionModel
from .vq_model import VQModel from .vq_model import VQModel
......
...@@ -16,7 +16,7 @@ from typing import Callable, Optional, Union ...@@ -16,7 +16,7 @@ from typing import Callable, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import einsum, nn
from ..utils import USE_PEFT_BACKEND, deprecate, logging from ..utils import USE_PEFT_BACKEND, deprecate, logging
from ..utils.import_utils import is_xformers_available from ..utils.import_utils import is_xformers_available
...@@ -2219,6 +2219,44 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): ...@@ -2219,6 +2219,44 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
return hidden_states return hidden_states
# TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
# this way torch.compile and co. will work as well
class Kandi3AttnProcessor:
r"""
Default kandinsky3 proccesor for performing attention-related computations.
"""
@staticmethod
def _reshape(hid_states, h):
b, n, f = hid_states.shape
d = f // h
return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)
def __call__(
self,
attn,
x,
context,
context_mask=None,
):
query = self._reshape(attn.to_q(x), h=attn.num_heads)
key = self._reshape(attn.to_k(context), h=attn.num_heads)
value = self._reshape(attn.to_v(context), h=attn.num_heads)
attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)
if context_mask is not None:
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = context_mask.unsqueeze(1).unsqueeze(1)
attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)
out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
out = attn.to_out[0](out)
return out
LORA_ATTENTION_PROCESSORS = ( LORA_ATTENTION_PROCESSORS = (
LoRAAttnProcessor, LoRAAttnProcessor,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
...@@ -2244,6 +2282,7 @@ CROSS_ATTENTION_PROCESSORS = ( ...@@ -2244,6 +2282,7 @@ CROSS_ATTENTION_PROCESSORS = (
LoRAXFormersAttnProcessor, LoRAXFormersAttnProcessor,
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
Kandi3AttnProcessor,
) )
AttentionProcessor = Union[ AttentionProcessor = Union[
......
This diff is collapsed.
...@@ -110,6 +110,7 @@ else: ...@@ -110,6 +110,7 @@ else:
"KandinskyV22PriorEmb2EmbPipeline", "KandinskyV22PriorEmb2EmbPipeline",
"KandinskyV22PriorPipeline", "KandinskyV22PriorPipeline",
] ]
_import_structure["kandinsky3"] = ["Kandinsky3Img2ImgPipeline", "Kandinsky3Pipeline"]
_import_structure["latent_consistency_models"] = [ _import_structure["latent_consistency_models"] = [
"LatentConsistencyModelImg2ImgPipeline", "LatentConsistencyModelImg2ImgPipeline",
"LatentConsistencyModelPipeline", "LatentConsistencyModelPipeline",
...@@ -338,6 +339,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -338,6 +339,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KandinskyV22PriorEmb2EmbPipeline, KandinskyV22PriorEmb2EmbPipeline,
KandinskyV22PriorPipeline, KandinskyV22PriorPipeline,
) )
from .kandinsky3 import (
Kandinsky3Img2ImgPipeline,
Kandinsky3Pipeline,
)
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .latent_diffusion import LDMTextToImagePipeline from .latent_diffusion import LDMTextToImagePipeline
from .musicldm import MusicLDMPipeline from .musicldm import MusicLDMPipeline
......
...@@ -42,6 +42,7 @@ from .kandinsky2_2 import ( ...@@ -42,6 +42,7 @@ from .kandinsky2_2 import (
KandinskyV22InpaintPipeline, KandinskyV22InpaintPipeline,
KandinskyV22Pipeline, KandinskyV22Pipeline,
) )
from .kandinsky3 import Kandinsky3Img2ImgPipeline, Kandinsky3Pipeline
from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline from .latent_consistency_models import LatentConsistencyModelImg2ImgPipeline, LatentConsistencyModelPipeline
from .pixart_alpha import PixArtAlphaPipeline from .pixart_alpha import PixArtAlphaPipeline
from .stable_diffusion import ( from .stable_diffusion import (
...@@ -64,6 +65,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -64,6 +65,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("if", IFPipeline), ("if", IFPipeline),
("kandinsky", KandinskyCombinedPipeline), ("kandinsky", KandinskyCombinedPipeline),
("kandinsky22", KandinskyV22CombinedPipeline), ("kandinsky22", KandinskyV22CombinedPipeline),
("kandinsky3", Kandinsky3Pipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
("wuerstchen", WuerstchenCombinedPipeline), ("wuerstchen", WuerstchenCombinedPipeline),
...@@ -79,6 +81,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -79,6 +81,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("if", IFImg2ImgPipeline), ("if", IFImg2ImgPipeline),
("kandinsky", KandinskyImg2ImgCombinedPipeline), ("kandinsky", KandinskyImg2ImgCombinedPipeline),
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline), ("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
("kandinsky3", Kandinsky3Img2ImgPipeline),
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline), ("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline), ("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
("lcm", LatentConsistencyModelImg2ImgPipeline), ("lcm", LatentConsistencyModelImg2ImgPipeline),
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["kandinsky3_pipeline"] = ["Kandinsky3Pipeline"]
_import_structure["kandinsky3img2img_pipeline"] = ["Kandinsky3Img2ImgPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .kandinsky3_pipeline import Kandinsky3Pipeline
from .kandinsky3img2img_pipeline import Kandinsky3Img2ImgPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
This diff is collapsed.
...@@ -77,6 +77,21 @@ class ControlNetModel(metaclass=DummyObject): ...@@ -77,6 +77,21 @@ class ControlNetModel(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class Kandinsky3UNet(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ModelMixin(metaclass=DummyObject): class ModelMixin(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -242,6 +242,36 @@ class ImageTextPipelineOutput(metaclass=DummyObject): ...@@ -242,6 +242,36 @@ class ImageTextPipelineOutput(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class Kandinsky3Img2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class Kandinsky3Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class KandinskyCombinedPipeline(metaclass=DummyObject): class KandinskyCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
#!/usr/bin/env python3
import argparse
import fnmatch
from safetensors.torch import load_file
from diffusers import Kandinsky3UNet
MAPPING = {
"to_time_embed.1": "time_embedding.linear_1",
"to_time_embed.3": "time_embedding.linear_2",
"in_layer": "conv_in",
"out_layer.0": "conv_norm_out",
"out_layer.2": "conv_out",
"down_samples": "down_blocks",
"up_samples": "up_blocks",
"projection_lin": "encoder_hid_proj.projection_linear",
"projection_ln": "encoder_hid_proj.projection_norm",
"feature_pooling": "add_time_condition",
"to_query": "to_q",
"to_key": "to_k",
"to_value": "to_v",
"output_layer": "to_out.0",
"self_attention_block": "attentions.0",
}
DYNAMIC_MAP = {
"resnet_attn_blocks.*.0": "resnets_in.*",
"resnet_attn_blocks.*.1": ("attentions.*", 1),
"resnet_attn_blocks.*.2": "resnets_out.*",
}
# MAPPING = {}
def convert_state_dict(unet_state_dict):
"""
Convert the state dict of a U-Net model to match the key format expected by Kandinsky3UNet model.
Args:
unet_model (torch.nn.Module): The original U-Net model.
unet_kandi3_model (torch.nn.Module): The Kandinsky3UNet model to match keys with.
Returns:
OrderedDict: The converted state dictionary.
"""
# Example of renaming logic (this will vary based on your model's architecture)
converted_state_dict = {}
for key in unet_state_dict:
new_key = key
for pattern, new_pattern in MAPPING.items():
new_key = new_key.replace(pattern, new_pattern)
for dyn_pattern, dyn_new_pattern in DYNAMIC_MAP.items():
has_matched = False
if fnmatch.fnmatch(new_key, f"*.{dyn_pattern}.*") and not has_matched:
star = int(new_key.split(dyn_pattern.split(".")[0])[-1].split(".")[1])
if isinstance(dyn_new_pattern, tuple):
new_star = star + dyn_new_pattern[-1]
dyn_new_pattern = dyn_new_pattern[0]
else:
new_star = star
pattern = dyn_pattern.replace("*", str(star))
new_pattern = dyn_new_pattern.replace("*", str(new_star))
new_key = new_key.replace(pattern, new_pattern)
has_matched = True
converted_state_dict[new_key] = unet_state_dict[key]
return converted_state_dict
def main(model_path, output_path):
# Load your original U-Net model
unet_state_dict = load_file(model_path)
# Initialize your Kandinsky3UNet model
config = {}
# Convert the state dict
converted_state_dict = convert_state_dict(unet_state_dict)
unet = Kandinsky3UNet(config)
unet.load_state_dict(converted_state_dict)
unet.save_pretrained(output_path)
print(f"Converted model saved to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert U-Net PyTorch model to Kandinsky3UNet format")
parser.add_argument("--model_path", type=str, required=True, help="Path to the original U-Net PyTorch model")
parser.add_argument("--output_path", type=str, required=True, help="Path to save the converted model")
args = parser.parse_args()
main(args.model_path, args.output_path)
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import (
AutoPipelineForImage2Image,
AutoPipelineForText2Image,
Kandinsky3Pipeline,
Kandinsky3UNet,
VQModel,
)
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_image,
require_torch_gpu,
slow,
)
from ..pipeline_params import (
TEXT_TO_IMAGE_BATCH_PARAMS,
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS,
)
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class Kandinsky3PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = Kandinsky3Pipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
test_xformers_attention = False
@property
def dummy_movq_kwargs(self):
return {
"block_out_channels": [32, 64],
"down_block_types": ["DownEncoderBlock2D", "AttnDownEncoderBlock2D"],
"in_channels": 3,
"latent_channels": 4,
"layers_per_block": 1,
"norm_num_groups": 8,
"norm_type": "spatial",
"num_vq_embeddings": 12,
"out_channels": 3,
"up_block_types": [
"AttnUpDecoderBlock2D",
"UpDecoderBlock2D",
],
"vq_embed_dim": 4,
}
@property
def dummy_movq(self):
torch.manual_seed(0)
model = VQModel(**self.dummy_movq_kwargs)
return model
def get_dummy_components(self, time_cond_proj_dim=None):
torch.manual_seed(0)
unet = Kandinsky3UNet(
in_channels=4,
time_embedding_dim=4,
groups=2,
attention_head_dim=4,
layers_per_block=3,
block_out_channels=(32, 64),
cross_attention_dim=4,
encoder_hid_dim=32,
)
scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="squaredcos_cap_v2",
clip_sample=True,
thresholding=False,
)
torch.manual_seed(0)
movq = self.dummy_movq
torch.manual_seed(0)
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
components = {
"unet": unet,
"scheduler": scheduler,
"movq": movq,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "A painting of a squirrel eating a burger",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
"width": 16,
"height": 16,
}
return inputs
def test_kandinsky3(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 16, 16, 3)
expected_slice = np.array([0.3768, 0.4373, 0.4865, 0.4890, 0.4299, 0.5122, 0.4921, 0.4924, 0.5599])
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
def test_float16_inference(self):
super().test_float16_inference(expected_max_diff=1e-1)
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=1e-2)
def test_model_cpu_offload_forward_pass(self):
# TODO(Yiyi) - this test should work, skipped for time reasons for now
pass
@slow
@require_torch_gpu
class Kandinsky3PipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_kandinskyV3(self):
pipe = AutoPipelineForText2Image.from_pretrained(
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
generator = torch.Generator(device="cpu").manual_seed(0)
image = pipe(prompt, num_inference_steps=25, generator=generator).images[0]
assert image.size == (1024, 1024)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png"
)
image_processor = VaeImageProcessor()
image_np = image_processor.pil_to_numpy(image)
expected_image_np = image_processor.pil_to_numpy(expected_image)
self.assertTrue(np.allclose(image_np, expected_image_np, atol=5e-2))
def test_kandinskyV3_img2img(self):
pipe = AutoPipelineForImage2Image.from_pretrained(
"kandinsky-community/kandinsky-3", variant="fp16", torch_dtype=torch.float16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0)
image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/t2i.png"
)
w, h = 512, 512
image = image.resize((w, h), resample=Image.BICUBIC, reducing_gap=1)
prompt = "A painting of the inside of a subway train with tiny raccoons."
image = pipe(prompt, image=image, strength=0.75, num_inference_steps=25, generator=generator).images[0]
assert image.size == (512, 512)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky3/i2i.png"
)
image_processor = VaeImageProcessor()
image_np = image_processor.pil_to_numpy(image)
expected_image_np = image_processor.pil_to_numpy(expected_image)
self.assertTrue(np.allclose(image_np, expected_image_np, atol=5e-2))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment