Unverified Commit 06b7a518 authored by SMG's avatar SMG Committed by GitHub
Browse files

feat: enable IP-Adapter (XLabs-AI/flux-ip-adapter-v2) support (#418)



* feat: support IP-adapter

* FBCache and comfyUI

* fixing conflicts

* update

* update example

* update example

* style: make linter happy

* update

* update ipa test

* add docs and rename IP to ip

* docs: add docs for ipa

* docs: add docs for ipa

* add an example for pulid

* update

* save gpu memory

* change the threshold to 0.8

---------
Co-authored-by: default avatarMuyang Li <lmxyy1999@foxmail.com>
parent 24c2f925
"""
IP-Adapter utility functions and classes for FluxTransformer2DModel.
This module provides the core implementation for integrating IP-Adapter
conditioning into Flux-based transformer models, including block modification,
weight loading, and image embedding support.
"""
import cv2
import torch
import torch.nn.functional as F
from diffusers import FluxTransformer2DModel
from huggingface_hub import hf_hub_download
from safetensors import safe_open
from torch import nn
from nunchaku.caching.utils import FluxCachedTransformerBlocks, check_and_apply_cache
from nunchaku.models.transformers.utils import pad_tensor
num_transformer_blocks = 19 # FIXME
num_single_transformer_blocks = 38 # FIXME
class IPA_TransformerBlocks(FluxCachedTransformerBlocks):
"""
Transformer block wrapper for IP-Adapter integration.
This class extends FluxCachedTransformerBlocks to enable per-layer
IP-Adapter conditioning, efficient caching, and flexible output control.
Parameters
----------
transformer : nn.Module, optional
The base transformer module to wrap.
ip_adapter_scale : float, default=1.0
Scaling factor for the IP-Adapter output.
return_hidden_states_first : bool, default=True
If True, return hidden states before encoder states.
return_hidden_states_only : bool, default=False
If True, return only hidden states.
verbose : bool, default=False
If True, print verbose debug information.
device : str or torch.device
Device to use for computation.
Attributes
----------
ip_adapter_scale : float
Scaling factor for IP-Adapter output.
image_embeds : torch.Tensor or None
Image embeddings for IP-Adapter conditioning.
"""
def __init__(
self,
*,
transformer: nn.Module = None,
ip_adapter_scale: float = 1.0,
return_hidden_states_first: bool = True,
return_hidden_states_only: bool = False,
verbose: bool = False,
device: str | torch.device,
):
super().__init__(
transformer=transformer,
use_double_fb_cache=False,
residual_diff_threshold_multi=-1,
residual_diff_threshold_single=-1,
return_hidden_states_first=return_hidden_states_first,
return_hidden_states_only=return_hidden_states_only,
verbose=verbose,
)
self.ip_adapter_scale = ip_adapter_scale
self.image_embeds = None
def forward(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
image_rotary_emb: torch.Tensor,
id_embeddings=None,
id_weight=None,
joint_attention_kwargs=None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=False,
):
"""
Forward pass with IP-Adapter conditioning.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states.
temb : torch.Tensor
Temporal embedding tensor.
encoder_hidden_states : torch.Tensor
Encoder hidden states.
image_rotary_emb : torch.Tensor
Rotary embedding for image tokens.
id_embeddings : optional
Not used.
id_weight : optional
Not used.
joint_attention_kwargs : dict, optional
Additional attention arguments, may include 'ip_hidden_states'.
controlnet_block_samples : list, optional
ControlNet block samples for multi-blocks.
controlnet_single_block_samples : list, optional
ControlNet block samples for single blocks.
skip_first_layer : bool, default=False
If True, skip the first transformer block.
Returns
-------
tuple or torch.Tensor
Final hidden states and encoder states, or only hidden states if
`return_hidden_states_only` is True.
"""
batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
original_dtype = hidden_states.dtype
original_device = hidden_states.device
hidden_states = hidden_states.to(self.dtype).to(original_device)
encoder_hidden_states = encoder_hidden_states.to(self.dtype).to(original_device)
temb = temb.to(self.dtype).to(original_device)
image_rotary_emb = image_rotary_emb.to(original_device)
if controlnet_block_samples is not None:
controlnet_block_samples = (
torch.stack(controlnet_block_samples).to(original_device) if len(controlnet_block_samples) > 0 else None
)
if controlnet_single_block_samples is not None:
controlnet_single_block_samples = (
torch.stack(controlnet_single_block_samples).to(original_device)
if len(controlnet_single_block_samples) > 0
else None
)
assert image_rotary_emb.ndim == 6
assert image_rotary_emb.shape[0] == 1
assert image_rotary_emb.shape[1] == 1
# [1, tokens, head_dim/2, 1, 2] (sincos)
total_tokens = txt_tokens + img_tokens
assert image_rotary_emb.shape[2] == 1 * total_tokens
image_rotary_emb = image_rotary_emb.reshape([1, txt_tokens + img_tokens, *image_rotary_emb.shape[3:]])
rotary_emb_txt = image_rotary_emb[:, :txt_tokens, ...]
rotary_emb_img = image_rotary_emb[:, txt_tokens:, ...]
rotary_emb_single = image_rotary_emb
rotary_emb_txt = self.pack_rotemb(pad_tensor(rotary_emb_txt, 256, 1))
rotary_emb_img = self.pack_rotemb(pad_tensor(rotary_emb_img, 256, 1))
rotary_emb_single = self.pack_rotemb(pad_tensor(rotary_emb_single, 256, 1))
if joint_attention_kwargs is not None and "ip_hidden_states" in joint_attention_kwargs:
ip_hidden_states = joint_attention_kwargs.pop("ip_hidden_states")
elif self.image_embeds is not None:
ip_hidden_states = self.image_embeds
remaining_kwargs = {
"temb": temb,
"rotary_emb_img": rotary_emb_img,
"rotary_emb_txt": rotary_emb_txt,
"rotary_emb_single": rotary_emb_single,
"controlnet_block_samples": controlnet_block_samples,
"controlnet_single_block_samples": controlnet_single_block_samples,
"txt_tokens": txt_tokens,
"ip_hidden_states": ip_hidden_states if ip_hidden_states is not None else None,
}
torch._dynamo.graph_break()
if (self.residual_diff_threshold_multi <= 0.0) or (batch_size > 1):
updated_h, updated_enc, _, _ = self.call_IPA_multi_transformer_blocks(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
skip_block=False,
**remaining_kwargs,
)
remaining_kwargs.pop("ip_hidden_states", None)
cat_hidden_states = torch.cat([updated_enc, updated_h], dim=1)
updated_cat, _ = self.call_remaining_single_transformer_blocks(
hidden_states=cat_hidden_states, encoder_hidden_states=None, start_idx=0, **remaining_kwargs
)
# torch._dynamo.graph_break()
final_enc = updated_cat[:, :txt_tokens, ...]
final_h = updated_cat[:, txt_tokens:, ...]
final_h = final_h.to(original_dtype).to(original_device)
final_enc = final_enc.to(original_dtype).to(original_device)
if self.return_hidden_states_only:
return final_h
if self.return_hidden_states_first:
return final_h, final_enc
return final_enc, final_h
original_hidden_states = hidden_states
first_hidden_states, first_encoder_hidden_states, _, _ = self.call_IPA_multi_transformer_blocks(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
first_block=True,
skip_block=False,
**remaining_kwargs,
)
hidden_states = first_hidden_states
encoder_hidden_states = first_encoder_hidden_states
first_hidden_states_residual_multi = hidden_states - original_hidden_states
del original_hidden_states
call_remaining_fn = self.call_IPA_multi_transformer_blocks
torch._dynamo.graph_break()
updated_h, updated_enc, threshold = check_and_apply_cache(
first_residual=first_hidden_states_residual_multi,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
threshold=self.residual_diff_threshold_multi,
parallelized=False,
mode="multi",
verbose=self.verbose,
call_remaining_fn=call_remaining_fn,
remaining_kwargs=remaining_kwargs,
)
self.residual_diff_threshold_multi = threshold
# Single layer
remaining_kwargs.pop("ip_hidden_states", None)
cat_hidden_states = torch.cat([updated_enc, updated_h], dim=1)
original_cat = cat_hidden_states
if not self.use_double_fb_cache:
##NO FBCache
updated_cat, _ = self.call_remaining_single_transformer_blocks(
hidden_states=cat_hidden_states, encoder_hidden_states=None, start_idx=0, **remaining_kwargs
)
else:
# USE FBCache
cat_hidden_states = self.m.forward_single_layer(0, cat_hidden_states, temb, rotary_emb_single)
first_hidden_states_residual_single = cat_hidden_states - original_cat
del original_cat
call_remaining_fn_single = self.call_remaining_single_transformer_blocks
updated_cat, _, threshold = check_and_apply_cache(
first_residual=first_hidden_states_residual_single,
hidden_states=cat_hidden_states,
encoder_hidden_states=None,
threshold=self.residual_diff_threshold_single,
parallelized=False,
mode="single",
verbose=self.verbose,
call_remaining_fn=call_remaining_fn_single,
remaining_kwargs=remaining_kwargs,
)
self.residual_diff_threshold_single = threshold
# torch._dynamo.graph_break()
final_enc = updated_cat[:, :txt_tokens, ...]
final_h = updated_cat[:, txt_tokens:, ...]
final_h = final_h.to(original_dtype).to(original_device)
final_enc = final_enc.to(original_dtype).to(original_device)
if self.return_hidden_states_only:
return final_h
if self.return_hidden_states_first:
return final_h, final_enc
return final_enc, final_h
def call_IPA_multi_transformer_blocks(
self,
hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_hidden_states: torch.Tensor,
rotary_emb_img: torch.Tensor,
rotary_emb_txt: torch.Tensor,
rotary_emb_single: torch.Tensor,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
skip_first_layer=False,
txt_tokens=None,
ip_hidden_states=None,
first_block: bool = False,
skip_block: bool = True,
):
"""
Apply IP-Adapter conditioning to multiple transformer blocks.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states.
temb : torch.Tensor
Temporal embedding tensor.
encoder_hidden_states : torch.Tensor
Encoder hidden states.
rotary_emb_img : torch.Tensor
Rotary embedding for image tokens.
rotary_emb_txt : torch.Tensor
Rotary embedding for text tokens.
rotary_emb_single : torch.Tensor
Rotary embedding for single block.
controlnet_block_samples : list, optional
ControlNet block samples for multi-blocks.
controlnet_single_block_samples : list, optional
ControlNet block samples for single blocks.
skip_first_layer : bool, default=False
If True, skip the first transformer block.
txt_tokens : int, optional
Number of text tokens.
ip_hidden_states : torch.Tensor, optional
Image prompt hidden states.
first_block : bool, default=False
If True, only process the first block.
skip_block : bool, default=True
If True, skip the first block.
Returns
-------
tuple
(hidden_states, encoder_hidden_states, hidden_states_residual, encoder_hidden_states_residual)
"""
if first_block and skip_block:
raise ValueError("`first_block` and `skip_block` cannot both be True.")
start_idx = 1 if skip_block else 0
end_idx = 1 if first_block else num_transformer_blocks
original_hidden_states = hidden_states.clone()
original_encoder_hidden_states = encoder_hidden_states.clone()
ip_hidden_states[0] = ip_hidden_states[0].to(self.dtype).to(self.device)
for idx in range(start_idx, end_idx):
k_img = self.ip_k_projs[idx](ip_hidden_states[0])
v_img = self.ip_v_projs[idx](ip_hidden_states[0])
hidden_states, encoder_hidden_states, ip_query = self.m.forward_layer_ip_adapter(
idx,
hidden_states,
encoder_hidden_states,
temb,
rotary_emb_img,
rotary_emb_txt,
controlnet_block_samples,
controlnet_single_block_samples,
)
ip_query = ip_query.contiguous().to(self.dtype)
ip_query = ip_query.view(1, -1, 24, 128).transpose(1, 2)
k_img = k_img.view(1, -1, 24, 128).transpose(1, 2)
v_img = v_img.view(1, -1, 24, 128).transpose(1, 2)
real_ip_attn_output = F.scaled_dot_product_attention(
ip_query, k_img, v_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
real_ip_attn_output = real_ip_attn_output.transpose(1, 2).reshape(1, -1, 24 * 128)
hidden_states = hidden_states + self.ip_adapter_scale * real_ip_attn_output
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
hs_res = hidden_states - original_hidden_states
enc_res = encoder_hidden_states - original_encoder_hidden_states
return hidden_states, encoder_hidden_states, hs_res, enc_res
def load_ip_adapter_weights_per_layer(
self,
repo_id: str,
filename: str = "ip_adapter.safetensors",
prefix: str = "double_blocks.",
joint_attention_dim: int = 4096,
inner_dim: int = 3072,
):
"""
Load per-layer IP-Adapter weights from a HuggingFace Hub repository.
Parameters
----------
repo_id : str
HuggingFace Hub repository ID.
filename : str, default="ip_adapter.safetensors"
Name of the safetensors file.
prefix : str, default="double_blocks."
Prefix for block keys in the file.
joint_attention_dim : int, default=4096
Input dimension for joint attention.
inner_dim : int, default=3072
Output dimension for projections.
Returns
-------
None
"""
path = hf_hub_download(repo_id=repo_id, filename=filename)
raw_cpu = {}
with safe_open(path, framework="pt", device="cpu") as f:
for key in f.keys():
if key.startswith(prefix):
raw_cpu[key] = f.get_tensor(key)
raw = {k: v.to(self.device) for k, v in raw_cpu.items()}
layer_ids = sorted({int(k.split(".")[1]) for k in raw.keys()})
layers = []
for i in layer_ids:
base = f"double_blocks.{i}.processor.ip_adapter_double_stream"
layers.append(
{
"k_weight": raw[f"{base}_k_proj.weight"],
"k_bias": raw[f"{base}_k_proj.bias"],
"v_weight": raw[f"{base}_v_proj.weight"],
"v_bias": raw[f"{base}_v_proj.bias"],
}
)
cross_dim = joint_attention_dim
hidden_dim = inner_dim
self.ip_k_projs = nn.ModuleList()
self.ip_v_projs = nn.ModuleList()
for layer in layers:
k_proj = nn.Linear(cross_dim, hidden_dim, bias=True, device=self.device, dtype=self.dtype)
v_proj = nn.Linear(cross_dim, hidden_dim, bias=True, device=self.device, dtype=self.dtype)
k_proj.weight.data.copy_(layer["k_weight"])
k_proj.bias.data.copy_(layer["k_bias"])
v_proj.weight.data.copy_(layer["v_weight"])
v_proj.bias.data.copy_(layer["v_bias"])
self.ip_k_projs.append(k_proj)
self.ip_v_projs.append(v_proj)
def set_ip_hidden_states(self, image_embeds, negative_image_embeds=None):
"""
Set the image embeddings for IP-Adapter conditioning.
Parameters
----------
image_embeds : torch.Tensor
Image embeddings to use.
negative_image_embeds : optional
Not used.
Returns
-------
None
"""
self.image_embeds = image_embeds
def resize_numpy_image_long(image, resize_long_edge=768):
"""
Resize a numpy image so its longest edge matches a target size.
Parameters
----------
image : np.ndarray
Input image as a numpy array.
resize_long_edge : int, default=768
Target size for the longest edge.
Returns
-------
np.ndarray
Resized image.
"""
h, w = image.shape[:2]
if max(h, w) <= resize_long_edge:
return image
k = resize_long_edge / max(h, w)
h = int(h * k)
w = int(w * k)
image = cv2.resize(image, (w, h), interpolation=cv2.INTER_LANCZOS4)
return image
def undo_all_mods_on_transformer(transformer: FluxTransformer2DModel):
"""
Restore a FluxTransformer2DModel to its original, unmodified state.
This function undoes any modifications made for IP-Adapter integration,
restoring the original forward method and transformer blocks.
Parameters
----------
transformer : FluxTransformer2DModel
The transformer model to restore.
Returns
-------
FluxTransformer2DModel
The restored transformer model.
"""
if hasattr(transformer, "_original_forward"):
transformer.forward = transformer._original_forward
del transformer._original_forward
if hasattr(transformer, "_original_blocks"):
transformer.transformer_blocks = transformer._original_blocks
del transformer._original_blocks
return transformer
......@@ -3,3 +3,5 @@ insightface
opencv-python
facexlib
onnxruntime
# ip-adapter
timm
......@@ -775,6 +775,453 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
return {hidden_states, encoder_hidden_states};
}
Tensor JointTransformerBlock::get_q_heads(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb,
Tensor rotary_emb_context,
float sparsityRatio) {
int batch_size = hidden_states.shape[0];
int num_tokens_img = hidden_states.shape[1];
int num_tokens_txt = encoder_hidden_states.shape[1];
// Apply AdaNorm.
auto norm1_output = norm1.forward(hidden_states, temb);
auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);
Tensor concat = Tensor::allocate(
{batch_size, num_tokens_img + num_tokens_txt, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device());
const bool blockSparse = sparsityRatio > 0;
constexpr int POOL_SIZE = Attention::POOL_SIZE;
const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
Tensor pool =
blockSparse
? Tensor::allocate({batch_size, poolTokens, dim * 3}, norm1_output.x.scalar_type(), norm1_output.x.device())
: Tensor{};
// QKV Projection.
for (int i = 0; i < batch_size; i++) {
Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Tensor qkv_context = concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Tensor pool_qkv = pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
Tensor pool_qkv_context =
pool.valid() ? pool.slice(0, i, i + 1).slice(1, num_tokens_img / POOL_SIZE, poolTokens) : Tensor{};
qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
qkv_context,
pool_qkv_context,
norm_added_q.weight,
norm_added_k.weight,
rotary_emb_context);
}
// Extract and return q_heads.
Tensor q_all = concat.slice(2, 0, num_heads * dim_head);
Tensor q_img = q_all.slice(1, 0, num_tokens_img);
auto make_contiguous = [&](const Tensor &t) {
int B = t.shape.dataExtent[0];
int R = t.shape.dataExtent[1];
int C = t.shape.dataExtent[2];
size_t E = t.scalar_size();
size_t src_pitch = t.stride(1) * E;
size_t dst_pitch = C * E;
size_t width = C * E;
size_t height = R;
Tensor out = Tensor::allocate({B, R, C}, t.scalarType, t.device());
auto stream = getCurrentCUDAStream();
for (int b = 0; b < B; ++b) {
const void *src = (const char *)t.data_ptr<char>() + t.stride(0) * b * E;
void *dst = (char *)out.data_ptr<char>() + out.stride(0) * b * E;
checkCUDA(
cudaMemcpy2DAsync(dst, dst_pitch, src, src_pitch, width, height, cudaMemcpyDeviceToDevice, stream));
}
return out;
};
return make_contiguous(q_img);
}
std::tuple<Tensor, Tensor, Tensor> JointTransformerBlock::forward_ip_adapter_branch(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb,
Tensor rotary_emb_context,
float sparsityRatio) {
int batch_size = hidden_states.shape[0];
assert(encoder_hidden_states.shape[0] == batch_size);
nvtxRangePushA("JointTransformerBlock");
nvtxRangePushA("AdaNorm");
int num_tokens_img = hidden_states.shape[1];
int num_tokens_txt = encoder_hidden_states.shape[1];
assert(hidden_states.shape[2] == dim);
assert(encoder_hidden_states.shape[2] == dim);
Tensor q_heads;
auto make_contiguous = [&](const Tensor &t) {
int B = t.shape.dataExtent[0];
int R = t.shape.dataExtent[1];
int C = t.shape.dataExtent[2];
size_t E = t.scalar_size();
size_t src_pitch = t.stride(1) * E;
size_t dst_pitch = C * E;
size_t width = C * E;
size_t height = R;
Tensor out = Tensor::allocate({B, R, C}, t.scalarType, t.device());
auto stream = getCurrentCUDAStream();
for (int b = 0; b < B; ++b) {
const void *src = (const char *)t.data_ptr<char>() + t.stride(0) * b * E;
void *dst = (char *)out.data_ptr<char>() + out.stride(0) * b * E;
checkCUDA(
cudaMemcpy2DAsync(dst, dst_pitch, src, src_pitch, width, height, cudaMemcpyDeviceToDevice, stream));
}
return out;
};
spdlog::debug("hidden_states={} encoder_hidden_states={} temb={}",
hidden_states.shape.str(),
encoder_hidden_states.shape.str(),
temb.shape.str());
spdlog::debug("batch_size={} num_tokens_img={} num_tokens_txt={}", batch_size, num_tokens_img, num_tokens_txt);
auto norm1_output = norm1.forward(hidden_states, temb);
auto norm1_context_output = norm1_context.forward(encoder_hidden_states, temb);
#if 0
norm1_output.x = hidden_states;
norm1_context_output.x = encoder_hidden_states;
#endif
debug("norm_hidden_states", norm1_output.x);
debug("norm_encoder_hidden_states", norm1_context_output.x);
constexpr int POOL_SIZE = Attention::POOL_SIZE;
nvtxRangePop();
auto stream = getCurrentCUDAStream();
int num_tokens_img_pad = 0, num_tokens_txt_pad = 0;
Tensor raw_attn_output;
if (attnImpl == AttentionImpl::FlashAttention2) {
num_tokens_img_pad = num_tokens_img;
num_tokens_txt_pad = num_tokens_txt;
Tensor concat;
Tensor pool;
{
nvtxRangePushA("qkv_proj");
const bool blockSparse = sparsityRatio > 0;
const int poolTokens = num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE;
concat = Tensor::allocate({batch_size, num_tokens_img + num_tokens_txt, dim * 3},
norm1_output.x.scalar_type(),
norm1_output.x.device());
pool = blockSparse ? Tensor::allocate({batch_size, poolTokens, dim * 3},
norm1_output.x.scalar_type(),
norm1_output.x.device())
: Tensor{};
for (int i = 0; i < batch_size; i++) {
// img first
Tensor qkv = concat.slice(0, i, i + 1).slice(1, 0, num_tokens_img);
Tensor qkv_context =
concat.slice(0, i, i + 1).slice(1, num_tokens_img, num_tokens_img + num_tokens_txt);
Tensor pool_qkv =
pool.valid() ? pool.slice(0, i, i + 1).slice(1, 0, num_tokens_img / POOL_SIZE) : Tensor{};
Tensor pool_qkv_context = pool.valid()
? pool.slice(0, i, i + 1)
.slice(1,
num_tokens_img / POOL_SIZE,
num_tokens_img / POOL_SIZE + num_tokens_txt / POOL_SIZE)
: Tensor{};
// qkv_proj.forward(norm1_output.x.slice(0, i, i + 1), qkv);
// debug("qkv_raw", qkv);
debug("rotary_emb", rotary_emb);
qkv_proj.forward(
norm1_output.x.slice(0, i, i + 1), qkv, pool_qkv, norm_q.weight, norm_k.weight, rotary_emb);
debug("qkv", qkv);
// qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1), qkv_context);
// debug("qkv_context_raw", qkv_context);
debug("rotary_emb_context", rotary_emb_context);
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
qkv_context,
pool_qkv_context,
norm_added_q.weight,
norm_added_k.weight,
rotary_emb_context);
debug("qkv_context", qkv_context);
}
nvtxRangePop();
}
spdlog::debug("concat={}", concat.shape.str());
debug("concat", concat);
assert(concat.shape[2] == num_heads * dim_head * 3);
nvtxRangePushA("Attention");
if (pool.valid()) {
raw_attn_output = attn.forward(concat, pool, sparsityRatio);
} else {
raw_attn_output = attn.forward(concat);
}
nvtxRangePop();
spdlog::debug("raw_attn_output={}", raw_attn_output.shape.str());
raw_attn_output = raw_attn_output.view({batch_size, num_tokens_img + num_tokens_txt, num_heads, dim_head});
// IP_adapter
Tensor q_all = concat.slice(2, 0, num_heads * dim_head); // [B, N_total, dim]
Tensor q_img = q_all.slice(1, 0, num_tokens_img); // [B, N_img, dim]
q_heads = make_contiguous(q_img);
} else if (attnImpl == AttentionImpl::NunchakuFP16) {
num_tokens_img_pad = ceilDiv(num_tokens_img, 256) * 256;
num_tokens_txt_pad = ceilDiv(num_tokens_txt, 256) * 256;
Tensor concat_q, concat_k, concat_v;
{
nvtxRangePushA("qkv_proj");
concat_q = Tensor::allocate({batch_size, num_heads, num_tokens_img_pad + num_tokens_txt_pad, dim_head},
Tensor::FP16,
norm1_output.x.device());
concat_k = Tensor::empty_like(concat_q);
concat_v = Tensor::empty_like(concat_q);
for (int i = 0; i < batch_size; i++) {
// img first
auto sliceImg = [&](Tensor x) { return x.slice(0, i, i + 1).slice(2, 0, num_tokens_img_pad); };
auto sliceTxt = [&](Tensor x) {
return x.slice(0, i, i + 1).slice(2, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt_pad);
};
qkv_proj.forward(norm1_output.x.slice(0, i, i + 1),
{},
{},
norm_q.weight,
norm_k.weight,
rotary_emb,
sliceImg(concat_q),
sliceImg(concat_k),
sliceImg(concat_v),
num_tokens_img);
qkv_proj_context.forward(norm1_context_output.x.slice(0, i, i + 1),
{},
{},
norm_added_q.weight,
norm_added_k.weight,
rotary_emb_context,
sliceTxt(concat_q),
sliceTxt(concat_k),
sliceTxt(concat_v),
num_tokens_txt);
}
debug("concat_q", concat_q);
debug("concat_k", concat_k);
debug("concat_v", concat_v);
nvtxRangePop();
}
raw_attn_output = Tensor::allocate({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads * dim_head},
norm1_output.x.scalar_type(),
norm1_output.x.device());
nvtxRangePushA("Attention");
kernels::attention_fp16(concat_q, concat_k, concat_v, raw_attn_output, pow(dim_head, (-0.5)));
nvtxRangePop();
raw_attn_output =
raw_attn_output.view({batch_size, num_tokens_img_pad + num_tokens_txt_pad, num_heads, dim_head});
q_heads = concat_q;
} else {
assert(false);
}
debug("raw_attn_output", raw_attn_output);
{
nvtxRangePushA("o_proj");
auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_output;
// raw_attn_output: [batch_size, num_tokens_img + num_tokens_txt, num_heads * dim_head]
Tensor raw_attn_output_split;
if (batch_size == 1) {
raw_attn_output_split =
raw_attn_output.slice(1, 0, num_tokens_img).reshape({batch_size, num_tokens_img, num_heads * dim_head});
} else {
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_img, num_heads * dim_head},
raw_attn_output.scalar_type(),
raw_attn_output.device());
checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
raw_attn_output.data_ptr(),
(num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
raw_attn_output.scalar_size(),
num_tokens_img * num_heads * dim_head * raw_attn_output_split.scalar_size(),
batch_size,
cudaMemcpyDeviceToDevice,
stream));
}
spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
debug("img.raw_attn_output_split", raw_attn_output_split);
Tensor attn_output =
forward_fc(out_proj, raw_attn_output_split); // std::get<Tensor>(out_proj.forward(raw_attn_output_split));
debug("img.attn_output", attn_output);
#if 1
// kernels::mul_add(attn_output, gate_msa, hidden_states);
kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, hidden_states, true);
hidden_states = std::move(attn_output);
nvtxRangePop();
nvtxRangePushA("MLP");
spdlog::debug("attn_output={}", hidden_states.shape.str());
Tensor norm_hidden_states = norm2.forward(hidden_states);
debug("scale_mlp", scale_mlp);
debug("shift_mlp", shift_mlp);
// kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
Tensor norm_hidden_states = hidden_states;
#endif
// Tensor ff_output = mlp_fc2.forward(GELU::forward(mlp_fc1.forward(norm_hidden_states)));
debug("img.ff_input", norm_hidden_states);
Tensor ff_output = forward_mlp(mlp_fc1, mlp_fc2, norm_hidden_states);
debug("img.ff_output", ff_output);
debug("gate_mlp", gate_mlp);
// kernels::mul_add(ff_output, gate_mlp, hidden_states);
kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, hidden_states, true);
hidden_states = std::move(ff_output);
nvtxRangePop();
spdlog::debug("ff_output={}", hidden_states.shape.str());
}
if (context_pre_only) {
return {hidden_states, encoder_hidden_states, q_heads};
}
{
nvtxRangePushA("o_proj_context");
auto &&[_, gate_msa, shift_mlp, scale_mlp, gate_mlp] = norm1_context_output;
Tensor raw_attn_output_split;
if (batch_size == 1) {
raw_attn_output_split = raw_attn_output.slice(1, num_tokens_img_pad, num_tokens_img_pad + num_tokens_txt)
.reshape({batch_size, num_tokens_txt, num_heads * dim_head});
} else {
raw_attn_output_split = Tensor::allocate({batch_size, num_tokens_txt, num_heads * dim_head},
raw_attn_output.scalar_type(),
raw_attn_output.device());
checkCUDA(cudaMemcpy2DAsync(raw_attn_output_split.data_ptr(),
num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
raw_attn_output.data_ptr<char>() + num_tokens_img_pad * num_heads * dim_head *
raw_attn_output_split.scalar_size(),
(num_tokens_img_pad + num_tokens_txt_pad) * num_heads * dim_head *
raw_attn_output.scalar_size(),
num_tokens_txt * num_heads * dim_head * raw_attn_output_split.scalar_size(),
batch_size,
cudaMemcpyDeviceToDevice,
stream));
}
spdlog::debug("raw_attn_output_split={}", raw_attn_output_split.shape.str());
debug("context.raw_attn_output_split", raw_attn_output_split);
Tensor attn_output =
forward_fc(out_proj_context,
raw_attn_output_split); // std::get<Tensor>(out_proj_context.forward(raw_attn_output_split));
debug("context.attn_output", attn_output);
#if 1
// kernels::mul_add(attn_output, gate_msa, encoder_hidden_states);
kernels::mul_add_batch(attn_output, gate_msa, true, 0.0, encoder_hidden_states, true);
encoder_hidden_states = std::move(attn_output);
nvtxRangePop();
nvtxRangePushA("MLP");
spdlog::debug("attn_output={}", encoder_hidden_states.shape.str());
Tensor norm_hidden_states = norm2_context.forward(encoder_hidden_states);
debug("c_scale_mlp", scale_mlp);
debug("c_shift_mlp", shift_mlp);
// kernels::mul_add(norm_hidden_states, scale_mlp, shift_mlp);
kernels::mul_add_batch(norm_hidden_states, scale_mlp, true, 0.0, shift_mlp, true);
spdlog::debug("norm_hidden_states={}", norm_hidden_states.shape.str());
#else
auto norm_hidden_states = encoder_hidden_states;
#endif
// Tensor ff_output = mlp_context_fc2.forward(GELU::forward(mlp_context_fc1.forward(norm_hidden_states)));
// Tensor ff_output =
// mlp_context_fc2.forward_quant(quant_static_fuse_gelu(mlp_context_fc1.forward(norm_hidden_states), 1.0));
debug("context.ff_input", norm_hidden_states);
Tensor ff_output = forward_mlp(mlp_context_fc1, mlp_context_fc2, norm_hidden_states);
debug("context.ff_output", ff_output);
debug("c_gate_mlp", gate_mlp);
// kernels::mul_add(ff_output, gate_mlp, encoder_hidden_states);
kernels::mul_add_batch(ff_output, gate_mlp, true, 0.0, encoder_hidden_states, true);
encoder_hidden_states = std::move(ff_output);
nvtxRangePop();
spdlog::debug("ff_output={}", encoder_hidden_states.shape.str());
}
nvtxRangePop();
return {hidden_states, encoder_hidden_states, q_heads};
}
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device)
: dtype(dtype), offload(offload) {
......@@ -969,6 +1416,43 @@ std::tuple<Tensor, Tensor> FluxModel::forward_layer(size_t layer,
return {hidden_states, encoder_hidden_states};
}
std::tuple<Tensor, Tensor, Tensor> FluxModel::forward_ip_adapter(size_t layer,
Tensor hidden_states, // [B, Nq, dim]
Tensor encoder_hidden_states, // [B, Nt, dim]
Tensor temb,
Tensor rotary_emb_img, // [B, Nq, dim_head]
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples) {
if (offload && layer > 0) {
if (layer < transformer_blocks.size()) {
transformer_blocks.at(layer)->loadLazyParams();
} else {
transformer_blocks.at(layer - transformer_blocks.size())->loadLazyParams();
}
}
std::tie(hidden_states, encoder_hidden_states) = transformer_blocks.at(layer)->forward(
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
Tensor ip_query = transformer_blocks.at(layer)->get_q_heads(
hidden_states, encoder_hidden_states, temb, rotary_emb_img, rotary_emb_context, 0.0f);
if (controlnet_block_samples.valid()) {
const int num_controlnet_block_samples = controlnet_block_samples.shape[0];
int interval_control = ceilDiv(transformer_blocks.size(), static_cast<size_t>(num_controlnet_block_samples));
int block_index = layer / interval_control;
hidden_states = kernels::add(hidden_states, controlnet_block_samples[block_index]);
}
if (offload && layer > 0) {
transformer_blocks.at(layer)->releaseLazyParams();
}
return {hidden_states, encoder_hidden_states, ip_query};
}
void FluxModel::setAttentionImpl(AttentionImpl impl) {
for (auto &&block : this->transformer_blocks) {
block->attnImpl = impl;
......
......@@ -133,6 +133,18 @@ public:
Tensor rotary_emb,
Tensor rotary_emb_context,
float sparsityRatio);
std::tuple<Tensor, Tensor, Tensor> forward_ip_adapter_branch(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb,
Tensor rotary_emb_context,
float sparsityRatio);
Tensor get_q_heads(Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb,
Tensor rotary_emb_context,
float sparsityRatio);
public:
const int dim;
......@@ -178,6 +190,16 @@ public:
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples);
std::tuple<Tensor, Tensor, Tensor> forward_ip_adapter(size_t layer,
Tensor hidden_states,
Tensor encoder_hidden_states,
Tensor temb,
Tensor rotary_emb_img,
Tensor rotary_emb_context,
Tensor controlnet_block_samples,
Tensor controlnet_single_block_samples);
void setAttentionImpl(AttentionImpl impl);
void set_residual_callback(std::function<Tensor(const Tensor &)> cb);
......
import gc
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from diffusers import FluxPipeline
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.models.ip_adapter.diffusers_adapters import apply_IPA_on_pipe
from nunchaku.models.ip_adapter.utils import resize_numpy_image_long
from nunchaku.pipeline.pipeline_flux_pulid import PuLIDFluxPipeline
from nunchaku.utils import get_precision, is_turing
@pytest.mark.skipif(is_turing(), reason="Skip tests due to using Turing GPUs")
def test_flux_dev_IPA():
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
pipeline.load_ip_adapter(
pretrained_model_name_or_path_or_dict="XLabs-AI/flux-ip-adapter-v2",
weight_name="ip_adapter.safetensors",
image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14",
)
apply_IPA_on_pipe(pipeline, ip_adapter_scale=1.15, repo_id="XLabs-AI/flux-ip-adapter-v2")
id_image = load_image(
"https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg"
)
image = pipeline(
prompt="holding an sign saying 'SVDQuant is fast!'",
ip_adapter_image=id_image.convert("RGB"),
num_inference_steps=50,
).images[0]
del pipeline
del transformer
gc.collect()
torch.cuda.empty_cache()
# use the pulid pipeline to get the id embedding
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors", offload=True
)
pipeline = PuLIDFluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
transformer=transformer,
torch_dtype=torch.bfloat16,
)
id_image = id_image.convert("RGB")
id_image_numpy = np.array(id_image)
id_image = resize_numpy_image_long(id_image_numpy, 1024)
id_embeddings, _ = pipeline.pulid_model.get_id_embedding(id_image)
output_image = image.convert("RGB")
output_image_numpy = np.array(output_image)
output_image = resize_numpy_image_long(output_image_numpy, 1024)
output_id_embeddings, _ = pipeline.pulid_model.get_id_embedding(output_image)
cosine_similarities = (
F.cosine_similarity(id_embeddings.view(32, 2048), output_id_embeddings.view(32, 2048), dim=1).mean().item()
)
print(cosine_similarities)
assert cosine_similarities > 0.80
del pipeline
del transformer
gc.collect()
torch.cuda.empty_cache()
if __name__ == "__main__":
test_flux_dev_IPA()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment