Unverified Commit 06b7a518 authored by SMG's avatar SMG Committed by GitHub
Browse files

feat: enable IP-Adapter (XLabs-AI/flux-ip-adapter-v2) support (#418)



* feat: support IP-adapter

* FBCache and comfyUI

* fixing conflicts

* update

* update example

* update example

* style: make linter happy

* update

* update ipa test

* add docs and rename IP to ip

* docs: add docs for ipa

* docs: add docs for ipa

* add an example for pulid

* update

* save gpu memory

* change the threshold to 0.8

---------
Co-authored-by: default avatarMuyang Li <lmxyy1999@foxmail.com>
parent 24c2f925
...@@ -24,6 +24,7 @@ Check out `DeepCompressor <github_deepcompressor_>`_ for the quantization librar ...@@ -24,6 +24,7 @@ Check out `DeepCompressor <github_deepcompressor_>`_ for the quantization librar
usage/attention.rst usage/attention.rst
usage/fbcache.rst usage/fbcache.rst
usage/pulid.rst usage/pulid.rst
usage/ip_adapter.rst
.. toctree:: .. toctree::
:maxdepth: 1 :maxdepth: 1
......
...@@ -8,3 +8,4 @@ ...@@ -8,3 +8,4 @@
.. _hf_nunchaku-flux1-dev-int4: https://huggingface.co/mit-han-lab/nunchaku-flux.1-dev/blob/main/svdq-int4_r32-flux.1-dev.safetensors .. _hf_nunchaku-flux1-dev-int4: https://huggingface.co/mit-han-lab/nunchaku-flux.1-dev/blob/main/svdq-int4_r32-flux.1-dev.safetensors
.. _hf_depth_anything: https://huggingface.co/LiheYoung/depth-anything-large-hf .. _hf_depth_anything: https://huggingface.co/LiheYoung/depth-anything-large-hf
.. _hf_nunchaku_wheels: https://huggingface.co/nunchaku-tech/nunchaku .. _hf_nunchaku_wheels: https://huggingface.co/nunchaku-tech/nunchaku
.. _hf_ip-adapterv2: https://huggingface.co/XLabs-AI/flux-ip-adapter-v2
nunchaku.models.ip_adapter.diffusers_adapters.flux
==================================================
.. automodule:: nunchaku.models.ip_adapter.diffusers_adapters.flux
:members:
:undoc-members:
:show-inheritance:
nunchaku.models.ip_adapter.diffusers_adapters
=============================================
.. automodule:: nunchaku.models.ip_adapter.diffusers_adapters
:members:
:undoc-members:
:show-inheritance:
.. toctree::
:maxdepth: 4
nunchaku.models.ip_adapter.diffusers_adapters.flux
nunchaku.models.ip_adapter
==========================
.. toctree::
:maxdepth: 4
nunchaku.models.ip_adapter.diffusers_adapters
nunchaku.models.ip_adapter.utils
nunchaku.models.ip_adapter.utils
================================
.. automodule:: nunchaku.models.ip_adapter.utils
:members:
:undoc-members:
:show-inheritance:
...@@ -7,4 +7,5 @@ nunchaku.models ...@@ -7,4 +7,5 @@ nunchaku.models
nunchaku.models.transformers nunchaku.models.transformers
nunchaku.models.text_encoders nunchaku.models.text_encoders
nunchaku.models.pulid nunchaku.models.pulid
nunchaku.models.ip_adapter
nunchaku.models.safety_checker nunchaku.models.safety_checker
IP Adapter
==========
Nunchaku supports `IP Adapter <hf_ip-adapterv2_>`_, an adapter achieving image prompt capability for the FLUX.1-dev
.. literalinclude:: ../../../examples/flux.1-dev-IP-adapter.py
:language: python
:caption: IP Adapter Example (`examples/flux.1-dev-IP-adapter.py <https://github.com/nunchaku-tech/nunchaku/blob/main/examples/flux.1-dev-IP-adapter.py>`__)
:linenos:
The IP Adapter integration in Nunchaku follows these main steps:
**Model Initialization**:
- Load a Nunchaku FLUX.1-dev transformer model using :meth:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel.from_pretrained`.
- Initialize the FLUX pipeline with :class:`diffusers.FluxPipeline`, passing the transformer and setting the appropriate precision.
**IP Adapter Loading**:
- Use ``pipeline.load_ip_adapter`` to load the IP Adapter weights and the CLIP image encoder.
- ``pretrained_model_name_or_path_or_dict``: Hugging Face repo or local path for the IP Adapter weights.
- ``weight_name``: Name of the weights file (e.g., ``ip_adapter.safetensors``).
- ``image_encoder_pretrained_model_name_or_path``: Name or path of the CLIP image encoder.
- Apply the IP Adapter to the pipeline with :func:`~nunchaku.models.ip_adapter.diffusers_adapters.apply_IPA_on_pipe`, specifying the adapter scale and repo ID.
**Caching (Optional)**:
Enable caching for faster inference and reduced memory usage with :func:`~nunchaku.caching.diffusers_adapters.apply_cache_on_pipe`. See :doc:`fbcache` for more details.
**Image Generation**:
- Load the image to be used as the image prompt (IP Adapter reference).
- Call the pipeline with:
- ``prompt``: The text prompt for generation.
- ``ip_adapter_image``: The reference image (must be RGB).
- The output image will reflect both the text prompt and the visual style/content of the reference image.
PuLID PuLID
===== =====
Nunchaku integrates `PuLID <_pulid_paper>`_, a tuning-free identity customization method for text-to-image generation. .. image:: https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/ComfyUI-nunchaku/workflows/nunchaku-flux.1-dev-pulid.png
Nunchaku integrates `PuLID <paper_pulid_>`_, a tuning-free identity customization method for text-to-image generation.
This feature allows you to generate images that maintain specific identity characteristics from reference photos. This feature allows you to generate images that maintain specific identity characteristics from reference photos.
.. literalinclude:: ../../../examples/flux.1-dev-pulid.py .. literalinclude:: ../../../examples/flux.1-dev-pulid.py
...@@ -9,13 +11,10 @@ This feature allows you to generate images that maintain specific identity chara ...@@ -9,13 +11,10 @@ This feature allows you to generate images that maintain specific identity chara
:caption: PuLID Example (`examples/flux.1-dev-pulid.py <https://github.com/nunchaku-tech/nunchaku/blob/main/examples/flux.1-dev-pulid.py>`__) :caption: PuLID Example (`examples/flux.1-dev-pulid.py <https://github.com/nunchaku-tech/nunchaku/blob/main/examples/flux.1-dev-pulid.py>`__)
:linenos: :linenos:
Implementation Overview
-----------------------
The PuLID integration follows these key steps: The PuLID integration follows these key steps:
**Model Initialization** (lines 12-20): **Model Initialization** (lines 12-20):
Load a Nunchaku FLUX.1-dev model using :class:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel` Load a Nunchaku FLUX.1-dev model using :meth:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel.from_pretrained`
and initialize the FLUX PuLID pipeline with :class:`~nunchaku.pipeline.pipeline_flux_pulid.PuLIDFluxPipeline`. and initialize the FLUX PuLID pipeline with :class:`~nunchaku.pipeline.pipeline_flux_pulid.PuLIDFluxPipeline`.
**Forward Method Override** (line 22): **Forward Method Override** (line 22):
......
...@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision ...@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-canny-dev/svdq-{precision}_r32-flux.1-canny-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-canny-dev/svdq-{precision}_r32-flux.1-canny-dev.safetensors"
) )
pipe = FluxControlPipeline.from_pretrained( pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Canny-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-Canny-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision ...@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors"
) )
pipe = FluxControlPipeline.from_pretrained( pipe = FluxControlPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision ...@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-depth-dev/svdq-{precision}_r32-flux.1-depth-dev.safetensors"
) )
pipe = FluxControlPipeline.from_pretrained( pipe = FluxControlPipeline.from_pretrained(
......
import torch
from diffusers import FluxPipeline
from diffusers.utils import load_image
from nunchaku import NunchakuFluxTransformer2dModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.models.ip_adapter.diffusers_adapters import apply_IPA_on_pipe
from nunchaku.utils import get_precision
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
pipeline.load_ip_adapter(
pretrained_model_name_or_path_or_dict="XLabs-AI/flux-ip-adapter-v2",
weight_name="ip_adapter.safetensors",
image_encoder_pretrained_model_name_or_path="openai/clip-vit-large-patch14",
)
apply_IPA_on_pipe(pipeline, ip_adapter_scale=1.1, repo_id="XLabs-AI/flux-ip-adapter-v2")
apply_cache_on_pipe(
pipeline,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
IP_image = load_image(
"https://huggingface.co/datasets/nunchaku-tech/test-data/resolve/main/ComfyUI-nunchaku/inputs/monalisa.jpg"
)
image = pipeline(
prompt="holding an sign saying 'SVDQuant is fast!'",
ip_adapter_image=IP_image.convert("RGB"),
num_inference_steps=50,
).images[0]
image.save(f"flux.1-dev-IP-adapter-{precision}.png")
...@@ -7,7 +7,7 @@ from nunchaku.utils import get_precision ...@@ -7,7 +7,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -15,7 +15,7 @@ controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend ...@@ -15,7 +15,7 @@ controlnet = FluxMultiControlNetModel([controlnet_union]) # we always recommend
precision = get_precision() precision = get_precision()
need_offload = get_gpu_memory() < 36 need_offload = get_gpu_memory() < 36
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors", f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
torch_dtype=torch.bfloat16, torch_dtype=torch.bfloat16,
offload=need_offload, offload=need_offload,
) )
......
...@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision ...@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision
precision = get_precision() precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
......
...@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision ...@@ -8,7 +8,7 @@ from nunchaku.utils import get_precision
precision = get_precision() precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors", f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
offload=True, offload=True,
) )
......
...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision ...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
) )
transformer.set_attention_impl("nunchaku-fp16") # set attention implementation to fp16 transformer.set_attention_impl("nunchaku-fp16") # set attention implementation to fp16
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
......
...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision ...@@ -6,7 +6,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
...@@ -7,7 +7,7 @@ from nunchaku.utils import get_precision ...@@ -7,7 +7,7 @@ from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2dModel.from_pretrained( transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors" f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
) )
pipeline = FluxPipeline.from_pretrained( pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16 "black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment