Unverified Commit 157c9011 authored by Ayush Mangal's avatar Ayush Mangal Committed by GitHub
Browse files

Add BLIP Diffusion (#4388)



* Add BLIP Diffusion skeleton

* Add other model components

* Add BLIP2, need to change it for now

* Fix pipeline imports

* Load pretrained ViT

* Make qformer fwd pass same

* Replicate fwd passes

* Fix device bug

* Add accelerate functions

* Remove extra functions from Blip2

* Minor bug

* Integrate initial review changes

* Refactoring

* Refactoring

* Refactor

* Add controlnet

* Refactor

* Update conversion script

* Add image processor

* Shift postprocessing to ImageProcessor

* Refactor

* Fix device

* Add fast tests

* Update conversion script

* Fix checkpoint conversion script

* Integrate review changes

* Integrate reivew changes

* Remove unused functions from test

* Reuse HF image processor in Cond image

* Create new BlipImageProcessor based on transfomers

* Fix image preprocessor

* Minor

* Minor

* Add canny preprocessing

* Fix controlnet preprocessing

* Fix blip diffusion test

* Add controlnet test

* Add initial doc strings

* Integrate review changes

* Refactor

* Update examples

* Remove DDIM comments

* Add copied from for prepare_latents

* Add type anotations

* Add docstrings

* Do black formatting

* Add batch support

* Make tests pass

* Make controlnet tests pass

* Black formatting

* Fix progress bar

* Fix some licensing comments

* Fix imports

* Refactor controlnet

* Make tests faster

* Edit examples

* Black formatting/Ruff

* Add doc

* Minor
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>

* Move controlnet pipeline

* Make tests faster

* Fix imports

* Fix formatting

* Fix make errors

* Fix make errors

* Minor

* Add suggested doc changes
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Edit docs

* Fix 16 bit loading

* Update examples

* Edit toctree

* Update docs/source/en/api/pipelines/blip_diffusion.md
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Minor

* Add tips

* Edit examples

* Update model paths

---------
Co-authored-by: default avatarPatrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
parent 24563ca6
...@@ -216,6 +216,8 @@ ...@@ -216,6 +216,8 @@
title: AudioLDM 2 title: AudioLDM 2
- local: api/pipelines/auto_pipeline - local: api/pipelines/auto_pipeline
title: AutoPipeline title: AutoPipeline
- local: api/pipelines/blip_diffusion
title: BLIP Diffusion
- local: api/pipelines/consistency_models - local: api/pipelines/consistency_models
title: Consistency Models title: Consistency Models
- local: api/pipelines/controlnet - local: api/pipelines/controlnet
......
# Blip Diffusion
Blip Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://arxiv.org/abs/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
The abstract from the paper is:
*Subject-driven text-to-image generation models create novel renditions of an input subject based on text prompts. Existing models suffer from lengthy fine-tuning and difficulties preserving the subject fidelity. To overcome these limitations, we introduce BLIP-Diffusion, a new subject-driven image generation model that supports multimodal control which consumes inputs of subject images and text prompts. Unlike other subject-driven generation models, BLIP-Diffusion introduces a new multimodal encoder which is pre-trained to provide subject representation. We first pre-train the multimodal encoder following BLIP-2 to produce visual representation aligned with the text. Then we design a subject representation learning task which enables a diffusion model to leverage such visual representation and generates new subject renditions. Compared with previous methods such as DreamBooth, our model enables zero-shot subject-driven generation, and efficient fine-tuning for customized subject with up to 20x speedup. We also demonstrate that BLIP-Diffusion can be flexibly combined with existing techniques such as ControlNet and prompt-to-prompt to enable novel subject-driven generation and editing applications.*
The original codebase can be found at [salesforce/LAVIS](https://github.com/salesforce/LAVIS/tree/main/projects/blip-diffusion). You can find the official BLIP Diffusion checkpoints under the [hf.co/SalesForce](https://hf.co/SalesForce) organization.
`BlipDiffusionPipeline` and `BlipDiffusionControlNetPipeline` were contributed by [`ayushtues`](https://github.com/ayushtues/).
<Tip>
Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](/using-diffusers/loading#reuse-components-across-pipelines) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## BlipDiffusionPipeline
[[autodoc]] BlipDiffusionPipeline
- all
- __call__
## BlipDiffusionControlNetPipeline
[[autodoc]] BlipDiffusionControlNetPipeline
- all
- __call__
"""
This script requires you to build `LAVIS` from source, since the pip version doesn't have BLIP Diffusion. Follow instructions here: https://github.com/salesforce/LAVIS/tree/main.
"""
import argparse
import os
import tempfile
import torch
from lavis.models import load_model_and_preprocess
from transformers import CLIPTokenizer
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
from diffusers import (
AutoencoderKL,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.pipelines import BlipDiffusionPipeline
from diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
from diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
from diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
BLIP2_CONFIG = {
"vision_config": {
"hidden_size": 1024,
"num_hidden_layers": 23,
"num_attention_heads": 16,
"image_size": 224,
"patch_size": 14,
"intermediate_size": 4096,
"hidden_act": "quick_gelu",
},
"qformer_config": {
"cross_attention_frequency": 1,
"encoder_hidden_size": 1024,
"vocab_size": 30523,
},
"num_query_tokens": 16,
}
blip2config = Blip2Config(**BLIP2_CONFIG)
def qformer_model_from_original_config():
qformer = Blip2QFormerModel(blip2config)
return qformer
def embeddings_from_original_checkpoint(model, diffuser_embeddings_prefix, original_embeddings_prefix):
embeddings = {}
embeddings.update(
{
f"{diffuser_embeddings_prefix}.word_embeddings.weight": model[
f"{original_embeddings_prefix}.word_embeddings.weight"
]
}
)
embeddings.update(
{
f"{diffuser_embeddings_prefix}.position_embeddings.weight": model[
f"{original_embeddings_prefix}.position_embeddings.weight"
]
}
)
embeddings.update(
{f"{diffuser_embeddings_prefix}.LayerNorm.weight": model[f"{original_embeddings_prefix}.LayerNorm.weight"]}
)
embeddings.update(
{f"{diffuser_embeddings_prefix}.LayerNorm.bias": model[f"{original_embeddings_prefix}.LayerNorm.bias"]}
)
return embeddings
def proj_layer_from_original_checkpoint(model, diffuser_proj_prefix, original_proj_prefix):
proj_layer = {}
proj_layer.update({f"{diffuser_proj_prefix}.dense1.weight": model[f"{original_proj_prefix}.dense1.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense1.bias": model[f"{original_proj_prefix}.dense1.bias"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense2.weight": model[f"{original_proj_prefix}.dense2.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.dense2.bias": model[f"{original_proj_prefix}.dense2.bias"]})
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.weight": model[f"{original_proj_prefix}.LayerNorm.weight"]})
proj_layer.update({f"{diffuser_proj_prefix}.LayerNorm.bias": model[f"{original_proj_prefix}.LayerNorm.bias"]})
return proj_layer
def attention_from_original_checkpoint(model, diffuser_attention_prefix, original_attention_prefix):
attention = {}
attention.update(
{
f"{diffuser_attention_prefix}.attention.query.weight": model[
f"{original_attention_prefix}.self.query.weight"
]
}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.query.bias": model[f"{original_attention_prefix}.self.query.bias"]}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.key.weight": model[f"{original_attention_prefix}.self.key.weight"]}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.key.bias": model[f"{original_attention_prefix}.self.key.bias"]}
)
attention.update(
{
f"{diffuser_attention_prefix}.attention.value.weight": model[
f"{original_attention_prefix}.self.value.weight"
]
}
)
attention.update(
{f"{diffuser_attention_prefix}.attention.value.bias": model[f"{original_attention_prefix}.self.value.bias"]}
)
attention.update(
{f"{diffuser_attention_prefix}.output.dense.weight": model[f"{original_attention_prefix}.output.dense.weight"]}
)
attention.update(
{f"{diffuser_attention_prefix}.output.dense.bias": model[f"{original_attention_prefix}.output.dense.bias"]}
)
attention.update(
{
f"{diffuser_attention_prefix}.output.LayerNorm.weight": model[
f"{original_attention_prefix}.output.LayerNorm.weight"
]
}
)
attention.update(
{
f"{diffuser_attention_prefix}.output.LayerNorm.bias": model[
f"{original_attention_prefix}.output.LayerNorm.bias"
]
}
)
return attention
def output_layers_from_original_checkpoint(model, diffuser_output_prefix, original_output_prefix):
output_layers = {}
output_layers.update({f"{diffuser_output_prefix}.dense.weight": model[f"{original_output_prefix}.dense.weight"]})
output_layers.update({f"{diffuser_output_prefix}.dense.bias": model[f"{original_output_prefix}.dense.bias"]})
output_layers.update(
{f"{diffuser_output_prefix}.LayerNorm.weight": model[f"{original_output_prefix}.LayerNorm.weight"]}
)
output_layers.update(
{f"{diffuser_output_prefix}.LayerNorm.bias": model[f"{original_output_prefix}.LayerNorm.bias"]}
)
return output_layers
def encoder_from_original_checkpoint(model, diffuser_encoder_prefix, original_encoder_prefix):
encoder = {}
for i in range(blip2config.qformer_config.num_hidden_layers):
encoder.update(
attention_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.attention", f"{original_encoder_prefix}.{i}.attention"
)
)
encoder.update(
attention_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.crossattention", f"{original_encoder_prefix}.{i}.crossattention"
)
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.weight": model[
f"{original_encoder_prefix}.{i}.intermediate.dense.weight"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate.dense.bias": model[
f"{original_encoder_prefix}.{i}.intermediate.dense.bias"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.weight": model[
f"{original_encoder_prefix}.{i}.intermediate_query.dense.weight"
]
}
)
encoder.update(
{
f"{diffuser_encoder_prefix}.{i}.intermediate_query.dense.bias": model[
f"{original_encoder_prefix}.{i}.intermediate_query.dense.bias"
]
}
)
encoder.update(
output_layers_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.output", f"{original_encoder_prefix}.{i}.output"
)
)
encoder.update(
output_layers_from_original_checkpoint(
model, f"{diffuser_encoder_prefix}.{i}.output_query", f"{original_encoder_prefix}.{i}.output_query"
)
)
return encoder
def visual_encoder_layer_from_original_checkpoint(model, diffuser_prefix, original_prefix):
visual_encoder_layer = {}
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.weight": model[f"{original_prefix}.ln_1.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm1.bias": model[f"{original_prefix}.ln_1.bias"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.weight": model[f"{original_prefix}.ln_2.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.layer_norm2.bias": model[f"{original_prefix}.ln_2.bias"]})
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.qkv.weight": model[f"{original_prefix}.attn.in_proj_weight"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.qkv.bias": model[f"{original_prefix}.attn.in_proj_bias"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.projection.weight": model[f"{original_prefix}.attn.out_proj.weight"]}
)
visual_encoder_layer.update(
{f"{diffuser_prefix}.self_attn.projection.bias": model[f"{original_prefix}.attn.out_proj.bias"]}
)
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.weight": model[f"{original_prefix}.mlp.c_fc.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc1.bias": model[f"{original_prefix}.mlp.c_fc.bias"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.weight": model[f"{original_prefix}.mlp.c_proj.weight"]})
visual_encoder_layer.update({f"{diffuser_prefix}.mlp.fc2.bias": model[f"{original_prefix}.mlp.c_proj.bias"]})
return visual_encoder_layer
def visual_encoder_from_original_checkpoint(model, diffuser_prefix, original_prefix):
visual_encoder = {}
visual_encoder.update(
{
f"{diffuser_prefix}.embeddings.class_embedding": model[f"{original_prefix}.class_embedding"]
.unsqueeze(0)
.unsqueeze(0)
}
)
visual_encoder.update(
{
f"{diffuser_prefix}.embeddings.position_embedding": model[
f"{original_prefix}.positional_embedding"
].unsqueeze(0)
}
)
visual_encoder.update(
{f"{diffuser_prefix}.embeddings.patch_embedding.weight": model[f"{original_prefix}.conv1.weight"]}
)
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.weight": model[f"{original_prefix}.ln_pre.weight"]})
visual_encoder.update({f"{diffuser_prefix}.pre_layernorm.bias": model[f"{original_prefix}.ln_pre.bias"]})
for i in range(blip2config.vision_config.num_hidden_layers):
visual_encoder.update(
visual_encoder_layer_from_original_checkpoint(
model, f"{diffuser_prefix}.encoder.layers.{i}", f"{original_prefix}.transformer.resblocks.{i}"
)
)
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.weight": model["blip.ln_vision.weight"]})
visual_encoder.update({f"{diffuser_prefix}.post_layernorm.bias": model["blip.ln_vision.bias"]})
return visual_encoder
def qformer_original_checkpoint_to_diffusers_checkpoint(model):
qformer_checkpoint = {}
qformer_checkpoint.update(embeddings_from_original_checkpoint(model, "embeddings", "blip.Qformer.bert.embeddings"))
qformer_checkpoint.update({"query_tokens": model["blip.query_tokens"]})
qformer_checkpoint.update(proj_layer_from_original_checkpoint(model, "proj_layer", "proj_layer"))
qformer_checkpoint.update(
encoder_from_original_checkpoint(model, "encoder.layer", "blip.Qformer.bert.encoder.layer")
)
qformer_checkpoint.update(visual_encoder_from_original_checkpoint(model, "visual_encoder", "blip.visual_encoder"))
return qformer_checkpoint
def get_qformer(model):
print("loading qformer")
qformer = qformer_model_from_original_config()
qformer_diffusers_checkpoint = qformer_original_checkpoint_to_diffusers_checkpoint(model)
load_checkpoint_to_model(qformer_diffusers_checkpoint, qformer)
print("done loading qformer")
return qformer
def load_checkpoint_to_model(checkpoint, model):
with tempfile.NamedTemporaryFile(delete=False) as file:
torch.save(checkpoint, file.name)
del checkpoint
model.load_state_dict(torch.load(file.name), strict=False)
os.remove(file.name)
def save_blip_diffusion_model(model, args):
qformer = get_qformer(model)
qformer.eval()
text_encoder = ContextCLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae")
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
vae.eval()
text_encoder.eval()
scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
skip_prk_steps=True,
)
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")
image_processor = BlipImageProcessor()
blip_diffusion = BlipDiffusionPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
image_processor=image_processor,
)
blip_diffusion.save_pretrained(args.checkpoint_path)
def main(args):
model, _, _ = load_model_and_preprocess("blip_diffusion", "base", device="cpu", is_eval=True)
save_blip_diffusion_model(model.state_dict(), args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
args = parser.parse_args()
main(args)
...@@ -197,6 +197,8 @@ else: ...@@ -197,6 +197,8 @@ else:
"AudioLDM2ProjectionModel", "AudioLDM2ProjectionModel",
"AudioLDM2UNet2DConditionModel", "AudioLDM2UNet2DConditionModel",
"AudioLDMPipeline", "AudioLDMPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"CLIPImageProjection", "CLIPImageProjection",
"CycleDiffusionPipeline", "CycleDiffusionPipeline",
"IFImg2ImgPipeline", "IFImg2ImgPipeline",
...@@ -458,6 +460,8 @@ if TYPE_CHECKING: ...@@ -458,6 +460,8 @@ if TYPE_CHECKING:
AutoPipelineForImage2Image, AutoPipelineForImage2Image,
AutoPipelineForInpainting, AutoPipelineForInpainting,
AutoPipelineForText2Image, AutoPipelineForText2Image,
BlipDiffusionControlNetPipeline,
BlipDiffusionPipeline,
CLIPImageProjection, CLIPImageProjection,
ConsistencyModelPipeline, ConsistencyModelPipeline,
DanceDiffusionPipeline, DanceDiffusionPipeline,
......
This diff is collapsed.
from dataclasses import dataclass
from typing import List, Optional, Union
import numpy as np
import PIL
from PIL import Image
from ...utils import OptionalDependencyNotAvailable, is_torch_available, is_transformers_available
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import ShapEPipeline
else:
from .blip_image_processing import BlipImageProcessor
from .modeling_blip2 import Blip2QFormerModel
from .modeling_ctx_clip import ContextCLIPTextModel
from .pipeline_blip_diffusion import BlipDiffusionPipeline
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processor class for BLIP."""
from typing import Dict, List, Optional, Union
import numpy as np
import torch
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict
from transformers.image_transforms import convert_to_rgb, resize, to_channel_dimension_format
from transformers.image_utils import (
OPENAI_CLIP_MEAN,
OPENAI_CLIP_STD,
ChannelDimension,
ImageInput,
PILImageResampling,
infer_channel_dimension_format,
is_scaled_image,
make_list_of_images,
to_numpy_array,
valid_images,
)
from transformers.utils import TensorType, is_vision_available, logging
from diffusers.utils import numpy_to_pil
if is_vision_available():
import PIL
logger = logging.get_logger(__name__)
# We needed some extra functions on top of the ones in transformers.image_processing_utils.BaseImageProcessor, namely center crop
# Copy-pasted from transformers.models.blip.image_processing_blip.BlipImageProcessor
class BlipImageProcessor(BaseImageProcessor):
r"""
Constructs a BLIP image processor.
Args:
do_resize (`bool`, *optional*, defaults to `True`):
Whether to resize the image's (height, width) dimensions to the specified `size`. Can be overridden by the
`do_resize` parameter in the `preprocess` method.
size (`dict`, *optional*, defaults to `{"height": 384, "width": 384}`):
Size of the output image after resizing. Can be overridden by the `size` parameter in the `preprocess`
method.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`. Can be
overridden by the `resample` parameter in the `preprocess` method.
do_rescale (`bool`, *optional*, defaults to `True`):
Wwhether to rescale the image by the specified scale `rescale_factor`. Can be overridden by the
`do_rescale` parameter in the `preprocess` method.
rescale_factor (`int` or `float`, *optional*, defaults to `1/255`):
Scale factor to use if rescaling the image. Only has an effect if `do_rescale` is set to `True`. Can be
overridden by the `rescale_factor` parameter in the `preprocess` method.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image. Can be overridden by the `do_normalize` parameter in the `preprocess`
method. Can be overridden by the `do_normalize` parameter in the `preprocess` method.
image_mean (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_MEAN`):
Mean to use if normalizing the image. This is a float or list of floats the length of the number of
channels in the image. Can be overridden by the `image_mean` parameter in the `preprocess` method. Can be
overridden by the `image_mean` parameter in the `preprocess` method.
image_std (`float` or `List[float]`, *optional*, defaults to `IMAGENET_STANDARD_STD`):
Standard deviation to use if normalizing the image. This is a float or list of floats the length of the
number of channels in the image. Can be overridden by the `image_std` parameter in the `preprocess` method.
Can be overridden by the `image_std` parameter in the `preprocess` method.
do_convert_rgb (`bool`, *optional*, defaults to `True`):
Whether to convert the image to RGB.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
do_resize: bool = True,
size: Dict[str, int] = None,
resample: PILImageResampling = PILImageResampling.BICUBIC,
do_rescale: bool = True,
rescale_factor: Union[int, float] = 1 / 255,
do_normalize: bool = True,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
do_convert_rgb: bool = True,
do_center_crop: bool = True,
**kwargs,
) -> None:
super().__init__(**kwargs)
size = size if size is not None else {"height": 224, "width": 224}
size = get_size_dict(size, default_to_square=True)
self.do_resize = do_resize
self.size = size
self.resample = resample
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean if image_mean is not None else OPENAI_CLIP_MEAN
self.image_std = image_std if image_std is not None else OPENAI_CLIP_STD
self.do_convert_rgb = do_convert_rgb
self.do_center_crop = do_center_crop
# Copy-pasted from transformers.models.vit.image_processing_vit.ViTImageProcessor.resize with PILImageResampling.BILINEAR->PILImageResampling.BICUBIC
def resize(
self,
image: np.ndarray,
size: Dict[str, int],
resample: PILImageResampling = PILImageResampling.BICUBIC,
data_format: Optional[Union[str, ChannelDimension]] = None,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> np.ndarray:
"""
Resize an image to `(size["height"], size["width"])`.
Args:
image (`np.ndarray`):
Image to resize.
size (`Dict[str, int]`):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`):
`PILImageResampling` filter to use when resizing the image e.g. `PILImageResampling.BICUBIC`.
data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the output image. If unset, the channel dimension format of the input
image is used. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
Returns:
`np.ndarray`: The resized image.
"""
size = get_size_dict(size)
if "height" not in size or "width" not in size:
raise ValueError(f"The `size` dictionary must contain the keys `height` and `width`. Got {size.keys()}")
output_size = (size["height"], size["width"])
return resize(
image,
size=output_size,
resample=resample,
data_format=data_format,
input_data_format=input_data_format,
**kwargs,
)
def preprocess(
self,
images: ImageInput,
do_resize: Optional[bool] = None,
size: Optional[Dict[str, int]] = None,
resample: PILImageResampling = None,
do_rescale: Optional[bool] = None,
do_center_crop: Optional[bool] = None,
rescale_factor: Optional[float] = None,
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, List[float]]] = None,
image_std: Optional[Union[float, List[float]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
do_convert_rgb: bool = None,
data_format: ChannelDimension = ChannelDimension.FIRST,
input_data_format: Optional[Union[str, ChannelDimension]] = None,
**kwargs,
) -> PIL.Image.Image:
"""
Preprocess an image or batch of images.
Args:
images (`ImageInput`):
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If
passing in images with pixel values between 0 and 1, set `do_rescale=False`.
do_resize (`bool`, *optional*, defaults to `self.do_resize`):
Whether to resize the image.
size (`Dict[str, int]`, *optional*, defaults to `self.size`):
Controls the size of the image after `resize`. The shortest edge of the image is resized to
`size["shortest_edge"]` whilst preserving the aspect ratio. If the longest edge of this resized image
is > `int(size["shortest_edge"] * (1333 / 800))`, then the image is resized again to make the longest
edge equal to `int(size["shortest_edge"] * (1333 / 800))`.
resample (`PILImageResampling`, *optional*, defaults to `self.resample`):
Resampling filter to use if resizing the image. Only has an effect if `do_resize` is set to `True`.
do_rescale (`bool`, *optional*, defaults to `self.do_rescale`):
Whether to rescale the image values between [0 - 1].
rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`):
Rescale factor to rescale the image by if `do_rescale` is set to `True`.
do_normalize (`bool`, *optional*, defaults to `self.do_normalize`):
Whether to normalize the image.
image_mean (`float` or `List[float]`, *optional*, defaults to `self.image_mean`):
Image mean to normalize the image by if `do_normalize` is set to `True`.
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`):
Image standard deviation to normalize the image by if `do_normalize` is set to `True`.
do_convert_rgb (`bool`, *optional*, defaults to `self.do_convert_rgb`):
Whether to convert the image to RGB.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`):
The channel dimension format for the output image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- Unset: Use the channel dimension format of the input image.
input_data_format (`ChannelDimension` or `str`, *optional*):
The channel dimension format for the input image. If unset, the channel dimension format is inferred
from the input image. Can be one of:
- `"channels_first"` or `ChannelDimension.FIRST`: image in (num_channels, height, width) format.
- `"channels_last"` or `ChannelDimension.LAST`: image in (height, width, num_channels) format.
- `"none"` or `ChannelDimension.NONE`: image in (height, width) format.
"""
do_resize = do_resize if do_resize is not None else self.do_resize
resample = resample if resample is not None else self.resample
do_rescale = do_rescale if do_rescale is not None else self.do_rescale
rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor
do_normalize = do_normalize if do_normalize is not None else self.do_normalize
image_mean = image_mean if image_mean is not None else self.image_mean
image_std = image_std if image_std is not None else self.image_std
do_convert_rgb = do_convert_rgb if do_convert_rgb is not None else self.do_convert_rgb
do_center_crop = do_center_crop if do_center_crop is not None else self.do_center_crop
size = size if size is not None else self.size
size = get_size_dict(size, default_to_square=False)
images = make_list_of_images(images)
if not valid_images(images):
raise ValueError(
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, "
"torch.Tensor, tf.Tensor or jax.ndarray."
)
if do_resize and size is None or resample is None:
raise ValueError("Size and resample must be specified if do_resize is True.")
if do_rescale and rescale_factor is None:
raise ValueError("Rescale factor must be specified if do_rescale is True.")
if do_normalize and (image_mean is None or image_std is None):
raise ValueError("Image mean and std must be specified if do_normalize is True.")
# PIL RGBA images are converted to RGB
if do_convert_rgb:
images = [convert_to_rgb(image) for image in images]
# All transformations expect numpy arrays.
images = [to_numpy_array(image) for image in images]
if is_scaled_image(images[0]) and do_rescale:
logger.warning_once(
"It looks like you are trying to rescale already rescaled images. If the input"
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again."
)
if input_data_format is None:
# We assume that all images have the same channel dimension format.
input_data_format = infer_channel_dimension_format(images[0])
if do_resize:
images = [
self.resize(image=image, size=size, resample=resample, input_data_format=input_data_format)
for image in images
]
if do_rescale:
images = [
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format)
for image in images
]
if do_normalize:
images = [
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format)
for image in images
]
if do_center_crop:
images = [self.center_crop(image, size, input_data_format=input_data_format) for image in images]
images = [
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images
]
encoded_outputs = BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
return encoded_outputs
# Follows diffusers.VaeImageProcessor.postprocess
def postprocess(self, sample: torch.FloatTensor, output_type: str = "pil"):
if output_type not in ["pt", "np", "pil"]:
raise ValueError(
f"output_type={output_type} is not supported. Make sure to choose one of ['pt', 'np', or 'pil']"
)
# Equivalent to diffusers.VaeImageProcessor.denormalize
sample = (sample / 2 + 0.5).clamp(0, 1)
if output_type == "pt":
return sample
# Equivalent to diffusers.VaeImageProcessor.pt_to_numpy
sample = sample.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "np":
return sample
# Output_type must be 'pil'
sample = numpy_to_pil(sample)
return sample
This diff is collapsed.
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
from torch import nn
from transformers import CLIPPreTrainedModel
from transformers.modeling_outputs import BaseModelOutputWithPooling
from transformers.models.clip.configuration_clip import CLIPTextConfig
from transformers.models.clip.modeling_clip import (
CLIPEncoder,
_expand_mask,
)
# This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip
# Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer
# They pass through the clip model, along with the text embeddings, and interact with them using self attention
class ContextCLIPTextModel(CLIPPreTrainedModel):
config_class = CLIPTextConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = ContextCLIPTextTransformer(config)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
ctx_embeddings: torch.Tensor = None,
ctx_begin_pos: list = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
return self.text_model(
ctx_embeddings=ctx_embeddings,
ctx_begin_pos=ctx_begin_pos,
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
class ContextCLIPTextTransformer(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
self.config = config
embed_dim = config.hidden_size
self.embeddings = ContextCLIPTextEmbeddings(config)
self.encoder = CLIPEncoder(config)
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward(
self,
ctx_embeddings: torch.Tensor,
ctx_begin_pos: list,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify either input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
ctx_embeddings=ctx_embeddings,
ctx_begin_pos=ctx_begin_pos,
)
bsz, seq_len = input_shape
if ctx_embeddings is not None:
seq_len += ctx_embeddings.size(1)
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs[0]
last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=input_ids.device),
input_ids.to(torch.int).argmax(dim=-1),
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class ContextCLIPTextEmbeddings(nn.Module):
def __init__(self, config: CLIPTextConfig):
super().__init__()
embed_dim = config.hidden_size
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(
self,
ctx_embeddings: torch.Tensor,
ctx_begin_pos: list,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
if ctx_embeddings is None:
ctx_len = 0
else:
ctx_len = ctx_embeddings.shape[1]
seq_length = (input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]) + ctx_len
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if inputs_embeds is None:
inputs_embeds = self.token_embedding(input_ids)
# for each input embeddings, add the ctx embeddings at the correct position
input_embeds_ctx = []
bsz = inputs_embeds.shape[0]
if ctx_embeddings is not None:
for i in range(bsz):
cbp = ctx_begin_pos[i]
prefix = inputs_embeds[i, :cbp]
# remove the special token embedding
suffix = inputs_embeds[i, cbp:]
input_embeds_ctx.append(torch.cat([prefix, ctx_embeddings[i], suffix], dim=0))
inputs_embeds = torch.stack(input_embeds_ctx, dim=0)
position_embeddings = self.position_embedding(position_ids)
embeddings = inputs_embeds + position_embeddings
return embeddings
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import PIL
import torch
from transformers import CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .blip_image_processing import BlipImageProcessor
from .modeling_blip2 import Blip2QFormerModel
from .modeling_ctx_clip import ContextCLIPTextModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers.pipelines import BlipDiffusionPipeline
>>> from diffusers.utils import load_image
>>> import torch
>>> blip_diffusion_pipe = BlipDiffusionPipeline.from_pretrained(
... "Salesforce/blipdiffusion", torch_dtype=torch.float16
... ).to("cuda")
>>> cond_subject = "dog"
>>> tgt_subject = "dog"
>>> text_prompt_input = "swimming underwater"
>>> cond_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/dog.jpg"
... )
>>> guidance_scale = 7.5
>>> num_inference_steps = 25
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
>>> output = blip_diffusion_pipe(
... text_prompt_input,
... cond_image,
... cond_subject,
... tgt_subject,
... guidance_scale=guidance_scale,
... num_inference_steps=num_inference_steps,
... neg_prompt=negative_prompt,
... height=512,
... width=512,
... ).images
>>> output[0].save("image.png")
```
"""
class BlipDiffusionPipeline(DiffusionPipeline):
"""
Pipeline for Zero-Shot Subject Driven Generation using Blip Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer ([`CLIPTokenizer`]):
Tokenizer for the text encoder
text_encoder ([`ContextCLIPTextModel`]):
Text encoder to encode the text prompt
vae ([`AutoencoderKL`]):
VAE model to map the latents to the image
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
scheduler ([`PNDMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
qformer ([`Blip2QFormerModel`]):
QFormer model to get multi-modal embeddings from the text and image.
image_processor ([`BlipImageProcessor`]):
Image Processor to preprocess and postprocess the image.
ctx_begin_pos (int, `optional`, defaults to 2):
Position of the context token in the text encoder.
"""
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: ContextCLIPTextModel,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
scheduler: PNDMScheduler,
qformer: Blip2QFormerModel,
image_processor: BlipImageProcessor,
ctx_begin_pos: int = 2,
mean: List[float] = None,
std: List[float] = None,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
image_processor=image_processor,
)
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects):
prompt = f"a {tgt_subject} {prompt.strip()}"
# a trick to amplify the prompt
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
return rv
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt):
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
tokenized_prompt = self.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
text_embeddings = self.text_encoder(
input_ids=tokenized_prompt.input_ids,
ctx_embeddings=query_embeds,
ctx_begin_pos=ctx_begin_pos,
)[0]
return text_embeddings
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: List[str],
reference_image: PIL.Image.Image,
source_subject_category: List[str],
target_subject_category: List[str],
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 7.5,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
neg_prompt: Optional[str] = "",
prompt_strength: float = 1.0,
prompt_reps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`List[str]`):
The prompt or prompts to guide the image generation.
reference_image (`PIL.Image.Image`):
The reference image to condition the generation on.
source_subject_category (`List[str]`):
The source subject category.
target_subject_category (`List[str]`):
The target subject category.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
height (`int`, *optional*, defaults to 512):
The height of the generated image.
width (`int`, *optional*, defaults to 512):
The width of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
neg_prompt (`str`, *optional*, defaults to ""):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_strength (`float`, *optional*, defaults to 1.0):
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
to amplify the prompt.
prompt_reps (`int`, *optional*, defaults to 20):
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between: `"pil"` (`PIL.Image.Image`), `"np"`
(`np.array`) or `"pt"` (`torch.Tensor`).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(source_subject_category, str):
source_subject_category = [source_subject_category]
if isinstance(target_subject_category, str):
target_subject_category = [target_subject_category]
batch_size = len(prompt)
prompt = self._build_prompt(
prompts=prompt,
tgt_subjects=target_subject_category,
prompt_strength=prompt_strength,
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
uncond_input = self.tokenizer(
[neg_prompt] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
latents = self.prepare_latents(
batch_size=batch_size,
num_channels=self.unet.config.in_channels,
height=height // scale_down_factor,
width=width // scale_down_factor,
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
)
# set timesteps
extra_set_kwargs = {}
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
do_classifier_free_guidance = guidance_scale > 1.0
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
noise_pred = self.unet(
latent_model_input,
timestep=t,
encoder_hidden_states=text_embeddings,
down_block_additional_residuals=None,
mid_block_additional_residual=None,
)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(
noise_pred,
t,
latents,
)["prev_sample"]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import ( from ...utils import (
OptionalDependencyNotAvailable, OptionalDependencyNotAvailable,
_LazyModule, _LazyModule,
get_objects_from_module, get_objects_from_module,
is_flax_available, is_flax_available,
is_torch_available, is_torch_available,
is_transformers_available, is_transformers_available,
) )
_dummy_objects = {} _dummy_objects = {}
_import_structure = {} _import_structure = {}
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403 from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["multicontrolnet"] = ["MultiControlNetModel"] _import_structure["multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"] _import_structure["pipeline_controlnet"] = ["StableDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"] _import_structure["pipeline_controlnet_blip_diffusion"] = ["BlipDiffusionControlNetPipeline"]
_import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"] _import_structure["pipeline_controlnet_img2img"] = ["StableDiffusionControlNetImg2ImgPipeline"]
_import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"] _import_structure["pipeline_controlnet_inpaint"] = ["StableDiffusionControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"] _import_structure["pipeline_controlnet_inpaint_sd_xl"] = ["StableDiffusionXLControlNetInpaintPipeline"]
_import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"] _import_structure["pipeline_controlnet_sd_xl"] = ["StableDiffusionXLControlNetPipeline"]
try: _import_structure["pipeline_controlnet_sd_xl_img2img"] = ["StableDiffusionXLControlNetImg2ImgPipeline"]
if not (is_transformers_available() and is_flax_available()): try:
raise OptionalDependencyNotAvailable() if not (is_transformers_available() and is_flax_available()):
except OptionalDependencyNotAvailable: raise OptionalDependencyNotAvailable()
from ...utils import dummy_flax_and_transformers_objects # noqa F403 except OptionalDependencyNotAvailable:
from ...utils import dummy_flax_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
else: _dummy_objects.update(get_objects_from_module(dummy_flax_and_transformers_objects))
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"] else:
_import_structure["pipeline_flax_controlnet"] = ["FlaxStableDiffusionControlNetPipeline"]
if TYPE_CHECKING:
try: if TYPE_CHECKING:
if not (is_transformers_available() and is_torch_available()): try:
raise OptionalDependencyNotAvailable() if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * except OptionalDependencyNotAvailable:
else: from ...utils.dummy_torch_and_transformers_objects import *
from .multicontrolnet import MultiControlNetModel else:
from .pipeline_controlnet import StableDiffusionControlNetPipeline from .multicontrolnet import MultiControlNetModel
from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline from .pipeline_controlnet import StableDiffusionControlNetPipeline
from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline from .pipeline_controlnet_blip_diffusion import BlipDiffusionControlNetPipeline
from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline from .pipeline_controlnet_img2img import StableDiffusionControlNetImg2ImgPipeline
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline from .pipeline_controlnet_inpaint import StableDiffusionControlNetInpaintPipeline
from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline from .pipeline_controlnet_inpaint_sd_xl import StableDiffusionXLControlNetInpaintPipeline
from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline
try: from .pipeline_controlnet_sd_xl_img2img import StableDiffusionXLControlNetImg2ImgPipeline
if not (is_transformers_available() and is_flax_available()):
raise OptionalDependencyNotAvailable() try:
except OptionalDependencyNotAvailable: if not (is_transformers_available() and is_flax_available()):
from ...utils.dummy_flax_and_transformers_objects import * # noqa F403 raise OptionalDependencyNotAvailable()
else: except OptionalDependencyNotAvailable:
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline from ...utils.dummy_flax_and_transformers_objects import * # noqa F403
else:
from .pipeline_flax_controlnet import FlaxStableDiffusionControlNetPipeline
else:
import sys
else:
sys.modules[__name__] = _LazyModule( import sys
__name__,
globals()["__file__"], sys.modules[__name__] = _LazyModule(
_import_structure, __name__,
module_spec=__spec__, globals()["__file__"],
) _import_structure,
for name, value in _dummy_objects.items(): module_spec=__spec__,
setattr(sys.modules[__name__], name, value) )
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
# Copyright 2023 Salesforce.com, inc.
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import PIL
import torch
from transformers import CLIPTokenizer
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import PNDMScheduler
from ...utils import (
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ..blip_diffusion.blip_image_processing import BlipImageProcessor
from ..blip_diffusion.modeling_blip2 import Blip2QFormerModel
from ..blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> from diffusers.pipelines import BlipDiffusionControlNetPipeline
>>> from diffusers.utils import load_image
>>> from controlnet_aux import CannyDetector
>>> import torch
>>> blip_diffusion_pipe = BlipDiffusionControlNetPipeline.from_pretrained(
... "Salesforce/blipdiffusion-controlnet", torch_dtype=torch.float16
... ).to("cuda")
>>> style_subject = "flower"
>>> tgt_subject = "teapot"
>>> text_prompt = "on a marble table"
>>> cldm_cond_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/kettle.jpg"
... ).resize(512, 512)
>>> canny = CannyDetector()
>>> cldm_cond_image = canny(cldm_cond_image, 30, 70, output_type="pil")
>>> style_image = load_image(
... "https://huggingface.co/datasets/ayushtues/blipdiffusion_images/resolve/main/flower.jpg"
... )
>>> guidance_scale = 7.5
>>> num_inference_steps = 50
>>> negative_prompt = "over-exposure, under-exposure, saturated, duplicate, out of frame, lowres, cropped, worst quality, low quality, jpeg artifacts, morbid, mutilated, out of frame, ugly, bad anatomy, bad proportions, deformed, blurry, duplicate"
>>> output = blip_diffusion_pipe(
... text_prompt,
... style_image,
... cldm_cond_image,
... style_subject,
... tgt_subject,
... guidance_scale=guidance_scale,
... num_inference_steps=num_inference_steps,
... neg_prompt=negative_prompt,
... height=512,
... width=512,
... ).images
>>> output[0].save("image.png")
```
"""
class BlipDiffusionControlNetPipeline(DiffusionPipeline):
"""
Pipeline for Canny Edge based Controlled subject-driven generation using Blip Diffusion.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
tokenizer ([`CLIPTokenizer`]):
Tokenizer for the text encoder
text_encoder ([`ContextCLIPTextModel`]):
Text encoder to encode the text prompt
vae ([`AutoencoderKL`]):
VAE model to map the latents to the image
unet ([`UNet2DConditionModel`]):
Conditional U-Net architecture to denoise the image embedding.
scheduler ([`PNDMScheduler`]):
A scheduler to be used in combination with `unet` to generate image latents.
qformer ([`Blip2QFormerModel`]):
QFormer model to get multi-modal embeddings from the text and image.
controlnet ([`ControlNetModel`]):
ControlNet model to get the conditioning image embedding.
image_processor ([`BlipImageProcessor`]):
Image Processor to preprocess and postprocess the image.
ctx_begin_pos (int, `optional`, defaults to 2):
Position of the context token in the text encoder.
"""
def __init__(
self,
tokenizer: CLIPTokenizer,
text_encoder: ContextCLIPTextModel,
vae: AutoencoderKL,
unet: UNet2DConditionModel,
scheduler: PNDMScheduler,
qformer: Blip2QFormerModel,
controlnet: ControlNetModel,
image_processor: BlipImageProcessor,
ctx_begin_pos: int = 2,
mean: List[float] = None,
std: List[float] = None,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
unet=unet,
scheduler=scheduler,
qformer=qformer,
controlnet=controlnet,
image_processor=image_processor,
)
self.register_to_config(ctx_begin_pos=ctx_begin_pos, mean=mean, std=std)
def get_query_embeddings(self, input_image, src_subject):
return self.qformer(image_input=input_image, text_input=src_subject, return_dict=False)
# from the original Blip Diffusion code, speciefies the target subject and augments the prompt by repeating it
def _build_prompt(self, prompts, tgt_subjects, prompt_strength=1.0, prompt_reps=20):
rv = []
for prompt, tgt_subject in zip(prompts, tgt_subjects):
prompt = f"a {tgt_subject} {prompt.strip()}"
# a trick to amplify the prompt
rv.append(", ".join([prompt] * int(prompt_strength * prompt_reps)))
return rv
# Copied from diffusers.pipelines.consistency_models.pipeline_consistency_models.ConsistencyModelPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels, height, width, dtype, device, generator, latents=None):
shape = (batch_size, num_channels, height, width)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt):
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
tokenized_prompt = self.tokenizer(
prompt,
padding="max_length",
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
text_embeddings = self.text_encoder(
input_ids=tokenized_prompt.input_ids,
ctx_embeddings=query_embeds,
ctx_begin_pos=ctx_begin_pos,
)[0]
return text_embeddings
# Adapted from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
def prepare_control_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
):
image = self.image_processor.preprocess(
image,
size={"width": width, "height": height},
do_rescale=True,
do_center_crop=False,
do_normalize=False,
return_tensors="pt",
)["pixel_values"].to(self.device)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: List[str],
reference_image: PIL.Image.Image,
condtioning_image: PIL.Image.Image,
source_subject_category: List[str],
target_subject_category: List[str],
latents: Optional[torch.FloatTensor] = None,
guidance_scale: float = 7.5,
height: int = 512,
width: int = 512,
num_inference_steps: int = 50,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
neg_prompt: Optional[str] = "",
prompt_strength: float = 1.0,
prompt_reps: int = 20,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`List[str]`):
The prompt or prompts to guide the image generation.
reference_image (`PIL.Image.Image`):
The reference image to condition the generation on.
condtioning_image (`PIL.Image.Image`):
The conditioning canny edge image to condition the generation on.
source_subject_category (`List[str]`):
The source subject category.
target_subject_category (`List[str]`):
The target subject category.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by random sampling.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
height (`int`, *optional*, defaults to 512):
The height of the generated image.
width (`int`, *optional*, defaults to 512):
The width of the generated image.
seed (`int`, *optional*, defaults to 42):
The seed to use for random generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
neg_prompt (`str`, *optional*, defaults to ""):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
prompt_strength (`float`, *optional*, defaults to 1.0):
The strength of the prompt. Specifies the number of times the prompt is repeated along with prompt_reps
to amplify the prompt.
prompt_reps (`int`, *optional*, defaults to 20):
The number of times the prompt is repeated along with prompt_strength to amplify the prompt.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
if isinstance(prompt, str):
prompt = [prompt]
if isinstance(source_subject_category, str):
source_subject_category = [source_subject_category]
if isinstance(target_subject_category, str):
target_subject_category = [target_subject_category]
batch_size = len(prompt)
prompt = self._build_prompt(
prompts=prompt,
tgt_subjects=target_subject_category,
prompt_strength=prompt_strength,
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
# 3. unconditional embedding
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
uncond_input = self.tokenizer(
[neg_prompt] * batch_size,
padding="max_length",
max_length=max_length,
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
scale_down_factor = 2 ** (len(self.unet.config.block_out_channels) - 1)
latents = self.prepare_latents(
batch_size=batch_size,
num_channels=self.unet.config.in_channels,
height=height // scale_down_factor,
width=width // scale_down_factor,
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
)
# set timesteps
extra_set_kwargs = {}
self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
cond_image = self.prepare_control_image(
image=condtioning_image,
width=width,
height=height,
batch_size=batch_size,
num_images_per_prompt=1,
device=self.device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
for i, t in enumerate(self.progress_bar(self.scheduler.timesteps)):
# expand the latents if we are doing classifier free guidance
do_classifier_free_guidance = guidance_scale > 1.0
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
down_block_res_samples, mid_block_res_sample = self.controlnet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings,
controlnet_cond=cond_image,
return_dict=False,
)
noise_pred = self.unet(
latent_model_input,
timestep=t,
encoder_hidden_states=text_embeddings,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
)["sample"]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
latents = self.scheduler.step(
noise_pred,
t,
latents,
)["prev_sample"]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
...@@ -315,6 +315,36 @@ class AutoPipelineForText2Image(metaclass=DummyObject): ...@@ -315,6 +315,36 @@ class AutoPipelineForText2Image(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class BlipDiffusionControlNetPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class BlipDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CLIPImageProjection(metaclass=DummyObject): class CLIPImageProjection(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import CLIPTokenizer
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
from transformers.models.clip.configuration_clip import CLIPTextConfig
from diffusers import AutoencoderKL, BlipDiffusionPipeline, PNDMScheduler, UNet2DConditionModel
from diffusers.utils.testing_utils import enable_full_determinism
from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class BlipDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BlipDiffusionPipeline
params = [
"prompt",
"reference_image",
"source_subject_category",
"target_subject_category",
]
batch_params = [
"prompt",
"reference_image",
"source_subject_category",
"target_subject_category",
]
required_optional_params = [
"generator",
"height",
"width",
"latents",
"guidance_scale",
"num_inference_steps",
"neg_prompt",
"guidance_scale",
"prompt_strength",
"prompt_reps",
]
def get_dummy_components(self):
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
vocab_size=1000,
hidden_size=16,
intermediate_size=16,
projection_dim=16,
num_hidden_layers=1,
num_attention_heads=1,
max_position_embeddings=77,
)
text_encoder = ContextCLIPTextModel(text_encoder_config)
vae = AutoencoderKL(
in_channels=4,
out_channels=4,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(32,),
layers_per_block=1,
act_fn="silu",
latent_channels=4,
norm_num_groups=16,
sample_size=16,
)
blip_vision_config = {
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"image_size": 224,
"patch_size": 14,
"hidden_act": "quick_gelu",
}
blip_qformer_config = {
"vocab_size": 1000,
"hidden_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"intermediate_size": 16,
"max_position_embeddings": 512,
"cross_attention_frequency": 1,
"encoder_hidden_size": 16,
}
qformer_config = Blip2Config(
vision_config=blip_vision_config,
qformer_config=blip_qformer_config,
num_query_tokens=16,
tokenizer="hf-internal-testing/tiny-random-bert",
)
qformer = Blip2QFormerModel(qformer_config)
unet = UNet2DConditionModel(
block_out_channels=(16, 32),
norm_num_groups=16,
layers_per_block=1,
sample_size=16,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=16,
)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
skip_prk_steps=True,
)
vae.eval()
qformer.eval()
text_encoder.eval()
image_processor = BlipImageProcessor()
components = {
"text_encoder": text_encoder,
"vae": vae,
"qformer": qformer,
"unet": unet,
"tokenizer": tokenizer,
"scheduler": scheduler,
"image_processor": image_processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
np.random.seed(seed)
reference_image = np.random.rand(32, 32, 3) * 255
reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "swimming underwater",
"generator": generator,
"reference_image": reference_image,
"source_subject_category": "dog",
"target_subject_category": "dog",
"height": 32,
"width": 32,
"guidance_scale": 7.5,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def test_blipdiffusion(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
image = pipe(**self.get_dummy_inputs(device))[0]
image_slice = image[0, -3:, -3:, 0]
assert image.shape == (1, 16, 16, 4)
expected_slice = np.array([0.7096, 0.5900, 0.6703, 0.4032, 0.7766, 0.3629, 0.5447, 0.4149, 0.8172])
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {image_slice.flatten()}, but got {image_slice.flatten()}"
# coding=utf-8
# Copyright 2023 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from PIL import Image
from transformers import CLIPTokenizer
from transformers.models.blip_2.configuration_blip_2 import Blip2Config
from transformers.models.clip.configuration_clip import CLIPTextConfig
from diffusers import (
AutoencoderKL,
BlipDiffusionControlNetPipeline,
ControlNetModel,
PNDMScheduler,
UNet2DConditionModel,
)
from diffusers.utils.testing_utils import enable_full_determinism
from src.diffusers.pipelines.blip_diffusion.blip_image_processing import BlipImageProcessor
from src.diffusers.pipelines.blip_diffusion.modeling_blip2 import Blip2QFormerModel
from src.diffusers.pipelines.blip_diffusion.modeling_ctx_clip import ContextCLIPTextModel
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class BlipDiffusionControlNetPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = BlipDiffusionControlNetPipeline
params = [
"prompt",
"reference_image",
"source_subject_category",
"target_subject_category",
"condtioning_image",
]
batch_params = [
"prompt",
"reference_image",
"source_subject_category",
"target_subject_category",
"condtioning_image",
]
required_optional_params = [
"generator",
"height",
"width",
"latents",
"guidance_scale",
"num_inference_steps",
"neg_prompt",
"guidance_scale",
"prompt_strength",
"prompt_reps",
]
def get_dummy_components(self):
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
vocab_size=1000,
hidden_size=16,
intermediate_size=16,
projection_dim=16,
num_hidden_layers=1,
num_attention_heads=1,
max_position_embeddings=77,
)
text_encoder = ContextCLIPTextModel(text_encoder_config)
vae = AutoencoderKL(
in_channels=4,
out_channels=4,
down_block_types=("DownEncoderBlock2D",),
up_block_types=("UpDecoderBlock2D",),
block_out_channels=(32,),
layers_per_block=1,
act_fn="silu",
latent_channels=4,
norm_num_groups=16,
sample_size=16,
)
blip_vision_config = {
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"image_size": 224,
"patch_size": 14,
"hidden_act": "quick_gelu",
}
blip_qformer_config = {
"vocab_size": 1000,
"hidden_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"intermediate_size": 16,
"max_position_embeddings": 512,
"cross_attention_frequency": 1,
"encoder_hidden_size": 16,
}
qformer_config = Blip2Config(
vision_config=blip_vision_config,
qformer_config=blip_qformer_config,
num_query_tokens=16,
tokenizer="hf-internal-testing/tiny-random-bert",
)
qformer = Blip2QFormerModel(qformer_config)
unet = UNet2DConditionModel(
block_out_channels=(4, 16),
layers_per_block=1,
norm_num_groups=4,
sample_size=16,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
cross_attention_dim=16,
)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
scheduler = PNDMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
set_alpha_to_one=False,
skip_prk_steps=True,
)
controlnet = ControlNetModel(
block_out_channels=(4, 16),
layers_per_block=1,
in_channels=4,
norm_num_groups=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
cross_attention_dim=16,
conditioning_embedding_out_channels=(8, 16),
)
vae.eval()
qformer.eval()
text_encoder.eval()
image_processor = BlipImageProcessor()
components = {
"text_encoder": text_encoder,
"vae": vae,
"qformer": qformer,
"unet": unet,
"tokenizer": tokenizer,
"scheduler": scheduler,
"controlnet": controlnet,
"image_processor": image_processor,
}
return components
def get_dummy_inputs(self, device, seed=0):
np.random.seed(seed)
reference_image = np.random.rand(32, 32, 3) * 255
reference_image = Image.fromarray(reference_image.astype("uint8")).convert("RGBA")
cond_image = np.random.rand(32, 32, 3) * 255
cond_image = Image.fromarray(cond_image.astype("uint8")).convert("RGBA")
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "swimming underwater",
"generator": generator,
"reference_image": reference_image,
"condtioning_image": cond_image,
"source_subject_category": "dog",
"target_subject_category": "dog",
"height": 32,
"width": 32,
"guidance_scale": 7.5,
"num_inference_steps": 2,
"output_type": "np",
}
return inputs
def test_blipdiffusion_controlnet(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
image = pipe(**self.get_dummy_inputs(device))[0]
image_slice = image[0, -3:, -3:, 0]
assert image.shape == (1, 16, 16, 4)
expected_slice = np.array([0.7953, 0.7136, 0.6597, 0.4779, 0.7389, 0.4111, 0.5826, 0.4150, 0.8422])
assert (
np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
), f" expected_slice {expected_slice}, but got {image_slice.flatten()}"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment