Unverified Commit 06b7a518 authored by SMG's avatar SMG Committed by GitHub
Browse files

feat: enable IP-Adapter (XLabs-AI/flux-ip-adapter-v2) support (#418)



* feat: support IP-adapter

* FBCache and comfyUI

* fixing conflicts

* update

* update example

* update example

* style: make linter happy

* update

* update ipa test

* add docs and rename IP to ip

* docs: add docs for ipa

* docs: add docs for ipa

* add an example for pulid

* update

* save gpu memory

* change the threshold to 0.8

---------
Co-authored-by: default avatarMuyang Li <lmxyy1999@foxmail.com>
parent 24c2f925
...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision ...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors", offload=True f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors", offload=True
) # set offload to False if you want to disable offloading ) # set offload to False if you want to disable offloading
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -10,7 +10,7 @@ from nunchaku.utils import get_precision ...@@ -10,7 +10,7 @@ from nunchaku.utils import get_precision
precision = get_precision() precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
) )
pipeline = PuLIDFluxPipeline.from_pretrained( pipeline = PuLIDFluxPipeline.from_pretrained(
......
...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision ...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
) )
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors") text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors")
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
......
...@@ -9,7 +9,7 @@ from nunchaku.utils import get_precision ...@@ -9,7 +9,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision ...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors", f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
offload=True, offload=True,
torch_dtype=torch.float16, # Turing GPUs only support fp16 precision torch_dtype=torch.float16, # Turing GPUs only support fp16 precision
) # set offload to False if you want to disable offloading ) # set offload to False if you want to disable offloading
......
...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision ...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -10,7 +10,7 @@ mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/ ...@@ -10,7 +10,7 @@ mask = load_image("https://huggingface.co/mit-han-lab/svdq-int4-flux.1-fill-dev/
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-fill-dev/svdq-{precision}_r32-flux.1-fill-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-fill-dev/svdq-{precision}_r32-flux.1-fill-dev.safetensors"
) )
pipe = FluxFillPipeline.from_pretrained( pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-Fill-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -6,7 +6,7 @@ from nunchaku import NunchakuFluxTransformer2dModel ...@@ -6,7 +6,7 @@ from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision from nunchaku.utils import get_precision
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-kontext-dev/svdq-{get_precision()}_r32-flux.1-kontext-dev.safetensors"
) )
pipeline = FluxKontextPipeline.from_pretrained( pipeline = FluxKontextPipeline.from_pretrained(
......
...@@ -10,7 +10,7 @@ pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained( ...@@ -10,7 +10,7 @@ pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-Redux-dev", torch_dtype=torch.bfloat16
).to("cuda") ).to("cuda")
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
) )
pipe = FluxPipeline.from_pretrained( pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", "black-forest-labs/FLUX.1-dev",
......
...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision ...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors" f"nunchaku-tech/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors"
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -5,7 +5,7 @@ from nunchaku import NunchakuSanaTransformer2DModel ...@@ -5,7 +5,7 @@ from nunchaku import NunchakuSanaTransformer2DModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
transformer = NunchakuSanaTransformer2DModel.from_pretrained( transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors" "nunchaku-tech/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
) )
pipe = SanaPipeline.from_pretrained( pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
......
...@@ -4,7 +4,7 @@ from diffusers import SanaPipeline ...@@ -4,7 +4,7 @@ from diffusers import SanaPipeline
from nunchaku import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained( transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors" "nunchaku-tech/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors"
) )
pipe = SanaPipeline.from_pretrained( pipe = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
......
...@@ -4,7 +4,7 @@ from diffusers import SanaPAGPipeline ...@@ -4,7 +4,7 @@ from diffusers import SanaPAGPipeline
from nunchaku import NunchakuSanaTransformer2DModel from nunchaku import NunchakuSanaTransformer2DModel
transformer = NunchakuSanaTransformer2DModel.from_pretrained( transformer = NunchakuSanaTransformer2DModel.from_pretrained(
"mit-han-lab/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors", pag_layers=8 "nunchaku-tech/nunchaku-sana/svdq-int4_r32-sana1.6b.safetensors", pag_layers=8
) )
pipe = SanaPAGPipeline.from_pretrained( pipe = SanaPAGPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers", "Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
......
...@@ -24,7 +24,7 @@ def apply_cache_on_transformer( ...@@ -24,7 +24,7 @@ def apply_cache_on_transformer(
use_double_fb_cache: bool = False, use_double_fb_cache: bool = False,
residual_diff_threshold: float = 0.12, residual_diff_threshold: float = 0.12,
residual_diff_threshold_multi: float | None = None, residual_diff_threshold_multi: float | None = None,
residual_diff_threshold_single: float = 0.1, residual_diff_threshold_single: float | None = None,
): ):
""" """
Enable caching for a ``FluxTransformer2DModel``. Enable caching for a ``FluxTransformer2DModel``.
...@@ -43,7 +43,7 @@ def apply_cache_on_transformer( ...@@ -43,7 +43,7 @@ def apply_cache_on_transformer(
residual_diff_threshold_multi : float, optional residual_diff_threshold_multi : float, optional
Threshold for multi-head (double) blocks. If None, uses ``residual_diff_threshold``. Threshold for multi-head (double) blocks. If None, uses ``residual_diff_threshold``.
residual_diff_threshold_single : float, optional residual_diff_threshold_single : float, optional
Threshold for single-head blocks (default: 0.1). Threshold for single-head blocks (default: None).
Returns Returns
------- -------
...@@ -54,6 +54,11 @@ def apply_cache_on_transformer( ...@@ -54,6 +54,11 @@ def apply_cache_on_transformer(
----- -----
If already cached, only updates thresholds. Caching is only active within a cache context. If already cached, only updates thresholds. Caching is only active within a cache context.
""" """
if not hasattr(transformer, "_original_forward"):
transformer._original_forward = transformer.forward
if not hasattr(transformer, "_original_blocks"):
transformer._original_blocks = transformer.transformer_blocks
if residual_diff_threshold_multi is None: if residual_diff_threshold_multi is None:
residual_diff_threshold_multi = residual_diff_threshold residual_diff_threshold_multi = residual_diff_threshold
...@@ -94,6 +99,10 @@ def apply_cache_on_transformer( ...@@ -94,6 +99,10 @@ def apply_cache_on_transformer(
return original_forward(*args, **kwargs) return original_forward(*args, **kwargs)
transformer.forward = new_forward.__get__(transformer) transformer.forward = new_forward.__get__(transformer)
transformer._is_cached = True
transformer.use_double_fb_cache = use_double_fb_cache
transformer.residual_diff_threshold_multi = residual_diff_threshold_multi
transformer.residual_diff_threshold_single = residual_diff_threshold_single
return transformer return transformer
......
...@@ -583,7 +583,7 @@ class SanaCachedTransformerBlocks(nn.Module): ...@@ -583,7 +583,7 @@ class SanaCachedTransformerBlocks(nn.Module):
can_use_cache, _ = get_can_use_cache( can_use_cache, _ = get_can_use_cache(
first_hidden_states_residual, first_hidden_states_residual,
threshold=self.residual_diff_threshold, threshold=self.residual_diff_threshold,
parallelized=self.transformer is not None and getattr(self.transformer, "_is_parallelized", False), parallelized=False,
) )
torch._dynamo.graph_break() torch._dynamo.graph_break()
...@@ -729,7 +729,7 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -729,7 +729,7 @@ class FluxCachedTransformerBlocks(nn.Module):
verbose: bool = False, verbose: bool = False,
): ):
super().__init__() super().__init__()
self.transformer = transformer # self.transformer = transformer
self.transformer_blocks = transformer.transformer_blocks self.transformer_blocks = transformer.transformer_blocks
self.single_transformer_blocks = transformer.single_transformer_blocks self.single_transformer_blocks = transformer.single_transformer_blocks
...@@ -924,7 +924,7 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -924,7 +924,7 @@ class FluxCachedTransformerBlocks(nn.Module):
hidden_states=hidden_states, hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
threshold=self.residual_diff_threshold_multi, threshold=self.residual_diff_threshold_multi,
parallelized=(self.transformer is not None and getattr(self.transformer, "_is_parallelized", False)), parallelized=False,
mode="multi", mode="multi",
verbose=self.verbose, verbose=self.verbose,
call_remaining_fn=call_remaining_fn, call_remaining_fn=call_remaining_fn,
...@@ -953,7 +953,7 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -953,7 +953,7 @@ class FluxCachedTransformerBlocks(nn.Module):
hidden_states=cat_hidden_states, hidden_states=cat_hidden_states,
encoder_hidden_states=None, encoder_hidden_states=None,
threshold=self.residual_diff_threshold_single, threshold=self.residual_diff_threshold_single,
parallelized=(self.transformer is not None and getattr(self.transformer, "_is_parallelized", False)), parallelized=False,
mode="single", mode="single",
verbose=self.verbose, verbose=self.verbose,
call_remaining_fn=call_remaining_fn_single, call_remaining_fn=call_remaining_fn_single,
...@@ -1066,6 +1066,7 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -1066,6 +1066,7 @@ class FluxCachedTransformerBlocks(nn.Module):
controlnet_single_block_samples=None, controlnet_single_block_samples=None,
skip_first_layer=False, skip_first_layer=False,
txt_tokens=None, txt_tokens=None,
start_idx=1,
): ):
""" """
Call remaining Flux double blocks. Call remaining Flux double blocks.
...@@ -1104,7 +1105,6 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -1104,7 +1105,6 @@ class FluxCachedTransformerBlocks(nn.Module):
enc_residual : torch.Tensor enc_residual : torch.Tensor
Residual of encoder hidden states. Residual of encoder hidden states.
""" """
start_idx = 1
original_hidden_states = hidden_states.clone() original_hidden_states = hidden_states.clone()
original_encoder_hidden_states = encoder_hidden_states.clone() original_encoder_hidden_states = encoder_hidden_states.clone()
...@@ -1139,6 +1139,7 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -1139,6 +1139,7 @@ class FluxCachedTransformerBlocks(nn.Module):
controlnet_single_block_samples=None, controlnet_single_block_samples=None,
skip_first_layer=False, skip_first_layer=False,
txt_tokens=None, txt_tokens=None,
start_idx=1,
): ):
""" """
Call remaining Flux single blocks. Call remaining Flux single blocks.
...@@ -1173,7 +1174,6 @@ class FluxCachedTransformerBlocks(nn.Module): ...@@ -1173,7 +1174,6 @@ class FluxCachedTransformerBlocks(nn.Module):
hidden_states_residual : torch.Tensor hidden_states_residual : torch.Tensor
Residual of hidden states. Residual of hidden states.
""" """
start_idx = 1
original_hidden_states = hidden_states.clone() original_hidden_states = hidden_states.clone()
for idx in range(start_idx, num_single_transformer_blocks): for idx in range(start_idx, num_single_transformer_blocks):
......
...@@ -212,4 +212,43 @@ public: ...@@ -212,4 +212,43 @@ public:
throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name)); throw std::invalid_argument(spdlog::fmt_lib::format("Invalid attention implementation {}", name));
} }
} }
std::tuple<torch::Tensor, torch::Tensor, torch::Tensor>
forward_layer_ip_adapter(int64_t idx,
torch::Tensor hidden_states,
torch::Tensor encoder_hidden_states,
torch::Tensor temb,
torch::Tensor rotary_emb_img,
torch::Tensor rotary_emb_context,
std::optional<torch::Tensor> controlnet_block_samples = std::nullopt,
std::optional<torch::Tensor> controlnet_single_block_samples = std::nullopt) {
CUDADeviceContext ctx(deviceId);
spdlog::debug("QuantizedFluxModel forward_layer {}", idx);
hidden_states = hidden_states.contiguous();
encoder_hidden_states = encoder_hidden_states.contiguous();
temb = temb.contiguous();
rotary_emb_img = rotary_emb_img.contiguous();
rotary_emb_context = rotary_emb_context.contiguous();
auto &&[hidden_states_, encoder_hidden_states_, ip_query_] = net->forward_ip_adapter(
idx,
from_torch(hidden_states),
from_torch(encoder_hidden_states),
from_torch(temb),
from_torch(rotary_emb_img),
from_torch(rotary_emb_context),
controlnet_block_samples.has_value() ? from_torch(controlnet_block_samples.value().contiguous()) : Tensor{},
controlnet_single_block_samples.has_value()
? from_torch(controlnet_single_block_samples.value().contiguous())
: Tensor{});
hidden_states = to_torch(hidden_states_);
encoder_hidden_states = to_torch(encoder_hidden_states_);
torch::Tensor ip_query = to_torch(ip_query_);
Tensor::synchronizeDevice();
return {hidden_states, encoder_hidden_states, ip_query};
}
}; };
...@@ -49,6 +49,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -49,6 +49,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("rotary_emb_context"), py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(), py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none()) py::arg("controlnet_single_block_samples") = py::none())
.def("forward_layer_ip_adapter",
&QuantizedFluxModel::forward_layer_ip_adapter,
py::arg("idx"),
py::arg("hidden_states"),
py::arg("encoder_hidden_states"),
py::arg("temb"),
py::arg("rotary_emb_img"),
py::arg("rotary_emb_context"),
py::arg("controlnet_block_samples") = py::none(),
py::arg("controlnet_single_block_samples") = py::none())
.def("forward_single_layer", &QuantizedFluxModel::forward_single_layer) .def("forward_single_layer", &QuantizedFluxModel::forward_single_layer)
.def("norm_one_forward", &QuantizedFluxModel::norm_one_forward) .def("norm_one_forward", &QuantizedFluxModel::norm_one_forward)
.def("startDebug", &QuantizedFluxModel::startDebug) .def("startDebug", &QuantizedFluxModel::startDebug)
......
"""
IP-Adapter integration for Diffusers pipelines.
This module provides utilities to apply IP-Adapter modifications to compatible
Diffusers pipelines, such as Flux and PuLID pipelines.
"""
from diffusers import DiffusionPipeline
def apply_IPA_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
"""
Apply IP-Adapter modifications to a supported Diffusers pipeline.
Parameters
----------
pipe : DiffusionPipeline
The pipeline instance to modify. Must be a Flux or PuLID pipeline.
*args
Additional positional arguments passed to the underlying implementation.
**kwargs
Additional keyword arguments passed to the underlying implementation.
Returns
-------
DiffusionPipeline
The modified pipeline with IP-Adapter applied.
Raises
------
ValueError
If the pipeline class is not supported.
"""
assert isinstance(pipe, DiffusionPipeline)
pipe_cls_name = pipe.__class__.__name__
if pipe_cls_name.startswith("Flux") or pipe_cls_name.startswith("PuLID"):
from .flux import apply_IPA_on_pipe as apply_IPA_on_pipe_fn
else:
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
return apply_IPA_on_pipe_fn(pipe, *args, **kwargs)
"""
IP-Adapter integration for Flux pipelines in Diffusers.
This module provides functions to apply IP-Adapter modifications to
FluxTransformer2DModel and DiffusionPipeline objects, enabling image prompt
conditioning for generative models.
"""
import functools
import unittest
from diffusers import DiffusionPipeline, FluxTransformer2DModel
from torch import nn
from nunchaku.caching.utils import cache_context, create_cache_context
from nunchaku.models.ip_adapter.utils import undo_all_mods_on_transformer
from ...ip_adapter import utils
def apply_IPA_on_transformer(
transformer: FluxTransformer2DModel,
*,
ip_adapter_scale: float = 1.0,
repo_id: str,
):
"""
Apply IP-Adapter modifications to a FluxTransformer2DModel.
This function replaces the transformer's blocks with IP-Adapter-enabled blocks,
loads per-layer IP-Adapter weights, and wraps the forward method to use the new blocks.
Parameters
----------
transformer : FluxTransformer2DModel
The transformer model to modify.
ip_adapter_scale : float, optional
Scaling factor for the IP-Adapter (default is 1.0).
repo_id : str
HuggingFace Hub repository ID containing the IP-Adapter weights.
Returns
-------
FluxTransformer2DModel
The modified transformer with IP-Adapter support.
"""
IPA_transformer_blocks = nn.ModuleList(
[
utils.IPA_TransformerBlocks(
transformer=transformer,
ip_adapter_scale=ip_adapter_scale,
return_hidden_states_first=False,
device=transformer.device,
)
]
)
if getattr(transformer, "_is_cached", False):
IPA_transformer_blocks[0].update_residual_diff_threshold(
use_double_fb_cache=transformer.use_double_fb_cache,
residual_diff_threshold_multi=transformer.residual_diff_threshold_multi,
residual_diff_threshold_single=transformer.residual_diff_threshold_single,
)
undo_all_mods_on_transformer(transformer)
if not hasattr(transformer, "_original_forward"):
transformer._original_forward = transformer.forward
if not hasattr(transformer, "_original_blocks"):
transformer._original_blocks = transformer.transformer_blocks
dummy_single_transformer_blocks = nn.ModuleList()
IPA_transformer_blocks[0].load_ip_adapter_weights_per_layer(repo_id=repo_id)
transformer.transformer_blocks = IPA_transformer_blocks
transformer.single_transformer_blocks = dummy_single_transformer_blocks
original_forward = transformer.forward
@functools.wraps(original_forward)
def new_forward(self, *args, **kwargs):
with (
unittest.mock.patch.object(self, "transformer_blocks", IPA_transformer_blocks),
unittest.mock.patch.object(self, "single_transformer_blocks", dummy_single_transformer_blocks),
):
return original_forward(*args, **kwargs)
transformer.forward = new_forward.__get__(transformer)
transformer._is_IPA = True
return transformer
def apply_IPA_on_pipe(pipe: DiffusionPipeline, **kwargs):
"""
Apply IP-Adapter modifications to a DiffusionPipeline.
This function modifies the pipeline's transformer to support IP-Adapter
conditioning. If the pipeline is cached, it also wraps the pipeline's
__call__ method to ensure cache context is used.
Parameters
----------
pipe : DiffusionPipeline
The pipeline to modify. Must contain a FluxTransformer2DModel as its transformer.
**kwargs
Additional keyword arguments passed to `apply_IPA_on_transformer`.
Returns
-------
DiffusionPipeline
The modified pipeline with IP-Adapter support.
"""
if getattr(pipe, "_is_cached", False):
original_call = pipe.__class__.__call__
@functools.wraps(original_call)
def new_call(self, *args, **kwargs):
with cache_context(create_cache_context()):
return original_call(self, *args, **kwargs)
pipe.__class__.__call__ = new_call
pipe.__class__._is_cached = True
apply_IPA_on_transformer(pipe.transformer, **kwargs)
return pipe
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment