"torchvision/vscode:/vscode.git/clone" did not exist on "9b82df43341a6891f652be1803abd1d1d05bfbb2"
Unverified Commit 40aa47b9 authored by Kashif Rasul's avatar Kashif Rasul Committed by GitHub
Browse files

[Pipiline] Wuerstchen v3 aka Stable Cascasde pipeline (#6487)



* initial diffNext v3

* move to v3 folder

* imports

* dry up the unets

* no switch_level

* fix init

* add switch_level tp config

* Fixed some things

* Added pooled text embeddings

* Initial work on adding image encoder

* changes from @dome272

* Stuff for the image encoder processing and variable naming in decoder

* fix arg name

* inference fixes

* inference fixes

* default TimestepBlock without conds

* c_skip=0 by default

* fix bfloat16 to cpu

* use config

* undo temp change

* fix gen_c_embeddings args

* change text encoding

* text encoding

* undo print

* undo .gitignore change

* Allow WuerstchenV3PriorPipeline to use the base DDPM & DDIM schedulers

* use WuerstchenV3Unet in both pipelines

* fix imports

* initial failing tests

* cleanup

* use scheduler.timesterps

* some fixes to the tests, still not fully working

* fix tests

* fix prior tests

* add dropout to the model_kwargs

* more tests passing

* update expected_slice

* initial rename

* rename tests

* rename class names

* make fix-copies

* initial docs

* autodocs

* typos

* fix arg docs

* add text_encoder info

* combined pipeline has optional image arg

* fix documentation

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* use self.config

* Update src/diffusers/pipelines/stable_cascade/modeling_stable_cascade_common.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* c_in -> in_channels

* removed kwargs from unet's forward

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* remove older callback api

* removed kwargs and fixed decoder guidance > 1

* decoder takes emeds

* check and use image_embeds

* fixed all but one decoder test

* fix decoder tests

* update callback api

* fix some more combined tests

* push combined pipeline

* initial docs

* fix doc_string

* update combined api

* no test_callback_inputs test for combined pipeline

* add optional components

* fix ordering of components

* fix combined tests

* update convert script

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* fix imports

* move effnet out of deniosing loop

* prompt_embeds_pooled only when doing guidance

* Fix repeat shape

* move StableCascadeUnet to models/unets/

* more descriptive names

* converted when numpy()

* StableCascadePriorPipelineOutput docs

* rename StableCascadeUNet

* add slow tests

* fix slow tests

* update

* update

* updated model_path

* add args for weights

* set push_to_hub to false

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------
Co-authored-by: default avatarDominic Rampas <d6582533@gmail.com>
Co-authored-by: default avatarPablo Pernias <pablo@pernias.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatar99991 <99991@users.noreply.github.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 1bc0d37f
...@@ -318,6 +318,8 @@ ...@@ -318,6 +318,8 @@
title: Semantic Guidance title: Semantic Guidance
- local: api/pipelines/shap_e - local: api/pipelines/shap_e
title: Shap-E title: Shap-E
- local: api/pipelines/stable_cascade
title: Stable Cascade
- sections: - sections:
- local: api/pipelines/stable_diffusion/overview - local: api/pipelines/stable_diffusion/overview
title: Overview title: Overview
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Stable Cascade
This model is built upon the [Würstchen](https://openreview.net/forum?id=gU58d5QeGv) architecture and its main
difference to other models like Stable Diffusion is that it is working at a much smaller latent space. Why is this
important? The smaller the latent space, the **faster** you can run inference and the **cheaper** the training becomes.
How small is the latent space? Stable Diffusion uses a compression factor of 8, resulting in a 1024x1024 image being
encoded to 128x128. Stable Cascade achieves a compression factor of 42, meaning that it is possible to encode a
1024x1024 image to 24x24, while maintaining crisp reconstructions. The text-conditional model is then trained in the
highly compressed latent space. Previous versions of this architecture, achieved a 16x cost reduction over Stable
Diffusion 1.5.
Therefore, this kind of model is well suited for usages where efficiency is important. Furthermore, all known extensions
like finetuning, LoRA, ControlNet, IP-Adapter, LCM etc. are possible with this method as well.
The original codebase can be found at [Stability-AI/StableCascade](https://github.com/Stability-AI/StableCascade).
## Model Overview
Stable Cascade consists of three models: Stage A, Stage B and Stage C, representing a cascade to generate images,
hence the name "Stable Cascade".
Stage A & B are used to compress images, similar to what the job of the VAE is in Stable Diffusion.
However, with this setup, a much higher compression of images can be achieved. While the Stable Diffusion models use a
spatial compression factor of 8, encoding an image with resolution of 1024 x 1024 to 128 x 128, Stable Cascade achieves
a compression factor of 42. This encodes a 1024 x 1024 image to 24 x 24, while being able to accurately decode the
image. This comes with the great benefit of cheaper training and inference. Furthermore, Stage C is responsible
for generating the small 24 x 24 latents given a text prompt.
## Uses
### Direct Use
The model is intended for research purposes for now. Possible research areas and tasks include
- Research on generative models.
- Safe deployment of models which have the potential to generate harmful content.
- Probing and understanding the limitations and biases of generative models.
- Generation of artworks and use in design and other artistic processes.
- Applications in educational or creative tools.
Excluded uses are described below.
### Out-of-Scope Use
The model was not trained to be factual or true representations of people or events,
and therefore using the model to generate such content is out-of-scope for the abilities of this model.
The model should not be used in any way that violates Stability AI's [Acceptable Use Policy](https://stability.ai/use-policy).
## Limitations and Bias
### Limitations
- Faces and people in general may not be generated properly.
- The autoencoding part of the model is lossy.
## StableCascadeCombinedPipeline
[[autodoc]] StableCascadeCombinedPipeline
- all
- __call__
## StableCascadePriorPipeline
[[autodoc]] StableCascadePriorPipeline
- all
- __call__
## StableCascadePriorPipelineOutput
[[autodoc]] pipelines.stable_cascade.pipeline_stable_cascade_prior.StableCascadePriorPipelineOutput
## StableCascadeDecoderPipeline
[[autodoc]] StableCascadeDecoderPipeline
- all
- __call__
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
import accelerate
import torch
from safetensors.torch import load_file
from transformers import (
AutoTokenizer,
CLIPConfig,
CLIPImageProcessor,
CLIPTextModelWithProjection,
CLIPVisionModelWithProjection,
)
from diffusers import (
DDPMWuerstchenScheduler,
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
)
from diffusers.models import StableCascadeUNet
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.pipelines.wuerstchen import PaellaVQModel
parser = argparse.ArgumentParser(description="Convert Stable Cascade model weights to a diffusers pipeline")
parser.add_argument("--model_path", type=str, default="../StableCascade", help="Location of Stable Cascade weights")
parser.add_argument("--stage_c_name", type=str, default="stage_c.safetensors", help="Name of stage c checkpoint file")
parser.add_argument("--stage_b_name", type=str, default="stage_b.safetensors", help="Name of stage b checkpoint file")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument("--save_org", type=str, default="diffusers", help="Hub organization to save the pipelines to")
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
args = parser.parse_args()
model_path = args.model_path
device = "cpu"
# set paths to model weights
prior_checkpoint_path = f"{model_path}/{args.stage_c_name}"
decoder_checkpoint_path = f"{model_path}/{args.stage_b_name}"
# Clip Text encoder and tokenizer
config = CLIPConfig.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
config.text_config.projection_dim = config.projection_dim
text_encoder = CLIPTextModelWithProjection.from_pretrained(
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", config=config.text_config
)
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k")
# image processor
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
# Prior
if args.use_safetensors:
orig_state_dict = load_file(prior_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(prior_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
with accelerate.init_empty_weights():
prior_model = StableCascadeUNet(
in_channels=16,
out_channels=16,
timestep_ratio_embedding_dim=64,
patch_size=1,
conditioning_dim=2048,
block_out_channels=[2048, 2048],
num_attention_heads=[32, 32],
down_num_layers_per_block=[8, 24],
up_num_layers_per_block=[24, 8],
down_blocks_repeat_mappers=[1, 1],
up_blocks_repeat_mappers=[1, 1],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_in_channels=1280,
clip_text_pooled_in_channels=1280,
clip_image_in_channels=768,
clip_seq=4,
kernel_size=3,
dropout=[0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca", "crp"],
switch_level=[False],
)
load_model_dict_into_meta(prior_model, state_dict)
# scheduler for prior and decoder
scheduler = DDPMWuerstchenScheduler()
# Prior pipeline
prior_pipeline = StableCascadePriorPipeline(
prior=prior_model,
tokenizer=tokenizer,
text_encoder=text_encoder,
image_encoder=image_encoder,
scheduler=scheduler,
feature_extractor=feature_extractor,
)
prior_pipeline.save_pretrained(f"{args.save_org}/StableCascade-prior", push_to_hub=args.push_to_hub)
# Decoder
if args.use_safetensors:
orig_state_dict = load_file(decoder_checkpoint_path, device=device)
else:
orig_state_dict = torch.load(decoder_checkpoint_path, map_location=device)
state_dict = {}
for key in orig_state_dict.keys():
if key.endswith("in_proj_weight"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0]
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1]
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2]
elif key.endswith("in_proj_bias"):
weights = orig_state_dict[key].chunk(3, 0)
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0]
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1]
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2]
elif key.endswith("out_proj.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights
elif key.endswith("out_proj.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights
# rename clip_mapper to clip_txt_pooled_mapper
elif key.endswith("clip_mapper.weight"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.weight", "clip_txt_pooled_mapper.weight")] = weights
elif key.endswith("clip_mapper.bias"):
weights = orig_state_dict[key]
state_dict[key.replace("clip_mapper.bias", "clip_txt_pooled_mapper.bias")] = weights
else:
state_dict[key] = orig_state_dict[key]
with accelerate.init_empty_weights():
decoder = StableCascadeUNet(
in_channels=4,
out_channels=4,
timestep_ratio_embedding_dim=64,
patch_size=2,
conditioning_dim=1280,
block_out_channels=[320, 640, 1280, 1280],
down_num_layers_per_block=[2, 6, 28, 6],
up_num_layers_per_block=[6, 28, 6, 2],
down_blocks_repeat_mappers=[1, 1, 1, 1],
up_blocks_repeat_mappers=[3, 3, 2, 2],
num_attention_heads=[0, 0, 20, 20],
block_types_per_layer=[
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
clip_text_pooled_in_channels=1280,
clip_seq=4,
effnet_in_channels=16,
pixel_mapper_in_channels=3,
kernel_size=3,
dropout=[0, 0, 0.1, 0.1],
self_attn=True,
timestep_conditioning_type=["sca"],
)
load_model_dict_into_meta(decoder, state_dict)
# VQGAN from Wuerstchen-V2
vqmodel = PaellaVQModel.from_pretrained("warp-ai/wuerstchen", subfolder="vqgan")
# Decoder pipeline
decoder_pipeline = StableCascadeDecoderPipeline(
decoder=decoder, text_encoder=text_encoder, tokenizer=tokenizer, vqgan=vqmodel, scheduler=scheduler
)
decoder_pipeline.save_pretrained(f"{args.save_org}/StableCascade-decoder", push_to_hub=args.push_to_hub)
# Stable Cascade combined pipeline
stable_cascade_pipeline = StableCascadeCombinedPipeline(
# Decoder
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqmodel,
# Prior
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
prior_prior=prior_model,
prior_scheduler=scheduler,
prior_image_encoder=image_encoder,
prior_feature_extractor=feature_extractor,
)
stable_cascade_pipeline.save_pretrained(f"{args.save_org}/StableCascade", push_to_hub=args.push_to_hub)
...@@ -86,6 +86,7 @@ else: ...@@ -86,6 +86,7 @@ else:
"MotionAdapter", "MotionAdapter",
"MultiAdapter", "MultiAdapter",
"PriorTransformer", "PriorTransformer",
"StableCascadeUNet",
"T2IAdapter", "T2IAdapter",
"T5FilmDecoder", "T5FilmDecoder",
"Transformer2DModel", "Transformer2DModel",
...@@ -259,6 +260,9 @@ else: ...@@ -259,6 +260,9 @@ else:
"SemanticStableDiffusionPipeline", "SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline", "ShapEImg2ImgPipeline",
"ShapEPipeline", "ShapEPipeline",
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
"StableDiffusionAdapterPipeline", "StableDiffusionAdapterPipeline",
"StableDiffusionAttendAndExcitePipeline", "StableDiffusionAttendAndExcitePipeline",
"StableDiffusionControlNetImg2ImgPipeline", "StableDiffusionControlNetImg2ImgPipeline",
...@@ -626,6 +630,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -626,6 +630,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SemanticStableDiffusionPipeline, SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline, ShapEImg2ImgPipeline,
ShapEPipeline, ShapEPipeline,
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
StableDiffusionAdapterPipeline, StableDiffusionAdapterPipeline,
StableDiffusionAttendAndExcitePipeline, StableDiffusionAttendAndExcitePipeline,
StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline,
......
...@@ -47,6 +47,7 @@ if is_torch_available(): ...@@ -47,6 +47,7 @@ if is_torch_available():
_import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"] _import_structure["unets.unet_kandinsky3"] = ["Kandinsky3UNet"]
_import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"] _import_structure["unets.unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"] _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
_import_structure["unets.uvit_2d"] = ["UVit2DModel"] _import_structure["unets.uvit_2d"] = ["UVit2DModel"]
_import_structure["vq_model"] = ["VQModel"] _import_structure["vq_model"] = ["VQModel"]
...@@ -80,6 +81,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -80,6 +81,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
I2VGenXLUNet, I2VGenXLUNet,
Kandinsky3UNet, Kandinsky3UNet,
MotionAdapter, MotionAdapter,
StableCascadeUNet,
UNet1DModel, UNet1DModel,
UNet2DConditionModel, UNet2DConditionModel,
UNet2DModel, UNet2DModel,
......
...@@ -10,6 +10,7 @@ if is_torch_available(): ...@@ -10,6 +10,7 @@ if is_torch_available():
from .unet_kandinsky3 import Kandinsky3UNet from .unet_kandinsky3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .unet_stable_cascade import StableCascadeUNet
from .uvit_2d import UVit2DModel from .uvit_2d import UVit2DModel
......
This diff is collapsed.
...@@ -176,6 +176,11 @@ else: ...@@ -176,6 +176,11 @@ else:
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"] _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline"]
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
_import_structure["stable_cascade"] = [
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
]
_import_structure["stable_diffusion"].extend( _import_structure["stable_diffusion"].extend(
[ [
"CLIPImageProjection", "CLIPImageProjection",
...@@ -424,6 +429,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -424,6 +429,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pixart_alpha import PixArtAlphaPipeline from .pixart_alpha import PixArtAlphaPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
from .stable_cascade import (
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
)
from .stable_diffusion import ( from .stable_diffusion import (
CLIPImageProjection, CLIPImageProjection,
StableDiffusionDepth2ImgPipeline, StableDiffusionDepth2ImgPipeline,
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_stable_cascade"] = ["StableCascadeDecoderPipeline"]
_import_structure["pipeline_stable_cascade_combined"] = ["StableCascadeCombinedPipeline"]
_import_structure["pipeline_stable_cascade_prior"] = ["StableCascadePriorPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
from .pipeline_stable_cascade_combined import StableCascadeCombinedPipeline
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Dict, List, Optional, Union
import PIL
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from ...models import StableCascadeUNet
from ...schedulers import DDPMWuerstchenScheduler
from ...utils import replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from ..wuerstchen.modeling_paella_vq_model import PaellaVQModel
from .pipeline_stable_cascade import StableCascadeDecoderPipeline
from .pipeline_stable_cascade_prior import StableCascadePriorPipeline
TEXT2IMAGE_EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusions import StableCascadeCombinedPipeline
>>> pipe = StableCascadeCombinedPipeline.from_pretrained("stabilityai/stable-cascade-combined", torch_dtype=torch.bfloat16).to(
... "cuda"
... )
>>> prompt = "an image of a shiba inu, donning a spacesuit and helmet"
>>> images = pipe(prompt=prompt)
```
"""
class StableCascadeCombinedPipeline(DiffusionPipeline):
"""
Combined Pipeline for text-to-image generation using Stable Cascade.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer (`CLIPTokenizer`):
The decoder tokenizer to be used for text inputs.
text_encoder (`CLIPTextModel`):
The decoder text encoder to be used for text inputs.
decoder (`StableCascadeUNet`):
The decoder model to be used for decoder image generation pipeline.
scheduler (`DDPMWuerstchenScheduler`):
The scheduler to be used for decoder image generation pipeline.
vqgan (`PaellaVQModel`):
The VQGAN model to be used for decoder image generation pipeline.
feature_extractor ([`~transformers.CLIPImageProcessor`]):
Model that extracts features from generated images to be used as inputs for the `image_encoder`.
image_encoder ([`CLIPVisionModelWithProjection`]):
Frozen CLIP image-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
prior_prior (`StableCascadeUNet`):
The prior model to be used for prior pipeline.
prior_scheduler (`DDPMWuerstchenScheduler`):
The scheduler to be used for prior pipeline.
"""
_load_connected_pipes = True
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: CLIPTextModel,
decoder: StableCascadeUNet,
scheduler: DDPMWuerstchenScheduler,
vqgan: PaellaVQModel,
prior_prior: StableCascadeUNet,
prior_text_encoder: CLIPTextModel,
prior_tokenizer: CLIPTokenizer,
prior_scheduler: DDPMWuerstchenScheduler,
prior_feature_extractor: Optional[CLIPImageProcessor] = None,
prior_image_encoder: Optional[CLIPVisionModelWithProjection] = None,
):
super().__init__()
self.register_modules(
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqgan,
prior_text_encoder=prior_text_encoder,
prior_tokenizer=prior_tokenizer,
prior_prior=prior_prior,
prior_scheduler=prior_scheduler,
prior_feature_extractor=prior_feature_extractor,
prior_image_encoder=prior_image_encoder,
)
self.prior_pipe = StableCascadePriorPipeline(
prior=prior_prior,
text_encoder=prior_text_encoder,
tokenizer=prior_tokenizer,
scheduler=prior_scheduler,
image_encoder=prior_image_encoder,
feature_extractor=prior_feature_extractor,
)
self.decoder_pipe = StableCascadeDecoderPipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
decoder=decoder,
scheduler=scheduler,
vqgan=vqgan,
)
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
def enable_model_cpu_offload(self, gpu_id=0):
r"""
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
"""
self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
def enable_sequential_cpu_offload(self, gpu_id=0):
r"""
Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
"""
self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
def progress_bar(self, iterable=None, total=None):
self.prior_pipe.progress_bar(iterable=iterable, total=total)
self.decoder_pipe.progress_bar(iterable=iterable, total=total)
def set_progress_bar_config(self, **kwargs):
self.prior_pipe.set_progress_bar_config(**kwargs)
self.decoder_pipe.set_progress_bar_config(**kwargs)
@torch.no_grad()
@replace_example_docstring(TEXT2IMAGE_EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
images: Union[torch.Tensor, PIL.Image.Image, List[torch.Tensor], List[PIL.Image.Image]] = None,
height: int = 512,
width: int = 512,
prior_num_inference_steps: int = 60,
prior_timesteps: Optional[List[float]] = None,
prior_guidance_scale: float = 4.0,
num_inference_steps: int = 12,
decoder_timesteps: Optional[List[float]] = None,
decoder_guidance_scale: float = 0.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
prior_callback_on_step_end_tensor_inputs: List[str] = ["latents"],
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation for the prior and decoder.
images (`torch.Tensor`, `PIL.Image.Image`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, *optional*):
The images to guide the image generation for the prior.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
input argument.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to 512):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to 512):
The width in pixels of the generated image.
prior_guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`prior_guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting
`prior_guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked
to the text `prompt`, usually at the expense of lower image quality.
prior_num_inference_steps (`Union[int, Dict[float, int]]`, *optional*, defaults to 60):
The number of prior denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. For more specific timestep spacing, you can pass customized
`prior_timesteps`
num_inference_steps (`int`, *optional*, defaults to 12):
The number of decoder denoising steps. More denoising steps usually lead to a higher quality image at
the expense of slower inference. For more specific timestep spacing, you can pass customized
`timesteps`
decoder_guidance_scale (`float`, *optional*, defaults to 0.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
prior_callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `prior_callback_on_step_end(self: DiffusionPipeline, step: int, timestep:
int, callback_kwargs: Dict)`.
prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
the `._callback_tensor_inputs` attribute of your pipeine class.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple` [`~pipelines.ImagePipelineOutput`] if `return_dict` is True,
otherwise a `tuple`. When returning a tuple, the first element is a list with the generated images.
"""
prior_outputs = self.prior_pipe(
prompt=prompt if prompt_embeds is None else None,
images=images,
height=height,
width=width,
num_inference_steps=prior_num_inference_steps,
guidance_scale=prior_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
num_images_per_prompt=num_images_per_prompt,
generator=generator,
latents=latents,
output_type="pt",
return_dict=True,
callback_on_step_end=prior_callback_on_step_end,
callback_on_step_end_tensor_inputs=prior_callback_on_step_end_tensor_inputs,
)
image_embeddings = prior_outputs.image_embeddings
prompt_embeds = prior_outputs.get("prompt_embeds", None)
negative_prompt_embeds = prior_outputs.get("negative_prompt_embeds", None)
outputs = self.decoder_pipe(
image_embeddings=image_embeddings,
prompt=prompt if prompt_embeds is None else None,
num_inference_steps=num_inference_steps,
guidance_scale=decoder_guidance_scale,
negative_prompt=negative_prompt if negative_prompt_embeds is None else None,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
generator=generator,
output_type=output_type,
return_dict=return_dict,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
return outputs
# Copyright (c) 2023 Dominic Rampas MIT License
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
import torch.nn as nn import torch.nn as nn
......
...@@ -233,7 +233,7 @@ class WuerstchenDiffNeXt(ModelMixin, ConfigMixin): ...@@ -233,7 +233,7 @@ class WuerstchenDiffNeXt(ModelMixin, ConfigMixin):
class ResBlockStageB(nn.Module): class ResBlockStageB(nn.Module):
def __init__(self, c, c_skip=None, kernel_size=3, dropout=0.0): def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
super().__init__() super().__init__()
self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c) self.depthwise = nn.Conv2d(c, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6) self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
......
...@@ -349,6 +349,11 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -349,6 +349,11 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
text_encoder_hidden_states = ( text_encoder_hidden_states = (
torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds torch.cat([prompt_embeds, negative_prompt_embeds]) if negative_prompt_embeds is not None else prompt_embeds
) )
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
if self.do_classifier_free_guidance
else image_embeddings
)
# 3. Determine latent shape of latents # 3. Determine latent shape of latents
latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale) latent_height = int(image_embeddings.size(2) * self.config.latent_dim_scale)
...@@ -371,11 +376,6 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -371,11 +376,6 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
self._num_timesteps = len(timesteps[:-1]) self._num_timesteps = len(timesteps[:-1])
for i, t in enumerate(self.progress_bar(timesteps[:-1])): for i, t in enumerate(self.progress_bar(timesteps[:-1])):
ratio = t.expand(latents.size(0)).to(dtype) ratio = t.expand(latents.size(0)).to(dtype)
effnet = (
torch.cat([image_embeddings, torch.zeros_like(image_embeddings)])
if self.do_classifier_free_guidance
else image_embeddings
)
# 7. Denoise latents # 7. Denoise latents
predicted_latents = self.decoder( predicted_latents = self.decoder(
torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents, torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents,
...@@ -423,9 +423,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline): ...@@ -423,9 +423,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
latents = self.vqgan.config.scale_factor * latents latents = self.vqgan.config.scale_factor * latents
images = self.vqgan.decode(latents).sample.clamp(0, 1) images = self.vqgan.decode(latents).sample.clamp(0, 1)
if output_type == "np": if output_type == "np":
images = images.permute(0, 2, 3, 1).cpu().numpy() images = images.permute(0, 2, 3, 1).cpu().float().numpy()
elif output_type == "pil": elif output_type == "pil":
images = images.permute(0, 2, 3, 1).cpu().numpy() images = images.permute(0, 2, 3, 1).cpu().float().numpy()
images = self.numpy_to_pil(images) images = self.numpy_to_pil(images)
else: else:
images = latents images = latents
......
...@@ -508,7 +508,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin): ...@@ -508,7 +508,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
if output_type == "np": if output_type == "np":
latents = latents.cpu().numpy() latents = latents.cpu().float().numpy()
if not return_dict: if not return_dict:
return (latents,) return (latents,)
......
...@@ -752,6 +752,51 @@ class ShapEPipeline(metaclass=DummyObject): ...@@ -752,6 +752,51 @@ class ShapEPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class StableCascadeCombinedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableCascadeDecoderPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableCascadePriorPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class StableDiffusionAdapterPipeline(metaclass=DummyObject): class StableDiffusionAdapterPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, StableCascadeCombinedPipeline
from diffusers.models import StableCascadeUNet
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableCascadeCombinedPipeline
params = ["prompt"]
batch_params = ["prompt", "negative_prompt"]
required_optional_params = [
"generator",
"height",
"width",
"latents",
"prior_guidance_scale",
"decoder_guidance_scale",
"negative_prompt",
"num_inference_steps",
"return_dict",
"prior_num_inference_steps",
"output_type",
]
test_xformers_attention = True
@property
def text_embedder_hidden_size(self):
return 32
@property
def dummy_prior(self):
torch.manual_seed(0)
model_kwargs = {
"conditioning_dim": 128,
"block_out_channels": (128, 128),
"num_attention_heads": (2, 2),
"down_num_layers_per_block": (1, 1),
"up_num_layers_per_block": (1, 1),
"clip_image_in_channels": 768,
"switch_level": (False,),
"clip_text_in_channels": self.text_embedder_hidden_size,
"clip_text_pooled_in_channels": self.text_embedder_hidden_size,
}
model = StableCascadeUNet(**model_kwargs)
return model.eval()
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
projection_dim=self.text_embedder_hidden_size,
hidden_size=self.text_embedder_hidden_size,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModelWithProjection(config).eval()
@property
def dummy_vqgan(self):
torch.manual_seed(0)
model_kwargs = {
"bottleneck_blocks": 1,
"num_vq_embeddings": 2,
}
model = PaellaVQModel(**model_kwargs)
return model.eval()
@property
def dummy_decoder(self):
torch.manual_seed(0)
model_kwargs = {
"in_channels": 4,
"out_channels": 4,
"conditioning_dim": 128,
"block_out_channels": (16, 32, 64, 128),
"num_attention_heads": (-1, -1, 1, 2),
"down_num_layers_per_block": (1, 1, 1, 1),
"up_num_layers_per_block": (1, 1, 1, 1),
"down_blocks_repeat_mappers": (1, 1, 1, 1),
"up_blocks_repeat_mappers": (3, 3, 2, 2),
"block_types_per_layer": (
("SDCascadeResBlock", "SDCascadeTimestepBlock"),
("SDCascadeResBlock", "SDCascadeTimestepBlock"),
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
("SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"),
),
"switch_level": None,
"clip_text_pooled_in_channels": 32,
"dropout": (0.1, 0.1, 0.1, 0.1),
}
model = StableCascadeUNet(**model_kwargs)
return model.eval()
def get_dummy_components(self):
prior = self.dummy_prior
scheduler = DDPMWuerstchenScheduler()
tokenizer = self.dummy_tokenizer
text_encoder = self.dummy_text_encoder
decoder = self.dummy_decoder
vqgan = self.dummy_vqgan
prior_text_encoder = self.dummy_text_encoder
prior_tokenizer = self.dummy_tokenizer
components = {
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"decoder": decoder,
"scheduler": scheduler,
"vqgan": vqgan,
"prior_text_encoder": prior_text_encoder,
"prior_tokenizer": prior_tokenizer,
"prior_prior": prior,
"prior_scheduler": scheduler,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "horse",
"generator": generator,
"prior_guidance_scale": 4.0,
"decoder_guidance_scale": 4.0,
"num_inference_steps": 2,
"prior_num_inference_steps": 2,
"output_type": "np",
"height": 128,
"width": 128,
}
return inputs
def test_stable_cascade(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.images
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[-3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array([0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
assert (
np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_from_tuple_slice.flatten()}"
@require_torch_gpu
def test_offloads(self):
pipes = []
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components).to(torch_device)
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_sequential_cpu_offload()
pipes.append(sd_pipe)
components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_model_cpu_offload()
pipes.append(sd_pipe)
image_slices = []
for pipe in pipes:
inputs = self.get_dummy_inputs(torch_device)
image = pipe(**inputs).images
image_slices.append(image[0, -3:, -3:, -1].flatten())
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
assert np.abs(image_slices[0] - image_slices[2]).max() < 1e-3
def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=2e-2)
@unittest.skip(reason="fp16 not supported")
def test_float16_inference(self):
super().test_float16_inference()
@unittest.skip(reason="no callback test for combined pipeline")
def test_callback_inputs(self):
super().test_callback_inputs()
# def test_callback_cfg(self):
# pass
# pass
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import DDPMWuerstchenScheduler, StableCascadeDecoderPipeline
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import StableCascadeUNet
from diffusers.pipelines.wuerstchen import PaellaVQModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
load_image,
load_pt,
require_torch_gpu,
skip_mps,
slow,
torch_device,
)
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class StableCascadeDecoderPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = StableCascadeDecoderPipeline
params = ["prompt"]
batch_params = ["image_embeddings", "prompt", "negative_prompt"]
required_optional_params = [
"num_images_per_prompt",
"num_inference_steps",
"latents",
"negative_prompt",
"guidance_scale",
"output_type",
"return_dict",
]
test_xformers_attention = False
callback_cfg_params = ["image_embeddings", "text_encoder_hidden_states"]
@property
def text_embedder_hidden_size(self):
return 32
@property
def time_input_dim(self):
return 32
@property
def block_out_channels_0(self):
return self.time_input_dim
@property
def time_embed_dim(self):
return self.time_input_dim * 4
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
projection_dim=self.text_embedder_hidden_size,
hidden_size=self.text_embedder_hidden_size,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModelWithProjection(config).eval()
@property
def dummy_vqgan(self):
torch.manual_seed(0)
model_kwargs = {
"bottleneck_blocks": 1,
"num_vq_embeddings": 2,
}
model = PaellaVQModel(**model_kwargs)
return model.eval()
@property
def dummy_decoder(self):
torch.manual_seed(0)
model_kwargs = {
"in_channels": 4,
"out_channels": 4,
"conditioning_dim": 128,
"block_out_channels": [16, 32, 64, 128],
"num_attention_heads": [-1, -1, 1, 2],
"down_num_layers_per_block": [1, 1, 1, 1],
"up_num_layers_per_block": [1, 1, 1, 1],
"down_blocks_repeat_mappers": [1, 1, 1, 1],
"up_blocks_repeat_mappers": [3, 3, 2, 2],
"block_types_per_layer": [
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
["SDCascadeResBlock", "SDCascadeTimestepBlock", "SDCascadeAttnBlock"],
],
"switch_level": None,
"clip_text_pooled_in_channels": 32,
"dropout": [0.1, 0.1, 0.1, 0.1],
}
model = StableCascadeUNet(**model_kwargs)
return model.eval()
def get_dummy_components(self):
decoder = self.dummy_decoder
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
vqgan = self.dummy_vqgan
scheduler = DDPMWuerstchenScheduler()
components = {
"decoder": decoder,
"vqgan": vqgan,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"scheduler": scheduler,
"latent_dim_scale": 4.0,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"image_embeddings": torch.ones((1, 4, 4, 4), device=device),
"prompt": "horse",
"generator": generator,
"guidance_scale": 2.0,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def test_wuerstchen_decoder(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(device))
image = output.images
image_from_tuple = pipe(**self.get_dummy_inputs(device), return_dict=False)
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@skip_mps
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-2)
@skip_mps
def test_attention_slicing_forward_pass(self):
test_max_difference = torch_device == "cpu"
test_mean_pixel_difference = False
self._test_attention_slicing_forward_pass(
test_max_difference=test_max_difference,
test_mean_pixel_difference=test_mean_pixel_difference,
)
@unittest.skip(reason="fp16 not supported")
def test_float16_inference(self):
super().test_float16_inference()
@slow
@require_torch_gpu
class StableCascadeDecoderPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_stable_cascade_decoder(self):
pipe = StableCascadeDecoderPipeline.from_pretrained(
"diffusers/StableCascade-decoder", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
pipe.set_progress_bar_config(disable=None)
prompt = "A photograph of the inside of a subway train. There are raccoons sitting on the seats. One of them is reading a newspaper. The window shows the city in the background."
generator = torch.Generator(device="cpu").manual_seed(0)
image_embedding = load_pt(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/image_embedding.pt"
)
image = pipe(
prompt=prompt, image_embeddings=image_embedding, num_inference_steps=10, generator=generator
).images[0]
assert image.size == (1024, 1024)
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/stable_cascade/t2i.png"
)
image_processor = VaeImageProcessor()
image_np = image_processor.pil_to_numpy(image)
expected_image_np = image_processor.pil_to_numpy(expected_image)
self.assertTrue(np.allclose(image_np, expected_image_np, atol=53e-2))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment