Unverified Commit be207099 authored by hlky's avatar hlky Committed by GitHub
Browse files

Support Flux IP Adapter (#10261)



* Flux IP-Adapter

* test cfg

* make style

* temp remove copied from

* fix test

* fix test

* v2

* fix

* make style

* temp remove copied from

* Apply suggestions from code review
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

* Move encoder_hid_proj to inside FluxTransformer2DModel

* merge

* separate encode_prompt, add copied from, image_encoder offload

* make

* fix test

* fix

* Update src/diffusers/pipelines/flux/pipeline_flux.py

* test_flux_prompt_embeds change not needed

* true_cfg -> true_cfg_scale

* fix merge conflict

* test_flux_ip_adapter_inference

* add fast test

* FluxIPAdapterMixin not test mixin

* Update pipeline_flux.py
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent bf9a641f
import argparse
from contextlib import nullcontext
import safetensors.torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
if is_transformers_available():
from transformers import CLIPVisionModelWithProjection
vision = True
else:
vision = False
"""
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
--filename "flux-ip-adapter.safetensors"
--output_path "flux-ip-adapter-hf/"
"""
CTX = init_empty_weights if is_accelerate_available else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--filename", default="flux.safetensors", type=str)
parser.add_argument("--checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", type=str)
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
args = parser.parse_args()
def load_original_checkpoint(args):
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
converted_state_dict = {}
# image_proj
## norm
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
## proj
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
# double transformer blocks
for i in range(num_layers):
block_prefix = f"ip_adapter.{i}."
# to_k_ip
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
)
# to_v_ip
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
)
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
)
return converted_state_dict
def main(args):
original_ckpt = load_original_checkpoint(args)
num_layers = 19
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
print("Saving Flux IP-Adapter in Diffusers format.")
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
if vision:
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
model.save_pretrained(f"{args.output_path}/image_encoder")
if __name__ == "__main__":
main(args)
...@@ -55,7 +55,7 @@ _import_structure = {} ...@@ -55,7 +55,7 @@ _import_structure = {}
if is_torch_available(): if is_torch_available():
_import_structure["single_file_model"] = ["FromOriginalModelMixin"] _import_structure["single_file_model"] = ["FromOriginalModelMixin"]
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"] _import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"] _import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
_import_structure["utils"] = ["AttnProcsLayers"] _import_structure["utils"] = ["AttnProcsLayers"]
...@@ -77,6 +77,7 @@ if is_torch_available(): ...@@ -77,6 +77,7 @@ if is_torch_available():
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"] _import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = [ _import_structure["ip_adapter"] = [
"IPAdapterMixin", "IPAdapterMixin",
"FluxIPAdapterMixin",
"SD3IPAdapterMixin", "SD3IPAdapterMixin",
] ]
...@@ -86,12 +87,14 @@ _import_structure["peft"] = ["PeftAdapterMixin"] ...@@ -86,12 +87,14 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
if is_torch_available(): if is_torch_available():
from .single_file_model import FromOriginalModelMixin from .single_file_model import FromOriginalModelMixin
from .transformer_flux import FluxTransformer2DLoadersMixin
from .transformer_sd3 import SD3Transformer2DLoadersMixin from .transformer_sd3 import SD3Transformer2DLoadersMixin
from .unet import UNet2DConditionLoadersMixin from .unet import UNet2DConditionLoadersMixin
from .utils import AttnProcsLayers from .utils import AttnProcsLayers
if is_transformers_available(): if is_transformers_available():
from .ip_adapter import ( from .ip_adapter import (
FluxIPAdapterMixin,
IPAdapterMixin, IPAdapterMixin,
SD3IPAdapterMixin, SD3IPAdapterMixin,
) )
......
...@@ -38,6 +38,8 @@ if is_transformers_available(): ...@@ -38,6 +38,8 @@ if is_transformers_available():
from ..models.attention_processor import ( from ..models.attention_processor import (
AttnProcessor, AttnProcessor,
AttnProcessor2_0, AttnProcessor2_0,
FluxAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor, IPAdapterXFormersAttnProcessor,
...@@ -353,6 +355,290 @@ class IPAdapterMixin: ...@@ -353,6 +355,290 @@ class IPAdapterMixin:
self.unet.set_attn_processor(attn_procs) self.unet.set_attn_processor(attn_procs)
class FluxIPAdapterMixin:
"""Mixin for handling Flux IP Adapters."""
@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
weight_name: Union[str, List[str]],
subfolder: Optional[Union[str, List[str]]] = "",
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
image_encoder_subfolder: Optional[str] = "",
image_encoder_dtype: torch.dtype = torch.float16,
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str` or `List[str]`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str` or `List[str]`):
The name of the weight file to load. If a list is passed, it should have the same length as
`weight_name`.
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
Can be either:
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# handle the list inputs for multiple IP Adapters
if not isinstance(weight_name, list):
weight_name = [weight_name]
if not isinstance(pretrained_model_name_or_path_or_dict, list):
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
if len(pretrained_model_name_or_path_or_dict) == 1:
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
if not isinstance(subfolder, list):
subfolder = [subfolder]
if len(subfolder) == 1:
subfolder = subfolder * len(weight_name)
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
if len(weight_name) != len(subfolder):
raise ValueError("`weight_name` and `subfolder` must have the same length.")
# Load the main state dict first.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
pretrained_model_name_or_path_or_dict, weight_name, subfolder
):
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
for key in f.keys():
if any(key.startswith(prefix) for prefix in image_proj_keys):
diffusers_name = ".".join(key.split(".")[1:])
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
diffusers_name = (
".".join(key.split(".")[1:])
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
.replace("processor.", "")
)
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict
keys = list(state_dict.keys())
if keys != ["image_proj", "ip_adapter"]:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
state_dicts.append(state_dict)
# load CLIP image encoder here if it has not been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_pretrained_model_name_or_path is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
image_encoder = (
CLIPVisionModelWithProjection.from_pretrained(
image_encoder_pretrained_model_name_or_path,
subfolder=image_encoder_subfolder,
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
)
.to(self.device, dtype=image_encoder_dtype)
.eval()
)
self.register_modules(image_encoder=image_encoder)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
)
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
default_clip_size = 224
clip_image_size = (
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
)
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
self.register_modules(feature_extractor=feature_extractor)
# load ip-adapter into transformer
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
"""
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a list.
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
number of IP adapters and each must match the number of blocks.
Example:
```py
# To use original IP-Adapter
scale = 1.0
pipeline.set_ip_adapter_scale(scale)
def LinearStrengthModel(start, finish, size):
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
pipeline.set_ip_adapter_scale(ip_strengths)
```
"""
transformer = self.transformer
if not isinstance(scale, list):
scale = [[scale] * transformer.config.num_layers]
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
if len(scale) != transformer.config.num_layers:
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")
scale = [scale]
scale_configs = scale
key_id = 0
for attn_name, attn_processor in transformer.attn_processors.items():
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
f"{len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
attn_processor.scale[i] = scale_config[key_id]
key_id += 1
def unload_ip_adapter(self):
"""
Unloads the IP Adapter weights
Examples:
```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=[None, None])
# remove feature extractor only when safety_checker is None as safety_checker uses
# the feature_extractor later
if not hasattr(self, "safety_checker"):
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=[None, None])
# remove hidden encoder
self.transformer.encoder_hid_proj = None
self.transformer.config.encoder_hid_dim_type = None
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
attn_processor_class = FluxAttnProcessor2_0()
attn_procs[name] = (
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)
class SD3IPAdapterMixin: class SD3IPAdapterMixin:
"""Mixin for handling StableDiffusion 3 IP Adapters.""" """Mixin for handling StableDiffusion 3 IP Adapters."""
......
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
from ..models.embeddings import (
ImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import load_model_dict_into_meta
from ..utils import (
is_accelerate_available,
is_torch_version,
logging,
)
if is_accelerate_available():
pass
logger = logging.get_logger(__name__)
class FluxTransformer2DLoadersMixin:
"""
Load layers into a [`FluxTransformer2DModel`].
"""
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=False):
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
updated_state_dict = {}
image_projection = None
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
if "proj.weight" in state_dict:
# IP-Adapter
num_image_text_embeds = 4
if state_dict["proj.weight"].shape[0] == 65536:
num_image_text_embeds = 16
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
with init_context():
image_projection = ImageProjection(
cross_attention_dim=cross_attention_dim,
image_embed_dim=clip_embeddings_dim,
num_image_text_embeds=num_image_text_embeds,
)
for key, value in state_dict.items():
diffusers_name = key.replace("proj", "image_embeds")
updated_state_dict[diffusers_name] = value
if not low_cpu_mem_usage:
image_projection.load_state_dict(updated_state_dict, strict=True)
else:
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
return image_projection
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=False):
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
if low_cpu_mem_usage:
if is_accelerate_available():
from accelerate import init_empty_weights
else:
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)
# set ip-adapter cross-attention processors & load state_dict
attn_procs = {}
key_id = 0
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
for name in self.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
num_image_text_embed = 4
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
num_image_text_embed = 16
# IP-Adapter
num_image_text_embeds += [num_image_text_embed]
with init_context():
attn_procs[name] = attn_processor_class(
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
num_tokens=num_image_text_embeds,
dtype=self.dtype,
device=self.device,
)
value_dict = {}
for i, state_dict in enumerate(state_dicts):
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
if not low_cpu_mem_usage:
attn_procs[name].load_state_dict(value_dict)
else:
device = self.device
dtype = self.dtype
load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype)
key_id += 1
return attn_procs
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
self.encoder_hid_proj = None
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
self.set_attn_processor(attn_procs)
image_projection_layers = []
for state_dict in state_dicts:
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
)
image_projection_layers.append(image_projection_layer)
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
self.config.encoder_hid_dim_type = "ip_image_proj"
...@@ -575,7 +575,7 @@ class Attention(nn.Module): ...@@ -575,7 +575,7 @@ class Attention(nn.Module):
# For standard processors that are defined here, `**cross_attention_kwargs` is empty # For standard processors that are defined here, `**cross_attention_kwargs` is empty
attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys()) attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
quiet_attn_parameters = {"ip_adapter_masks"} quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
unused_kwargs = [ unused_kwargs = [
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
] ]
...@@ -2653,6 +2653,149 @@ class FusedFluxAttnProcessor2_0_NPU: ...@@ -2653,6 +2653,149 @@ class FusedFluxAttnProcessor2_0_NPU:
return hidden_states return hidden_states
class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
"""Flux Attention processor for IP-Adapter."""
def __init__(
self, hidden_size: int, cross_attention_dim: int, num_tokens=(4,), scale=1.0, device=None, dtype=None
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale
self.to_k_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
self.to_v_ip = nn.ModuleList(
[
nn.Linear(cross_attention_dim, hidden_size, bias=True, device=device, dtype=dtype)
for _ in range(len(num_tokens))
]
)
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
ip_hidden_states: Optional[List[torch.Tensor]] = None,
ip_adapter_masks: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
# `sample` projections.
hidden_states_query_proj = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
hidden_states_query_proj = hidden_states_query_proj.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
if attn.norm_q is not None:
hidden_states_query_proj = attn.norm_q(hidden_states_query_proj)
if attn.norm_k is not None:
key = attn.norm_k(key)
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
# attention
query = torch.cat([encoder_hidden_states_query_proj, hidden_states_query_proj], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
if image_rotary_emb is not None:
from .embeddings import apply_rotary_emb
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
if encoder_hidden_states is not None:
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
# IP-adapter
ip_query = hidden_states_query_proj
ip_attn_output = None
# for ip-adapter
# TODO: support for multiple adapters
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_attn_output = F.scaled_dot_product_attention(
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_attn_output = scale * ip_attn_output
ip_attn_output = ip_attn_output.to(ip_query.dtype)
return hidden_states, encoder_hidden_states, ip_attn_output
else:
return hidden_states
class CogVideoXAttnProcessor2_0: class CogVideoXAttnProcessor2_0:
r""" r"""
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
...@@ -5896,6 +6039,7 @@ CROSS_ATTENTION_PROCESSORS = ( ...@@ -5896,6 +6039,7 @@ CROSS_ATTENTION_PROCESSORS = (
SlicedAttnProcessor, SlicedAttnProcessor,
IPAdapterAttnProcessor, IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0, IPAdapterAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
) )
AttentionProcessor = Union[ AttentionProcessor = Union[
......
...@@ -1535,7 +1535,7 @@ class ImageProjection(nn.Module): ...@@ -1535,7 +1535,7 @@ class ImageProjection(nn.Module):
batch_size = image_embeds.shape[0] batch_size = image_embeds.shape[0]
# image # image
image_embeds = self.image_embeds(image_embeds) image_embeds = self.image_embeds(image_embeds.to(self.image_embeds.weight.dtype))
image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
image_embeds = self.norm(image_embeds) image_embeds = self.norm(image_embeds)
return image_embeds return image_embeds
......
...@@ -21,7 +21,7 @@ import torch.nn as nn ...@@ -21,7 +21,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...models.attention import FeedForward from ...models.attention import FeedForward
from ...models.attention_processor import ( from ...models.attention_processor import (
Attention, Attention,
...@@ -177,13 +177,18 @@ class FluxTransformerBlock(nn.Module): ...@@ -177,13 +177,18 @@ class FluxTransformerBlock(nn.Module):
) )
joint_attention_kwargs = joint_attention_kwargs or {} joint_attention_kwargs = joint_attention_kwargs or {}
# Attention. # Attention.
attn_output, context_attn_output = self.attn( attention_outputs = self.attn(
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs, **joint_attention_kwargs,
) )
if len(attention_outputs) == 2:
attn_output, context_attn_output = attention_outputs
elif len(attention_outputs) == 3:
attn_output, context_attn_output, ip_attn_output = attention_outputs
# Process attention outputs for the `hidden_states`. # Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output hidden_states = hidden_states + attn_output
...@@ -195,6 +200,8 @@ class FluxTransformerBlock(nn.Module): ...@@ -195,6 +200,8 @@ class FluxTransformerBlock(nn.Module):
ff_output = gate_mlp.unsqueeze(1) * ff_output ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = hidden_states + ff_output hidden_states = hidden_states + ff_output
if len(attention_outputs) == 3:
hidden_states = hidden_states + ip_attn_output
# Process attention outputs for the `encoder_hidden_states`. # Process attention outputs for the `encoder_hidden_states`.
...@@ -212,7 +219,9 @@ class FluxTransformerBlock(nn.Module): ...@@ -212,7 +219,9 @@ class FluxTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states return encoder_hidden_states, hidden_states
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): class FluxTransformer2DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
):
""" """
The Transformer model introduced in Flux. The Transformer model introduced in Flux.
...@@ -482,6 +491,11 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig ...@@ -482,6 +491,11 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
ids = torch.cat((txt_ids, img_ids), dim=0) ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids) image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
for index_block, block in enumerate(self.transformer_blocks): for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
......
...@@ -17,10 +17,17 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -17,10 +17,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTokenizer,
CLIPVisionModelWithProjection,
T5EncoderModel,
T5TokenizerFast,
)
from ...image_processor import VaeImageProcessor from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKL
from ...models.transformers import FluxTransformer2DModel from ...models.transformers import FluxTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
...@@ -142,6 +149,7 @@ class FluxPipeline( ...@@ -142,6 +149,7 @@ class FluxPipeline(
FluxLoraLoaderMixin, FluxLoraLoaderMixin,
FromSingleFileMixin, FromSingleFileMixin,
TextualInversionLoaderMixin, TextualInversionLoaderMixin,
FluxIPAdapterMixin,
): ):
r""" r"""
The Flux pipeline for text-to-image generation. The Flux pipeline for text-to-image generation.
...@@ -169,8 +177,8 @@ class FluxPipeline( ...@@ -169,8 +177,8 @@ class FluxPipeline(
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast). [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
""" """
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
_optional_components = [] _optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds"] _callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__( def __init__(
...@@ -182,6 +190,8 @@ class FluxPipeline( ...@@ -182,6 +190,8 @@ class FluxPipeline(
text_encoder_2: T5EncoderModel, text_encoder_2: T5EncoderModel,
tokenizer_2: T5TokenizerFast, tokenizer_2: T5TokenizerFast,
transformer: FluxTransformer2DModel, transformer: FluxTransformer2DModel,
image_encoder: CLIPVisionModelWithProjection = None,
feature_extractor: CLIPImageProcessor = None,
): ):
super().__init__() super().__init__()
...@@ -193,6 +203,8 @@ class FluxPipeline( ...@@ -193,6 +203,8 @@ class FluxPipeline(
tokenizer_2=tokenizer_2, tokenizer_2=tokenizer_2,
transformer=transformer, transformer=transformer,
scheduler=scheduler, scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=feature_extractor,
) )
self.vae_scale_factor = ( self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
...@@ -377,14 +389,60 @@ class FluxPipeline( ...@@ -377,14 +389,60 @@ class FluxPipeline(
return prompt_embeds, pooled_prompt_embeds, text_ids return prompt_embeds, pooled_prompt_embeds, text_ids
def encode_image(self, image, device, num_images_per_prompt):
dtype = next(self.image_encoder.parameters()).dtype
if not isinstance(image, torch.Tensor):
image = self.feature_extractor(image, return_tensors="pt").pixel_values
image = image.to(device=device, dtype=dtype)
image_embeds = self.image_encoder(image).image_embeds
image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
return image_embeds
def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
):
image_embeds = []
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
)
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
):
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
image_embeds.append(single_image_embeds[None, :])
else:
for single_image_embeds in ip_adapter_image_embeds:
image_embeds.append(single_image_embeds)
ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
return ip_adapter_image_embeds
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
prompt_2, prompt_2,
height, height,
width, width,
negative_prompt=None,
negative_prompt_2=None,
prompt_embeds=None, prompt_embeds=None,
negative_prompt_embeds=None,
pooled_prompt_embeds=None, pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None, callback_on_step_end_tensor_inputs=None,
max_sequence_length=None, max_sequence_length=None,
): ):
...@@ -419,10 +477,33 @@ class FluxPipeline( ...@@ -419,10 +477,33 @@ class FluxPipeline(
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds is not None and pooled_prompt_embeds is None: if prompt_embeds is not None and pooled_prompt_embeds is None:
raise ValueError( raise ValueError(
"If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
) )
if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
raise ValueError(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
if max_sequence_length is not None and max_sequence_length > 512: if max_sequence_length is not None and max_sequence_length > 512:
raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}") raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
...@@ -551,6 +632,9 @@ class FluxPipeline( ...@@ -551,6 +632,9 @@ class FluxPipeline(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
prompt_2: Optional[Union[str, List[str]]] = None, prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt: Union[str, List[str]] = None,
negative_prompt_2: Optional[Union[str, List[str]]] = None,
true_cfg_scale: float = 1.0,
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_inference_steps: int = 28, num_inference_steps: int = 28,
...@@ -561,6 +645,12 @@ class FluxPipeline( ...@@ -561,6 +645,12 @@ class FluxPipeline(
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_ip_adapter_image: Optional[PipelineImageInput] = None,
negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None,
...@@ -610,6 +700,17 @@ class FluxPipeline( ...@@ -610,6 +700,17 @@ class FluxPipeline(
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument. If not provided, pooled text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_ip_adapter_image:
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
...@@ -647,8 +748,12 @@ class FluxPipeline( ...@@ -647,8 +748,12 @@ class FluxPipeline(
prompt_2, prompt_2,
height, height,
width, width,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
) )
...@@ -670,6 +775,7 @@ class FluxPipeline( ...@@ -670,6 +775,7 @@ class FluxPipeline(
lora_scale = ( lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
) )
do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
( (
prompt_embeds, prompt_embeds,
pooled_prompt_embeds, pooled_prompt_embeds,
...@@ -684,6 +790,21 @@ class FluxPipeline( ...@@ -684,6 +790,21 @@ class FluxPipeline(
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
lora_scale=lora_scale, lora_scale=lora_scale,
) )
if do_true_cfg:
(
negative_prompt_embeds,
negative_pooled_prompt_embeds,
_,
) = self.encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)
# 4. Prepare latent variables # 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4 num_channels_latents = self.transformer.config.in_channels // 4
...@@ -725,12 +846,43 @@ class FluxPipeline( ...@@ -725,12 +846,43 @@ class FluxPipeline(
else: else:
guidance = None guidance = None
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
image_embeds = None
negative_image_embeds = None
if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
negative_image_embeds = self.prepare_ip_adapter_image_embeds(
negative_ip_adapter_image,
negative_ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
)
# 6. Denoising loop # 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
if self.interrupt: if self.interrupt:
continue continue
if image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype) timestep = t.expand(latents.shape[0]).to(latents.dtype)
...@@ -746,6 +898,22 @@ class FluxPipeline( ...@@ -746,6 +898,22 @@ class FluxPipeline(
return_dict=False, return_dict=False,
)[0] )[0]
if do_true_cfg:
if negative_image_embeds is not None:
self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
neg_noise_pred = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
pooled_projections=negative_pooled_prompt_embeds,
encoder_hidden_states=negative_prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
......
...@@ -403,7 +403,6 @@ class FluxControlPipeline( ...@@ -403,7 +403,6 @@ class FluxControlPipeline(
return prompt_embeds, pooled_prompt_embeds, text_ids return prompt_embeds, pooled_prompt_embeds, text_ids
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.check_inputs
def check_inputs( def check_inputs(
self, self,
prompt, prompt,
......
...@@ -18,6 +18,8 @@ import unittest ...@@ -18,6 +18,8 @@ import unittest
import torch import torch
from diffusers import FluxTransformer2DModel from diffusers import FluxTransformer2DModel
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
from diffusers.models.embeddings import ImageProjection
from diffusers.utils.testing_utils import enable_full_determinism, torch_device from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin from ..test_modeling_common import ModelTesterMixin
...@@ -26,6 +28,56 @@ from ..test_modeling_common import ModelTesterMixin ...@@ -26,6 +28,56 @@ from ..test_modeling_common import ModelTesterMixin
enable_full_determinism() enable_full_determinism()
def create_flux_ip_adapter_state_dict(model):
# "ip_adapter" (cross-attention weights)
ip_cross_attn_state_dict = {}
key_id = 0
for name in model.attn_processors.keys():
if name.startswith("single_transformer_blocks"):
continue
joint_attention_dim = model.config["joint_attention_dim"]
hidden_size = model.config["num_attention_heads"] * model.config["attention_head_dim"]
sd = FluxIPAdapterJointAttnProcessor2_0(
hidden_size=hidden_size, cross_attention_dim=joint_attention_dim, scale=1.0
).state_dict()
ip_cross_attn_state_dict.update(
{
f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"],
f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"],
f"{key_id}.to_k_ip.bias": sd["to_k_ip.0.bias"],
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
}
)
key_id += 1
# "image_proj" (ImageProjection layer weights)
image_projection = ImageProjection(
cross_attention_dim=model.config["joint_attention_dim"],
image_embed_dim=model.config["pooled_projection_dim"],
num_image_text_embeds=4,
)
ip_image_projection_state_dict = {}
sd = image_projection.state_dict()
ip_image_projection_state_dict.update(
{
"proj.weight": sd["image_embeds.weight"],
"proj.bias": sd["image_embeds.bias"],
"norm.weight": sd["norm.weight"],
"norm.bias": sd["norm.bias"],
}
)
del sd
ip_state_dict = {}
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
return ip_state_dict
class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): class FluxTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = FluxTransformer2DModel model_class = FluxTransformer2DModel
main_input_name = "hidden_states" main_input_name = "hidden_states"
......
...@@ -16,13 +16,14 @@ from diffusers.utils.testing_utils import ( ...@@ -16,13 +16,14 @@ from diffusers.utils.testing_utils import (
) )
from ..test_pipelines_common import ( from ..test_pipelines_common import (
FluxIPAdapterTesterMixin,
PipelineTesterMixin, PipelineTesterMixin,
check_qkv_fusion_matches_attn_procs_length, check_qkv_fusion_matches_attn_procs_length,
check_qkv_fusion_processors_exist, check_qkv_fusion_processors_exist,
) )
class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin): class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin, FluxIPAdapterTesterMixin):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
batch_params = frozenset(["prompt"]) batch_params = frozenset(["prompt"])
...@@ -91,6 +92,8 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin): ...@@ -91,6 +92,8 @@ class FluxPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
"tokenizer_2": tokenizer_2, "tokenizer_2": tokenizer_2,
"transformer": transformer, "transformer": transformer,
"vae": vae, "vae": vae,
"image_encoder": None,
"feature_extractor": None,
} }
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
...@@ -296,3 +299,112 @@ class FluxPipelineSlowTests(unittest.TestCase): ...@@ -296,3 +299,112 @@ class FluxPipelineSlowTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4 assert max_diff < 1e-4
@slow
@require_big_gpu_with_torch_cuda
@pytest.mark.big_gpu_with_torch_cuda
class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev"
image_encoder_pretrained_model_name_or_path = "openai/clip-vit-large-patch14"
weight_name = "ip_adapter.safetensors"
ip_adapter_repo_id = "XLabs-AI/flux-ip-adapter"
def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)
prompt_embeds = torch.load(
hf_hub_download(repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/prompt_embeds.pt")
)
pooled_prompt_embeds = torch.load(
hf_hub_download(
repo_id="diffusers/test-slices", repo_type="dataset", filename="flux/pooled_prompt_embeds.pt"
)
)
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
ip_adapter_image = np.zeros((1024, 1024, 3), dtype=np.uint8)
return {
"prompt_embeds": prompt_embeds,
"pooled_prompt_embeds": pooled_prompt_embeds,
"negative_prompt_embeds": negative_prompt_embeds,
"negative_pooled_prompt_embeds": negative_pooled_prompt_embeds,
"ip_adapter_image": ip_adapter_image,
"num_inference_steps": 2,
"guidance_scale": 3.5,
"true_cfg_scale": 4.0,
"max_sequence_length": 256,
"output_type": "np",
"generator": generator,
}
def test_flux_ip_adapter_inference(self):
pipe = self.pipeline_class.from_pretrained(
self.repo_id, torch_dtype=torch.bfloat16, text_encoder=None, text_encoder_2=None
)
pipe.load_ip_adapter(
self.ip_adapter_repo_id,
weight_name=self.weight_name,
image_encoder_pretrained_model_name_or_path=self.image_encoder_pretrained_model_name_or_path,
)
pipe.set_ip_adapter_scale(1.0)
pipe.enable_model_cpu_offload()
inputs = self.get_inputs(torch_device)
image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
0.1855,
0.1680,
0.1406,
0.1953,
0.1699,
0.1465,
0.2012,
0.1738,
0.1484,
0.2051,
0.1797,
0.1523,
0.2012,
0.1719,
0.1445,
0.2070,
0.1777,
0.1465,
0.2090,
0.1836,
0.1484,
0.2129,
0.1875,
0.1523,
0.2090,
0.1816,
0.1484,
0.2110,
0.1836,
0.1543,
],
dtype=np.float32,
)
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4, f"{image_slice} != {expected_slice}"
...@@ -29,7 +29,7 @@ from diffusers import ( ...@@ -29,7 +29,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import IPAdapterMixin from diffusers.loaders import FluxIPAdapterMixin, IPAdapterMixin
from diffusers.models.attention_processor import AttnProcessor from diffusers.models.attention_processor import AttnProcessor
from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel from diffusers.models.controlnets.controlnet_xs import UNetControlNetXSModel
from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel from diffusers.models.unets.unet_3d_condition import UNet3DConditionModel
...@@ -54,6 +54,7 @@ from ..models.autoencoders.vae import ( ...@@ -54,6 +54,7 @@ from ..models.autoencoders.vae import (
get_autoencoder_tiny_config, get_autoencoder_tiny_config,
get_consistency_vae_config, get_consistency_vae_config,
) )
from ..models.transformers.test_models_transformer_flux import create_flux_ip_adapter_state_dict
from ..models.unets.test_models_unet_2d_condition import ( from ..models.unets.test_models_unet_2d_condition import (
create_ip_adapter_faceid_state_dict, create_ip_adapter_faceid_state_dict,
create_ip_adapter_state_dict, create_ip_adapter_state_dict,
...@@ -483,6 +484,94 @@ class IPAdapterTesterMixin: ...@@ -483,6 +484,94 @@ class IPAdapterTesterMixin:
) )
class FluxIPAdapterTesterMixin:
"""
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
It provides a set of common tests for pipelines that support IP Adapters.
"""
def test_pipeline_signature(self):
parameters = inspect.signature(self.pipeline_class.__call__).parameters
assert issubclass(self.pipeline_class, FluxIPAdapterMixin)
self.assertIn(
"ip_adapter_image",
parameters,
"`ip_adapter_image` argument must be supported by the `__call__` method",
)
self.assertIn(
"ip_adapter_image_embeds",
parameters,
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method",
)
def _get_dummy_image_embeds(self, image_embed_dim: int = 768):
return torch.randn((1, 1, image_embed_dim), device=torch_device)
def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]):
inputs["negative_prompt"] = ""
inputs["true_cfg_scale"] = 4.0
inputs["output_type"] = "np"
inputs["return_dict"] = False
return inputs
def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=None):
r"""Tests for IP-Adapter.
The following scenarios are tested:
- Single IP-Adapter with scale=0 should produce same output as no IP-Adapter.
- Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter.
"""
# Raising the tolerance for this test when it's run on a CPU because we
# compare against static slices and that can be shaky (with a VVVV low probability).
expected_max_diff = 9e-4 if torch_device == "cpu" else expected_max_diff
components = self.get_dummy_components()
pipe = self.pipeline_class(**components).to(torch_device)
pipe.set_progress_bar_config(disable=None)
image_embed_dim = pipe.transformer.config.pooled_projection_dim
# forward pass without ip adapter
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
if expected_pipe_slice is None:
output_without_adapter = pipe(**inputs)[0]
else:
output_without_adapter = expected_pipe_slice
adapter_state_dict = create_flux_ip_adapter_state_dict(pipe.transformer)
pipe.transformer._load_ip_adapter_weights(adapter_state_dict)
# forward pass with single ip adapter, but scale=0 which should have no effect
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
pipe.set_ip_adapter_scale(0.0)
output_without_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten()
# forward pass with single ip adapter, but with scale of adapter weights
inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device))
inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
inputs["negative_ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(image_embed_dim)]
pipe.set_ip_adapter_scale(42.0)
output_with_adapter_scale = pipe(**inputs)[0]
if expected_pipe_slice is not None:
output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten()
max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max()
max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max()
self.assertLess(
max_diff_without_adapter_scale,
expected_max_diff,
"Output without ip-adapter must be same as normal inference",
)
self.assertGreater(
max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference"
)
class PipelineLatentTesterMixin: class PipelineLatentTesterMixin:
""" """
This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment