"vscode:/vscode.git/clone" did not exist on "98085015d5826f173e19572905d29ff1908cb3e0"
Unverified Commit 06b7a518 authored by SMG's avatar SMG Committed by GitHub
Browse files

feat: enable IP-Adapter (XLabs-AI/flux-ip-adapter-v2) support (#418)



* feat: support IP-adapter

* FBCache and comfyUI

* fixing conflicts

* update

* update example

* update example

* style: make linter happy

* update

* update ipa test

* add docs and rename IP to ip

* docs: add docs for ipa

* docs: add docs for ipa

* add an example for pulid

* update

* save gpu memory

* change the threshold to 0.8

---------
Co-authored-by: default avatarMuyang Li <lmxyy1999@foxmail.com>
parent 24c2f925
......@@ -24,6 +24,7 @@ Check out `DeepCompressor <github_deepcompressor_>`_ for the quantization librar
usage/attention.rst
usage/fbcache.rst
usage/pulid.rst
usage/ip_adapter.rst
.. toctree::
:maxdepth: 1
......
......@@ -8,3 +8,4 @@
.. _hf_nunchaku-flux1-dev-int4: https://huggingface.co/mit-han-lab/nunchaku-flux.1-dev/blob/main/svdq-int4_r32-flux.1-dev.safetensors
.. _hf_depth_anything: https://huggingface.co/LiheYoung/depth-anything-large-hf
.. _hf_nunchaku_wheels: https://huggingface.co/nunchaku-tech/nunchaku
.. _hf_ip-adapterv2: https://huggingface.co/XLabs-AI/flux-ip-adapter-v2
nunchaku.models.ip_adapter.diffusers_adapters.flux
==================================================
.. automodule:: nunchaku.models.ip_adapter.diffusers_adapters.flux
:members:
:undoc-members:
:show-inheritance:
nunchaku.models.ip_adapter.diffusers_adapters
=============================================
.. automodule:: nunchaku.models.ip_adapter.diffusers_adapters
:members:
:undoc-members:
:show-inheritance:
.. toctree::
:maxdepth: 4
nunchaku.models.ip_adapter.diffusers_adapters.flux
nunchaku.models.ip_adapter
==========================
.. toctree::
:maxdepth: 4
nunchaku.models.ip_adapter.diffusers_adapters
nunchaku.models.ip_adapter.utils
nunchaku.models.ip_adapter.utils
================================
.. automodule:: nunchaku.models.ip_adapter.utils
:members:
:undoc-members:
:show-inheritance:
......@@ -7,4 +7,5 @@ nunchaku.models
nunchaku.models.transformers
nunchaku.models.text_encoders
nunchaku.models.pulid
nunchaku.models.ip_adapter
nunchaku.models.safety_checker
IP Adapter
==========
Nunchaku supports `IP Adapter <hf_ip-adapterv2_>`_, an adapter achieving image prompt capability for the FLUX.1-dev
.. literalinclude:: ../../../examples/flux.1-dev-IP-adapter.py
:language: python
:caption: IP Adapter Example (`examples/flux.1-dev-IP-adapter.py <https://github.com/nunchaku-tech/nunchaku/blob/main/examples/flux.1-dev-IP-adapter.py>`__)
:linenos:
The IP Adapter integration in Nunchaku follows these main steps:
**Model Initialization**:
- Load a Nunchaku FLUX.1-dev transformer model using :meth:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel.from_pretrained`.
- Initialize the FLUX pipeline with :class:`diffusers.FluxPipeline`, passing the transformer and setting the appropriate precision.
**IP Adapter Loading**:
- Use ``pipeline.load_ip_adapter`` to load the IP Adapter weights and the CLIP image encoder.
- ``pretrained_model_name_or_path_or_dict``: Hugging Face repo or local path for the IP Adapter weights.
- ``weight_name``: Name of the weights file (e.g., ``ip_adapter.safetensors``).
- ``image_encoder_pretrained_model_name_or_path``: Name or path of the CLIP image encoder.
- Apply the IP Adapter to the pipeline with :func:`~nunchaku.models.ip_adapter.diffusers_adapters.apply_IPA_on_pipe`, specifying the adapter scale and repo ID.
**Caching (Optional)**:
Enable caching for faster inference and reduced memory usage with :func:`~nunchaku.caching.diffusers_adapters.apply_cache_on_pipe`. See :doc:`fbcache` for more details.
**Image Generation**:
- Load the image to be used as the image prompt (IP Adapter reference).
- Call the pipeline with:
- ``prompt``: The text prompt for generation.
- ``ip_adapter_image``: The reference image (must be RGB).
- The output image will reflect both the text prompt and the visual style/content of the reference image.
PuLID
=====
Nunchaku integrates `PuLID <_pulid_paper>`_, a tuning-free identity customization method for text-to-image generation.
.. image:: https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/ComfyUI-nunchaku/workflows/nunchaku-flux.1-dev-pulid.png
Nunchaku integrates `PuLID <paper_pulid_>`_, a tuning-free identity customization method for text-to-image generation.
This feature allows you to generate images that maintain specific identity characteristics from reference photos.
.. literalinclude:: ../../../examples/flux.1-dev-pulid.py
......@@ -9,13 +11,10 @@ This feature allows you to generate images that maintain specific identity chara
:caption: PuLID Example (`examples/flux.1-dev-pulid.py <https://github.com/nunchaku-tech/nunchaku/blob/main/examples/flux.1-dev-pulid.py>`__)
:linenos:
Implementation Overview
-----------------------
The PuLID integration follows these key steps:
**Model Initialization** (lines 12-20):
Load a Nunchaku FLUX.1-dev model using :class:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel`
Load a Nunchaku FLUX.1-dev model using :meth:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel.from_pretrained`
and initialize the FLUX PuLID pipeline with :class:`~nunchaku.pipeline.pipeline_flux_pulid.PuLIDFluxPipeline`.
**Forward Method Override** (line 22):
......
......@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-canny-dev/svdq-{precision}_r32-flux.1-canny-dev.safetensors"
f"nunchaku-tech/nunchaku-flux.1-canny-dev/svdq-{precision}_r32-flux.1-canny-dev.safetensors"
)
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Canny-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
......@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors"
f"nunchaku-tech/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors"
)
pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
......@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors"
f"nunchaku-tech/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors"
)
pipe = FluxControlPipeline.from_pretrained(
......
import torch
from diffusers import FluxPipeline
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.models.ip_adapter.diffusers_adapters import apply_IPA_on_pipe
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
pipeline.load_ip_adapter(
pretrained_model_name_or_path_or_dict="XLabs-AI/flux-ip-adapter-v2",
weight_name="ip_adapter.safetensors",
image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14",
)
apply_IPA_on_pipe(pipeline, ip_adapter_scale=1.1, repo_id="XLabs-AI/flux-ip-adapter-v2")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
IP_image = load_image(
"https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg"
)
image = pipeline(
prompt="holding an sign saying 'SVDQuant is fast!'",
ip_adapter_image=IP_image.convert("RGB"),
num_inference_steps=50,
).images[0]
image.save(f"flux.1-dev-IP-adapter-{precision}.png")
......@@ -7,7 +7,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
......@@ -15,7 +15,7 @@ controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend
precision = get_precision()
need_offload = get_gpu_memory() < 36
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
torch_dtype=torch.bfloat16,
offload=need_offload,
)
......
......@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
......
......@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
offload=True,
)
......
......@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
transformer.set_attention_impl("nunchaku-fp16") # set attention implementation to fp16
pipeline = FluxPipeline.from_pretrained(
......
......@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
......@@ -7,7 +7,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment