Unverified Commit 51732b7a authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

docs: add the docs of nunchaku (#517)

* update sphinx docs

* update the doc configration

* configure doxyfile

* start building the docs

* building docs

* building docs

* update docs

* finish the installation documents

* finish the installation documents

* finish the installation documents

* start using rst

* use rst instead of md

* need to figure out how to maintain rst

* update

* make linter happy

* update

* link management

* rst is hard to handle

* fix the title-only errors

* setup the rst linter

* add the lora files

* lora added, need to be more comprehensive

* update

* update

* finished lora docs

* finished the LoRA docs

* finished the cn docs

* finished the cn docs

* finished the qencoder docs

* finished the cpu offload

* finished the offload docs

* add the attention docs

* finished the attention docs

* finished the fbcache

* update

* finished the pulid docs

* make linter happy

* make linter happy

* add kontext

* update

* add the docs for gradio demos

* add docs for test.py

* add the docs for utils.py

* make the doc better displayed

* update

* update

* add some docs

* style: make linter happy

* add docs

* update

* add caching docs

* make linter happy

* add api docs

* fix the t5 docs

* fix the t5 docs

* fix the t5 docs

* hide the private functions

* update

* fix the docs of caching utils

* update docs

* finished the docstring of nunchaku cahcing

* update packer

* revert the docs

* better docs for packer.py

* better docs for packer.py

* better docs for packer.py

* better docs for packer.py

* update

* update docs

* caching done

* caching done

* lora

* lora

* lora

* update

* python docs

* reorg docs

* add the initial version of faq

* update

* make linter happy

* reorg

* reorg

* add crossref

* make linter happy

* better docs

* make linter happy

* preliminary version of the docs done

* update

* update README

* update README

* docs done

* update README

* update docs

* not using datasets 4 for now
parent 189be8bf
FLUX.1-Kontext
==============
.. image:: https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/assets/kontext.png
:alt: FLUX.1-Kontext-dev integration with Nunchaku
Nunchaku supports `FLUX-Kontext-dev <_flux1_kontext_dev>`_,
an advanced model that enables precise image editing through natural language prompts.
The implementation follows the same pattern as described in :doc:`Basic Usage <./basic_usage>`.
.. literalinclude:: ../../../examples/flux.1-kontext-dev.py
:language: python
:caption: Running FLUX.1-Kontext-dev (`examples/flux.1-kontext-dev.py <https://github.com/mit-han-lab/nunchaku/blob/main/examples/flux.1-kontext-dev.py>`__)
:linenos:
Customized LoRAs
================
.. image:: https://huggingface.co/mit-han-lab/nunchaku-artifacts/resolve/main/nunchaku/assets/lora.jpg
:alt: LoRA integration with Nunchaku
Single LoRA
-----------
`Nunchaku <nunchaku_repo_>`_ seamlessly integrates with off-the-shelf LoRAs without requiring requantization.
Instead of fusing the LoRA branch into the main branch, we directly concatenate the LoRA weights to our low-rank branch.
As Nunchaku uses fused kernel, the overhead of a separate low-rank branch is largely reduced.
Below is an example of running FLUX.1-dev with `Ghibsky <ghibsky_lora_>`_ LoRA.
.. literalinclude:: ../../../examples/flux.1-dev-lora.py
:language: python
:caption: Running FLUX.1-dev with `Ghibsky <ghibsky_lora_>`_ LoRA (`examples/flux.1-dev-lora.py <https://github.com/mit-han-lab/nunchaku/blob/main/examples/flux.1-dev-lora.py>`__)
:linenos:
:emphasize-lines: 16-19
The LoRA integration in Nunchaku works through two key methods:
**Loading LoRA Parameters** (lines 16-17):
The :meth:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel.update_lora_params` method loads LoRA weights from a safetensors file. It supports:
- **Local file path**: ``"/path/to/your/lora.safetensors"``
- **HuggingFace repository with specific file**: ``"aleksa-codes/flux-ghibsky-illustration/lora.safetensors"``. The system automatically downloads and caches the LoRA file on first access.
**Controlling LoRA Strength** (lines 18-19):
The :meth:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel.set_lora_strength` method sets the LoRA strength parameter, which controls how much influence the LoRA has on the final output. A value of 1.0 applies the full LoRA effect, while lower values (e.g., 0.5) apply a more subtle influence.
Multiple LoRAs
--------------
To load multiple LoRAs simultaneously, Nunchaku provides the :func:`~nunchaku.lora.flux.compose.compose_lora` function,
which combines multiple LoRA weights into a single composed LoRA before loading.
This approach enables efficient multi-LoRA inference without requiring separate loading operations.
The following example demonstrates how to compose and load multiple LoRAs:
.. literalinclude:: ../../../examples/flux.1-dev-multiple-lora.py
:language: python
:caption: Running FLUX.1-dev with `Ghibsky <ghibsky_lora_>`_ and `FLUX-Turbo <turbo_lora_>`_ LoRA (`examples/flux.1-dev-multiple-lora.py <https://github.com/mit-han-lab/nunchaku/blob/main/examples/flux.1-dev-multiple-lora.py>`__)
:linenos:
:emphasize-lines: 17-23
The :func:`~nunchaku.lora.flux.compose.compose_lora` function accepts a list of tuples, where each tuple contains:
- **LoRA path**: Either a local file path or HuggingFace repository path with specific file
- **Strength value**: A float value (typically between 0.0 and 1.0) that controls the influence of that specific LoRA
This composition method allows for precise control over individual LoRA strengths while maintaining computational efficiency through a single loading operation.
.. warning::
When using multiple LoRAs,
the :meth:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel.set_lora_strength` method
applies a uniform strength value across all loaded LoRAs, which may not provide the desired level of control.
For precise management of individual LoRA influences, specify strength values for each LoRA within the
:func:`~nunchaku.lora.flux.compose.compose_lora` function call.
.. warning::
Nunchaku's current implementation maintains the LoRA branch separately from the main branch.
This design choice may impact inference performance when the composed rank becomes large (e.g., > 256).
A future release will include quantization tools to fuse the LoRA branch into the main branch.
LoRA Conversion
---------------
Nunchaku utilizes the `Diffusers <diffusers_repo_>`_ LoRA format as an intermediate representation for converting LoRAs to Nunchaku's native format.
Both the :meth:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel.update_lora_params` method and :func:`~nunchaku.lora.flux.compose.compose_lora` function internally invoke the `to_diffusers <to_diffusers_lora_>`_ method to convert LoRAs to the `Diffusers <diffusers_repo_>`_ format.
If LoRA functionality is not working as expected, verify that the LoRA has been properly converted to the `Diffusers <diffusers_repo_>`_ format. Please check `to_diffusers <to_diffusers_lora_>`_ for more details.
Following the conversion to `Diffusers <diffusers_repo_>`_ format,
the :meth:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel.update_lora_params`
method calls the :func:`~nunchaku.lora.flux.nunchaku_converter.to_nunchaku` function
to perform the final conversion to Nunchaku's format.
Exporting Converted LoRAs
-------------------------
The current implementation employs single-threaded conversion, which may result in extended processing times, particularly for large LoRA files.
To address this limitation, users can pre-compose LoRAs using the :mod:`nunchaku.lora.flux.compose` command-line interface.
The syntax is as follows:
.. code-block:: bash
python -m nunchaku.lora.flux.compose -i lora1.safetensors lora2.safetensors -s 0.8 0.6 -o composed_lora.safetensors
**Arguments**:
- ``-i``, ``--input-paths``: Paths to the LoRA safetensors files (supports multiple files)
- ``-s``, ``--strengths``: Strength values for each LoRA (must correspond to the number of input files)
- ``-o``, ``--output-path``: Output path for the composed LoRA safetensors file
This command composes the specified LoRAs with their respective strength values and saves the result to the output file,
which can subsequently be loaded using :meth:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel.update_lora_params` for optimized inference performance.
Following composition, users may either load the file directly
(via the ComfyUI LoRA loader or :meth:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel.update_lora_params`)
or utilize :mod:`nunchaku.lora.flux.convert` to convert the composed LoRA to Nunchaku's format and export it.
The syntax is as follows:
.. code-block:: bash
python -m nunchaku.lora.flux.convert --lora-path composed_lora.safetensors --quant-path mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors --output-root ./converted --dtype bfloat16
**Arguments**:
- ``--lora-path``: Path to the LoRA weights safetensor file (required)
- ``--quant-path``: Path to the quantized model safetensor file (default: ``mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors``)
- ``--output-root``: Root directory for the output safetensor file (default: parent directory of the lora file)
- ``--lora-name``: Name of the LoRA weights (optional, auto-generated if not provided)
- ``--dtype``: Data type of the converted weights, either ``bfloat16`` or ``float16`` (default: ``bfloat16``)
This command converts the LoRA to Nunchaku's format and saves it with an appropriate filename based on the quantization precision (fp4 or int4).
.. warning::
LoRAs in Nunchaku format should not be composed with other LoRAs. Additionally, LoRA strength values are permanently embedded in the composed LoRA. To apply different strength values, the LoRAs must be recomposed.
CPU Offload
===========
Nunchaku provides CPU offload capabilities to significantly reduce GPU memory usage with minimal performance impact.
This feature is fully compatible with `Diffusers <diffusers_repo>`_ offload mechanisms.
.. literalinclude:: ../../../examples/flux.1-dev-offload.py
:language: python
:caption: Running FLUX.1-dev with CPU Offload (`examples/flux.1-dev-offload.py <https://github.com/mit-han-lab/nunchaku/blob/main/examples/flux.1-dev-offload.py>`__)
:linenos:
:emphasize-lines: 9, 13, 14
The following modifications are required compared to `basic usage <../basic_usage/basic_usage>`_:
**Nunchaku CPU Offload** (line 9):
Enable Nunchaku's built-in CPU offload by setting ``offload=True`` during transformer initialization.
This intelligently offloads inactive model components to CPU memory, reducing GPU memory footprint.
**Diffusers Sequential Offload** (line 14):
Activate Diffusers' sequential CPU offload with ``pipeline.enable_sequential_cpu_offload()``.
This provides automatic device management and additional memory optimization.
.. note::
When using CPU offload, manual device placement with ``.to('cuda')`` is unnecessary,
as ``pipeline.enable_sequential_cpu_offload()`` handles all device management automatically.
PuLID
=====
Nunchaku integrates `PuLID <_pulid_paper>`_, a tuning-free identity customization method for text-to-image generation.
This feature allows you to generate images that maintain specific identity characteristics from reference photos.
.. literalinclude:: ../../../examples/flux.1-dev-pulid.py
:language: python
:caption: PuLID Example (`examples/flux.1-dev-pulid.py <https://github.com/mit-han-lab/nunchaku/blob/main/examples/flux.1-dev-pulid.py>`__)
:linenos:
Implementation Overview
-----------------------
The PuLID integration follows these key steps:
**Model Initialization** (lines 12-20):
Load a Nunchaku FLUX.1-dev model using :class:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel`
and initialize the FLUX PuLID pipeline with :class:`~nunchaku.pipeline.pipeline_flux_pulid.PuLIDFluxPipeline`.
**Forward Method Override** (line 22):
Replace the transformer's forward method with PuLID's specialized implementation using
``MethodType(pulid_forward, pipeline.transformer)``.
This modification enables identity-aware generation capabilities.
See :meth:`~nunchaku.models.pulid.pulid_forward.pulid_forward` for more details.
**Reference Image Processing** (line 24):
Load and prepare the reference identity image that will guide the generation process.
This image defines the identity characteristics to be preserved in the output.
**Identity-Controlled Generation** (lines 26-32):
Execute the pipeline with identity-specific parameters:
- ``id_image``: The reference identity image
- ``id_weight``: Identity influence strength (range: 0.0-1.0, where 1.0 provides maximum identity preservation)
- Standard generation parameters (prompt, inference steps, guidance scale)
The generated image will incorporate the identity features from the reference photo while adhering to the provided text prompt.
Quantized Text Encoders
=======================
Nunchaku provides a quantized T5 encoder for FLUX.1 to reduce GPU memory usage.
.. literalinclude:: ../../../examples/flux.1-dev-qencoder.py
:language: python
:caption: Running FLUX.1-dev with Quantized T5 (`examples/flux.1-dev-qencoder.py <https://github.com/mit-han-lab/nunchaku/blob/main/examples/flux.1-dev-qencoder.py>`__)
:linenos:
:emphasize-lines: 11, 14
The key changes from `Basic Usage <./basic_usage>`_ are:
**Loading Quantized T5 Encoder** (line 11):
Use :class:`~nunchaku.models.text_encoders.t5_encoder.NunchakuT5EncoderModel` to load the quantized encoder.
This reduces GPU memory usage while maintaining quality. Supports local or Hugging Face remote paths.
**Pipeline Integration** (line 14):
Pass the quantized encoder to the pipeline via the ``text_encoder_2`` parameter,
replacing the default T5 encoder.
.. note::
The quantized T5 encoder currently only supports CUDA backend. Turing GPUs will be supported later.
"""
Pipeline adapters for first-block caching in Nunchaku models.
"""
from diffusers import DiffusionPipeline
def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
"""
Apply caching to a diffusers pipeline with automatic type detection.
This function serves as a unified interface for applying first-block caching
to different types of diffusion pipelines of nunchaku models. It automatically detects the
pipeline type based on the class name and delegates to the appropriate
caching implementation.
Parameters
----------
pipe : DiffusionPipeline
The diffusers pipeline to apply caching to.
*args
Variable positional arguments passed to the specific caching function.
**kwargs
Variable keyword arguments passed to the specific caching function.
Common arguments include:
- ``residual_diff_threshold`` (float): Similarity threshold for cache validity
- ``use_double_fb_cache`` (bool): Whether to use double first-block caching
- ``verbose`` (bool): Whether to enable verbose caching messages
Returns
-------
DiffusionPipeline
The same pipeline instance with caching applied.
Raises
------
ValueError
If the pipeline type is not supported (doesn't start with "Flux" or "Sana").
AssertionError
If the input is not a DiffusionPipeline instance.
Examples
--------
With a Flux pipeline:
.. code-block:: python
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-fp4_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
cached_pipe = apply_cache_on_pipe(pipeline, residual_diff_threshold=0.12)
.. note::
This function modifies the pipeline in-place and returns the same instance.
.. warning::
Only pipelines with class names starting with ``Flux`` or ``Sana`` are supported.
"""
assert isinstance(pipe, DiffusionPipeline)
pipe_cls_name = pipe.__class__.__name__
......
"""
Adapters for efficient caching in Flux diffusion pipelines.
This module enables advanced first-block caching for Flux models, supporting both single and double caching strategies. It provides:
- :func:`apply_cache_on_transformer` — Add caching to a ``FluxTransformer2DModel``.
- :func:`apply_cache_on_pipe` — Add caching to a complete Flux pipeline.
Caching is context-managed and only active within a cache context.
"""
import functools
import unittest
......@@ -15,6 +26,34 @@ def apply_cache_on_transformer(
residual_diff_threshold_multi: float | None = None,
residual_diff_threshold_single: float = 0.1,
):
"""
Enable caching for a ``FluxTransformer2DModel``.
This function wraps the transformer to use cached transformer blocks for faster inference.
Supports both single and double first-block caching with configurable thresholds.
Parameters
----------
transformer : FluxTransformer2DModel
The transformer to modify.
use_double_fb_cache : bool, optional
If True, cache both multi-head and single-head attention blocks (default: False).
residual_diff_threshold : float, optional
Default similarity threshold for caching (default: 0.12).
residual_diff_threshold_multi : float, optional
Threshold for multi-head (double) blocks. If None, uses ``residual_diff_threshold``.
residual_diff_threshold_single : float, optional
Threshold for single-head blocks (default: 0.1).
Returns
-------
FluxTransformer2DModel
The transformer with caching enabled.
Notes
-----
If already cached, only updates thresholds. Caching is only active within a cache context.
"""
if residual_diff_threshold_multi is None:
residual_diff_threshold_multi = residual_diff_threshold
......@@ -59,7 +98,31 @@ def apply_cache_on_transformer(
return transformer
def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
def apply_cache_on_pipe(pipe: DiffusionPipeline, **kwargs):
"""
Enable caching for a complete Flux diffusion pipeline.
This function wraps the pipeline's ``__call__`` method to manage cache contexts,
and optionally applies transformer-level caching.
Parameters
----------
pipe : DiffusionPipeline
The Flux pipeline to modify.
shallow_patch : bool, optional
If True, only patch the pipeline (do not modify the transformer). Useful for testing (default: False).
**kwargs
Passed to :func:`apply_cache_on_transformer` (e.g., ``use_double_fb_cache``, ``residual_diff_threshold``, etc.).
Returns
-------
DiffusionPipeline
The pipeline with caching enabled.
Notes
-----
The pipeline class's ``__call__`` is patched for all instances.
"""
if not getattr(pipe, "_is_cached", False):
original_call = pipe.__class__.__call__
......@@ -71,7 +134,4 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False,
pipe.__class__.__call__ = new_call
pipe.__class__._is_cached = True
if not shallow_patch:
apply_cache_on_transformer(pipe.transformer, **kwargs)
return pipe
"""
Adapters for efficient caching in SANA diffusion pipelines.
This module enables first-block caching for SANA models, providing:
- :func:`apply_cache_on_transformer` — Add caching to a ``SanaTransformer2DModel``.
- :func:`apply_cache_on_pipe` — Add caching to a complete SANA pipeline.
Caching is context-managed and only active within a cache context.
"""
import functools
import unittest
......@@ -8,6 +19,28 @@ from ...caching import utils
def apply_cache_on_transformer(transformer: SanaTransformer2DModel, *, residual_diff_threshold=0.12):
"""
Enable caching for a ``SanaTransformer2DModel``.
This function wraps the transformer to use cached transformer blocks for faster inference.
Uses single first-block caching with configurable similarity thresholds.
Parameters
----------
transformer : SanaTransformer2DModel
The transformer to modify.
residual_diff_threshold : float, optional
Similarity threshold for caching (default: 0.12).
Returns
-------
SanaTransformer2DModel
The transformer with caching enabled.
Notes
-----
If already cached, returns the transformer unchanged. Caching is only active within a cache context.
"""
if getattr(transformer, "_is_cached", False):
return transformer
......@@ -36,7 +69,29 @@ def apply_cache_on_transformer(transformer: SanaTransformer2DModel, *, residual_
return transformer
def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
def apply_cache_on_pipe(pipe: DiffusionPipeline, **kwargs):
"""
Enable caching for a complete SANA diffusion pipeline.
This function wraps the pipeline's ``__call__`` method to manage cache contexts,
and applies transformer-level caching.
Parameters
----------
pipe : DiffusionPipeline
The SANA pipeline to modify.
**kwargs
Passed to :func:`apply_cache_on_transformer` (e.g., ``residual_diff_threshold``).
Returns
-------
DiffusionPipeline
The pipeline with caching enabled.
Notes
-----
The pipeline class's ``__call__`` is patched for all instances.
"""
if not getattr(pipe, "_is_cached", False):
original_call = pipe.__class__.__call__
......@@ -48,7 +103,6 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False,
pipe.__class__.__call__ = new_call
pipe.__class__._is_cached = True
if not shallow_patch:
apply_cache_on_transformer(pipe.transformer, **kwargs)
apply_cache_on_transformer(pipe.transformer, **kwargs)
return pipe
"""
This file is deprecated.
TeaCache: Temporal Embedding Analysis Caching for Flux Transformers.
This module implements TeaCache, a temporal caching mechanism that optimizes
transformer model inference by skipping computation when input changes are
below a threshold. The approach is based on Temporal Embedding Analysis (TEA)
that tracks the relative L1 distance of modulated inputs across timesteps.
The TeaCache system works by:
1. Analyzing the modulated input from the first transformer block
2. Computing a relative L1 distance compared to the previous timestep
3. Applying a rescaling function to the distance metric
4. Skipping transformer computation when accumulated distance is below threshold
5. Reusing previous residual computations for efficiency
Key Components:
TeaCache: Context manager for applying temporal caching to transformer models
make_teacache_forward: Factory function that creates a cached forward method
The caching strategy is particularly effective for diffusion models during
inference where consecutive timesteps often have similar inputs, allowing
significant computational savings without meaningful quality loss.
Example:
Basic usage with a Flux transformer::
from nunchaku.caching.teacache import TeaCache
from diffusers import FluxTransformer2DModel
model = FluxTransformer2DModel.from_pretrained("black-forest-labs/FLUX.1-dev")
with TeaCache(model, num_steps=50, rel_l1_thresh=0.6, skip_steps=10):
# Model forward passes will use temporal caching
for step in range(50):
output = model(inputs_for_step)
Note:
The rescaling function uses polynomial coefficients optimized for Flux models:
[4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01]
"""
from types import MethodType
from typing import Any, Callable, Optional, Union
......@@ -16,6 +59,36 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_steps: int = 0) -> Callable:
"""
Create a cached forward method for Flux transformers using TeaCache.
This factory function creates a modified forward method that implements temporal
caching based on the relative L1 distance of modulated inputs. The caching
decision is made by analyzing the first transformer block's modulated input
and comparing it to the previous timestep.
Args:
num_steps (int, optional): Total number of inference steps. Used to determine
when to reset the counter. Defaults to 50.
rel_l1_thresh (float, optional): Relative L1 distance threshold for caching.
Lower values mean more aggressive caching. Defaults to 0.6.
skip_steps (int, optional): Number of initial steps to skip caching.
Useful for allowing the model to stabilize. Defaults to 0.
Returns:
Callable: A cached forward method that can be bound to a transformer model
Example:
>>> model = FluxTransformer2DModel.from_pretrained("model_name")
>>> cached_forward = make_teacache_forward(num_steps=50, rel_l1_thresh=0.6)
>>> model.forward = cached_forward.__get__(model, type(model))
Note:
The rescaling function uses polynomial coefficients optimized for Flux models.
The accumulated distance is reset when it exceeds the threshold or at the
beginning/end of the inference sequence.
"""
def teacache_forward(
self: Union[FluxTransformer2DModel, NunchakuFluxTransformer2dModel],
hidden_states: torch.Tensor,
......@@ -332,10 +405,52 @@ def make_teacache_forward(num_steps: int = 50, rel_l1_thresh: float = 0.6, skip_
return teacache_forward
# A context manager to add teacache support to a block of code
# When the context manager is applied, the model passed to the context manager is modified
# to support teacache
class TeaCache:
"""
Context manager for applying TeaCache temporal caching to transformer models.
This class provides a context manager that temporarily modifies a Flux transformer
model to use TeaCache temporal caching. When entering the context, the model's
forward method is replaced with a cached version that tracks temporal changes
and skips computation when appropriate.
Args:
model (Union[FluxTransformer2DModel, NunchakuFluxTransformer2dModel]):
The transformer model to apply caching to
num_steps (int, optional): Total number of inference steps. Defaults to 50.
rel_l1_thresh (float, optional): Relative L1 distance threshold for caching.
Lower values enable more aggressive caching. Defaults to 0.6.
skip_steps (int, optional): Number of initial steps to skip caching.
Useful for model stabilization. Defaults to 0.
enabled (bool, optional): Whether caching is enabled. If False, the model
behaves normally. Defaults to True.
Attributes:
model: Reference to the transformer model
num_steps (int): Total number of inference steps
rel_l1_thresh (float): Caching threshold
skip_steps (int): Number of steps to skip caching
enabled (bool): Caching enabled flag
previous_model_forward: Original forward method (for restoration)
Example:
Basic usage::
with TeaCache(model, num_steps=50, rel_l1_thresh=0.6):
for step in range(50):
output = model(inputs[step])
Disabling caching conditionally::
with TeaCache(model, enabled=use_caching):
# Model will use caching only if use_caching is True
output = model(inputs)
Note:
The context manager automatically restores the original forward method
when exiting, ensuring the model can be used normally afterward.
"""
def __init__(
self,
model: Union[FluxTransformer2DModel, NunchakuFluxTransformer2dModel],
......@@ -352,6 +467,19 @@ class TeaCache:
self.previous_model_forward = self.model.forward
def __enter__(self) -> "TeaCache":
"""
Enter the TeaCache context and apply caching to the model.
This method is called when entering the 'with' block. It replaces the
model's forward method with a cached version and initializes the
necessary state variables for tracking temporal changes.
Returns:
TeaCache: Self reference for context manager protocol
Note:
If caching is disabled (enabled=False), the model is left unchanged.
"""
if self.enabled:
# self.model.__class__.forward = make_teacache_forward(self.num_steps, self.rel_l1_thresh, self.skip_steps) # type: ignore
self.model.forward = MethodType(
......@@ -364,6 +492,21 @@ class TeaCache:
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
"""
Exit the TeaCache context and restore the original model.
This method is called when exiting the 'with' block. It restores the
model's original forward method and cleans up the state variables
that were added for caching.
Args:
exc_type: Exception type (if any occurred)
exc_value: Exception value (if any occurred)
traceback: Exception traceback (if any occurred)
Note:
If caching was disabled (enabled=False), no cleanup is performed.
"""
if self.enabled:
self.model.forward = self.previous_model_forward
del self.model.cnt
......
# This caching functionality is largely brought from https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/
"""
Caching utilities for transformer models.
Implements first-block caching to accelerate transformer inference by reusing computations
when input changes are minimal. Supports SANA and Flux architectures.
**Main Classes**
- :class:`CacheContext` : Manages cache buffers and incremental naming.
- :class:`SanaCachedTransformerBlocks` : Cached transformer blocks for SANA models.
- :class:`FluxCachedTransformerBlocks` : Cached transformer blocks for Flux models.
**Key Functions**
- :func:`get_buffer`, :func:`set_buffer` : Cache buffer management.
- :func:`cache_context` : Context manager for cache operations.
- :func:`are_two_tensors_similar` : Tensor similarity check.
- :func:`apply_prev_hidden_states_residual` : Applies cached residuals.
- :func:`get_can_use_cache` : Checks cache usability.
- :func:`check_and_apply_cache` : Main cache logic.
**Caching Strategy**
1. Compute the first transformer block.
2. Compare the residual with the cached residual.
3. If similar, reuse cached results for the remaining blocks; otherwise, recompute and update cache.
.. note::
Adapted from ParaAttention:
https://github.com/chengzeyi/ParaAttention/src/para_attn/first_block_cache/
"""
import contextlib
import dataclasses
......@@ -16,10 +46,34 @@ num_single_transformer_blocks = 38 # FIXME
@dataclasses.dataclass
class CacheContext:
"""
Manages cache buffers and incremental naming for transformer model inference.
Attributes
----------
buffers : Dict[str, torch.Tensor]
Stores cached tensor buffers.
incremental_name_counters : DefaultDict[str, int]
Counters for generating unique incremental cache entry names.
"""
buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
def get_incremental_name(self, name=None):
"""
Generate an incremental cache entry name.
Parameters
----------
name : str, optional
Base name for the counter. If None, uses "default".
Returns
-------
str
Incremental name in the format ``"{name}_{counter}"``.
"""
if name is None:
name = "default"
idx = self.incremental_name_counters[name]
......@@ -27,28 +81,106 @@ class CacheContext:
return f"{name}_{idx}"
def reset_incremental_name(self):
"""
Reset all incremental name counters.
After calling this, :meth:`get_incremental_name` will start from 0 for each name.
"""
self.incremental_name_counters.clear()
# @torch.compiler.disable # This is a torchscript feature
def get_buffer(self, name=str):
def get_buffer(self, name: str) -> Optional[torch.Tensor]:
"""
Retrieve a cached tensor buffer by name.
Parameters
----------
name : str
Name of the buffer to retrieve.
Returns
-------
torch.Tensor or None
The cached tensor if found, otherwise None.
"""
return self.buffers.get(name)
def set_buffer(self, name, buffer):
def set_buffer(self, name: str, buffer: torch.Tensor):
"""
Store a tensor buffer in the cache.
Parameters
----------
name : str
The name to associate with the buffer.
buffer : torch.Tensor
The tensor to cache.
"""
self.buffers[name] = buffer
def clear_buffers(self):
"""
Clear all cached tensor buffers.
Removes all stored tensors from the cache.
"""
self.buffers.clear()
@torch.compiler.disable
def get_buffer(name):
def get_buffer(name: str) -> torch.Tensor:
"""
Retrieve a cached tensor buffer from the current cache context.
Parameters
----------
name : str
The name of the buffer to retrieve.
Returns
-------
torch.Tensor or None
The cached tensor if found, otherwise None.
Raises
------
AssertionError
If no cache context is currently active.
Examples
--------
>>> with cache_context(create_cache_context()):
... set_buffer("my_tensor", torch.randn(2, 3))
... cached = get_buffer("my_tensor")
"""
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
return cache_context.get_buffer(name)
@torch.compiler.disable
def set_buffer(name, buffer):
def set_buffer(name: str, buffer: torch.Tensor):
"""
Store a tensor buffer in the current cache context.
Parameters
----------
name : str
The name to associate with the buffer.
buffer : torch.Tensor
The tensor to cache.
Raises
------
AssertionError
If no cache context is currently active.
Examples
--------
>>> with cache_context(create_cache_context()):
... set_buffer("my_tensor", torch.randn(2, 3))
... cached = get_buffer("my_tensor")
"""
cache_context = get_current_cache_context()
assert cache_context is not None, "cache_context must be set before"
cache_context.set_buffer(name, buffer)
......@@ -58,15 +190,62 @@ _current_cache_context = None
def create_cache_context():
"""
Create a new :class:`CacheContext` for managing cached computations.
Returns
-------
CacheContext
A new cache context instance.
Examples
--------
>>> context = create_cache_context()
>>> with cache_context(context):
... # Cached operations here
... pass
"""
return CacheContext()
def get_current_cache_context():
"""
Get the currently active cache context.
Returns:
CacheContext or None: The current cache context if one is active, None otherwise
Example:
>>> with cache_context(create_cache_context()):
... current = get_current_cache_context()
... assert current is not None
"""
return _current_cache_context
@contextlib.contextmanager
def cache_context(cache_context):
"""
Context manager to set the active cache context.
Sets the global cache context for the duration of the ``with`` block, restoring the previous context on exit.
Parameters
----------
cache_context : CacheContext
The cache context to activate.
Yields
------
None
Examples
--------
>>> context = create_cache_context()
>>> with cache_context(context):
... set_buffer("key", torch.tensor([1, 2, 3]))
... cached = get_buffer("key")
"""
global _current_cache_context
old_cache_context = _current_cache_context
_current_cache_context = cache_context
......@@ -77,7 +256,30 @@ def cache_context(cache_context):
@torch.compiler.disable
def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
def are_two_tensors_similar(t1: torch.Tensor, t2: torch.Tensor, *, threshold: float, parallelized: bool = False):
"""
Check if two tensors are similar based on relative L1 distance.
The relative distance is computed as
``mean(abs(t1 - t2)) / mean(abs(t1))`` and compared to ``threshold``.
Parameters
----------
t1 : torch.Tensor
First tensor.
t2 : torch.Tensor
Second tensor.
threshold : float
Similarity threshold. Tensors are similar if relative distance < threshold.
parallelized : bool, optional
Unused. For API compatibility.
Returns
-------
tuple of (bool, float)
- bool: True if tensors are similar, False otherwise.
- float: The computed relative L1 distance.
"""
mean_diff = (t1 - t2).abs().mean()
mean_t1 = t1.abs().mean()
diff = (mean_diff / mean_t1).item()
......@@ -87,9 +289,34 @@ def are_two_tensors_similar(t1, t2, *, threshold, parallelized=False):
@torch.compiler.disable
def apply_prev_hidden_states_residual(
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states: torch.Tensor | None = None,
mode: str = "multi",
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply cached residuals to hidden states.
Parameters
----------
hidden_states : torch.Tensor
Current hidden states.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states (required for ``mode="multi"``).
mode : {"multi", "single"}, default: "multi"
Whether to apply residuals for Flux double blocks or single blocks.
Returns
-------
tuple or torch.Tensor
- If ``mode="multi"``: (updated_hidden_states, updated_encoder_hidden_states)
- If ``mode="single"``: updated_hidden_states
Raises
------
AssertionError
If required cached residuals are not found.
ValueError
If mode is not "multi" or "single".
"""
if mode == "multi":
hidden_states_residual = get_buffer("multi_hidden_states_residual")
assert hidden_states_residual is not None, "multi_hidden_states_residual must be set before"
......@@ -122,6 +349,31 @@ def apply_prev_hidden_states_residual(
def get_can_use_cache(
first_hidden_states_residual: torch.Tensor, threshold: float, parallelized: bool = False, mode: str = "multi"
):
"""
Check if cached computations can be reused based on residual similarity.
Parameters
----------
first_hidden_states_residual : torch.Tensor
Current first block residual.
threshold : float
Similarity threshold for cache validity.
parallelized : bool, optional
Whether computation is parallelized. Default is False.
mode : {"multi", "single"}, optional
Caching mode. Default is "multi".
Returns
-------
tuple of (bool, float)
- bool: True if cache can be used (residuals are similar), False otherwise.
- float: The computed similarity difference, or threshold if no cache exists.
Raises
------
ValueError
If mode is not "multi" or "single".
"""
if mode == "multi":
buffer_name = "first_multi_hidden_states_residual"
elif mode == "single":
......@@ -155,6 +407,42 @@ def check_and_apply_cache(
call_remaining_fn,
remaining_kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], float]:
"""
Check and apply cache based on residual similarity.
This function determines whether cached results can be used by comparing the
first block residuals. If the cache is valid, it applies cached computations;
otherwise, it computes new values and updates the cache.
Parameters
----------
first_residual : torch.Tensor
First block residual for similarity comparison.
hidden_states : torch.Tensor
Current hidden states.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states (required for "multi" mode).
threshold : float
Similarity threshold for cache validity.
parallelized : bool
Whether computation is parallelized.
mode : {"multi", "single"}
Caching mode.
verbose : bool
Whether to print caching status messages.
call_remaining_fn : callable
Function to call remaining transformer blocks.
remaining_kwargs : dict
Additional keyword arguments for `call_remaining_fn`.
Returns
-------
tuple
(updated_hidden_states, updated_encoder_hidden_states, threshold)
- updated_hidden_states (torch.Tensor)
- updated_encoder_hidden_states (torch.Tensor or None)
- threshold (float)
"""
can_use_cache, diff = get_can_use_cache(
first_residual,
threshold=threshold,
......@@ -200,6 +488,30 @@ def check_and_apply_cache(
class SanaCachedTransformerBlocks(nn.Module):
"""
Caching wrapper for SANA transformer blocks.
Parameters
----------
transformer : nn.Module
The original SANA transformer model to wrap.
residual_diff_threshold : float
Similarity threshold for cache validity.
verbose : bool, optional
Print caching status messages (default: False).
Attributes
----------
transformer : nn.Module
Reference to the original transformer.
transformer_blocks : nn.ModuleList
The transformer blocks to cache.
residual_diff_threshold : float
Current similarity threshold.
verbose : bool
Verbosity flag.
"""
def __init__(
self,
*,
......@@ -223,10 +535,21 @@ class SanaCachedTransformerBlocks(nn.Module):
post_patch_height=None,
post_patch_width=None,
):
"""
Forward pass with caching for SANA transformer blocks.
See also
--------
nunchaku.models.transformers.transformer_sana.NunchakuSanaTransformerBlocks.forward
Notes
-----
If batch size > 2 or residual_diff_threshold <= 0, caching is disabled for now.
"""
batch_size = hidden_states.shape[0]
if self.residual_diff_threshold <= 0.0 or batch_size > 2:
if batch_size > 2:
print("Batch size > 2 (for SANA CFG)" " currently not supported")
print("Batch size > 2 (for SANA CFG) currently not supported")
first_transformer_block = self.transformer_blocks[0]
hidden_states = first_transformer_block(
......@@ -299,6 +622,35 @@ class SanaCachedTransformerBlocks(nn.Module):
post_patch_height=None,
post_patch_width=None,
):
"""
Call remaining SANA transformer blocks.
Called when the cache is invalid. Skips the first layer and processes
the remaining blocks.
Parameters
----------
hidden_states : torch.Tensor
Hidden states from the first block.
attention_mask : torch.Tensor
Attention mask for the input.
encoder_hidden_states : torch.Tensor
Encoder hidden states.
encoder_attention_mask : torch.Tensor, optional
Encoder attention mask (default: None).
timestep : torch.Tensor, optional
Timestep tensor for conditioning (default: None).
post_patch_height : int, optional
Height after patch embedding (default: None).
post_patch_width : int, optional
Width after patch embedding (default: None).
Returns
-------
tuple[torch.Tensor, torch.Tensor]
- Final hidden states after processing all blocks.
- Residual difference for caching.
"""
first_transformer_block = self.transformer_blocks[0]
original_hidden_states = hidden_states
hidden_states = first_transformer_block(
......@@ -317,6 +669,54 @@ class SanaCachedTransformerBlocks(nn.Module):
class FluxCachedTransformerBlocks(nn.Module):
"""
Caching wrapper for Flux transformer blocks.
Parameters
----------
transformer : nn.Module
The original Flux transformer model.
use_double_fb_cache : bool, optional
Cache both double and single transformer blocks (default: True).
residual_diff_threshold_multi : float
Similarity threshold for double blocks.
residual_diff_threshold_single : float
Similarity threshold for single blocks.
return_hidden_states_first : bool, optional
If True, return hidden states first (default: True).
return_hidden_states_only : bool, optional
If True, return only hidden states (default: False).
verbose : bool, optional
Print caching status messages (default: False).
Attributes
----------
transformer : nn.Module
Reference to the original transformer.
transformer_blocks : nn.ModuleList
Double transformer blocks.
single_transformer_blocks : nn.ModuleList
Single transformer blocks.
use_double_fb_cache : bool
Whether both block types are cached.
residual_diff_threshold_multi : float
Threshold for double blocks.
residual_diff_threshold_single : float
Threshold for single blocks.
return_hidden_states_first : bool
Output order flag.
return_hidden_states_only : bool
Output type flag.
verbose : bool
Verbosity flag.
m : object
Nunchaku C model interface.
dtype : torch.dtype
Computation data type.
device : torch.device
Computation device.
"""
def __init__(
self,
*,
......@@ -347,6 +747,13 @@ class FluxCachedTransformerBlocks(nn.Module):
@staticmethod
def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
"""
Packs rotary embeddings for efficient computation.
See also
--------
nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformerBlocks.pack_rotemb
"""
assert rotemb.dtype == torch.float32
B = rotemb.shape[0]
M = rotemb.shape[1]
......@@ -368,6 +775,26 @@ class FluxCachedTransformerBlocks(nn.Module):
def update_residual_diff_threshold(
self, use_double_fb_cache=True, residual_diff_threshold_multi=0.12, residual_diff_threshold_single=0.09
):
"""
Update caching configuration parameters.
Parameters
----------
use_double_fb_cache : bool, optional
Use double first-block caching. Default is True.
residual_diff_threshold_multi : float, optional
Similarity threshold for Flux double blocks. Default is 0.12.
residual_diff_threshold_single : float, optional
Similarity threshold for Flux single blocks (used if
``use_double_fb_cache`` is False). Default is 0.09.
Examples
--------
>>> cached_blocks.update_residual_diff_threshold(
... use_double_fb_cache=False,
... residual_diff_threshold_multi=0.15
... )
"""
self.use_double_fb_cache = use_double_fb_cache
self.residual_diff_threshold_multi = residual_diff_threshold_multi
self.residual_diff_threshold_single = residual_diff_threshold_single
......@@ -383,6 +810,17 @@ class FluxCachedTransformerBlocks(nn.Module):
controlnet_single_block_samples=None,
skip_first_layer=False,
):
"""
Forward pass with advanced caching for Flux transformer blocks.
See also
--------
nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformerBlocks.forward
Notes
-----
If batch size > 2 or residual_diff_threshold <= 0, caching is disabled for now.
"""
batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
......@@ -550,6 +988,43 @@ class FluxCachedTransformerBlocks(nn.Module):
skip_first_layer=True,
txt_tokens=None,
):
"""
Call remaining Flux transformer blocks.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states.
temb : torch.Tensor
Time embedding tensor.
encoder_hidden_states : torch.Tensor
Encoder hidden states.
rotary_emb_img : torch.Tensor
Image rotary embeddings.
rotary_emb_txt : torch.Tensor
Text rotary embeddings.
rotary_emb_single : torch.Tensor
Single-head rotary embeddings.
controlnet_block_samples : list, optional
ControlNet block samples.
controlnet_single_block_samples : list, optional
ControlNet single block samples.
skip_first_layer : bool, optional
Whether to skip the first layer. Default is True.
txt_tokens : int, optional
Number of text tokens.
Returns
-------
hidden_states : torch.Tensor
Updated hidden states.
encoder_hidden_states : torch.Tensor
Updated encoder hidden states.
hidden_states_residual : torch.Tensor
Residual of hidden states.
enc_residual : torch.Tensor
Residual of encoder hidden states.
"""
original_dtype = hidden_states.dtype
original_device = hidden_states.device
original_hidden_states = hidden_states
......@@ -592,6 +1067,43 @@ class FluxCachedTransformerBlocks(nn.Module):
skip_first_layer=False,
txt_tokens=None,
):
"""
Call remaining Flux double blocks.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states.
temb : torch.Tensor
Time embedding tensor.
encoder_hidden_states : torch.Tensor
Encoder hidden states.
rotary_emb_img : torch.Tensor
Image rotary embeddings.
rotary_emb_txt : torch.Tensor
Text rotary embeddings.
rotary_emb_single : torch.Tensor
Single-head rotary embeddings.
controlnet_block_samples : list, optional
ControlNet block samples.
controlnet_single_block_samples : list, optional
ControlNet single block samples.
skip_first_layer : bool, optional
Whether to skip the first layer. Default is False.
txt_tokens : int, optional
Number of text tokens.
Returns
-------
hidden_states : torch.Tensor
Updated hidden states.
encoder_hidden_states : torch.Tensor
Updated encoder hidden states.
hidden_states_residual : torch.Tensor
Residual of hidden states.
enc_residual : torch.Tensor
Residual of encoder hidden states.
"""
start_idx = 1
original_hidden_states = hidden_states.clone()
original_encoder_hidden_states = encoder_hidden_states.clone()
......@@ -628,6 +1140,39 @@ class FluxCachedTransformerBlocks(nn.Module):
skip_first_layer=False,
txt_tokens=None,
):
"""
Call remaining Flux single blocks.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states (concatenated).
temb : torch.Tensor
Time embedding tensor.
encoder_hidden_states : torch.Tensor
Encoder hidden states (unused).
rotary_emb_img : torch.Tensor
Image rotary embeddings (unused).
rotary_emb_txt : torch.Tensor
Text rotary embeddings (unused).
rotary_emb_single : torch.Tensor
Single-head rotary embeddings.
controlnet_block_samples : list, optional
ControlNet block samples (unused).
controlnet_single_block_samples : list, optional
ControlNet single block samples (unused).
skip_first_layer : bool, optional
Whether to skip the first layer. Default is False.
txt_tokens : int, optional
Number of text tokens (unused).
Returns
-------
hidden_states : torch.Tensor
Updated hidden states.
hidden_states_residual : torch.Tensor
Residual of hidden states.
"""
start_idx = 1
original_hidden_states = hidden_states.clone()
......
"""
Compose multiple LoRA weights into a single LoRA for FLUX models.
This script merges several LoRA safetensors files into one, applying individual strength values to each.
**Example Usage:**
.. code-block:: bash
python -m nunchaku.lora.flux.compose \\
-i lora1.safetensors lora2.safetensors \\
-s 0.8 1.0 \\
-o composed_lora.safetensors
**Arguments:**
- ``-i``, ``--input-paths``: Input LoRA safetensors files (one or more).
- ``-s``, ``--strengths``: Strength value for each LoRA (must match number of inputs).
- ``-o``, ``--output-path``: Output path for the composed LoRA safetensors file.
This will merge ``lora1.safetensors`` (strength 0.8) and ``lora2.safetensors`` (strength 1.0) into ``composed_lora.safetensors``.
**Main Function**
:func:`compose_lora`
"""
import argparse
import os
......@@ -11,6 +38,44 @@ from .utils import is_nunchaku_format, load_state_dict_in_safetensors
def compose_lora(
loras: list[tuple[str | dict[str, torch.Tensor], float]], output_path: str | None = None
) -> dict[str, torch.Tensor]:
"""
Compose multiple LoRA weights into a single LoRA representation.
Parameters
----------
loras : list of (str or dict[str, torch.Tensor], float)
Each tuple contains:
- Path to a LoRA safetensors file or a LoRA weights dictionary.
- Strength/scale factor for that LoRA.
output_path : str, optional
Path to save the composed LoRA weights as a safetensors file. If None, does not save.
Returns
-------
dict[str, torch.Tensor]
The composed LoRA weights.
Raises
------
AssertionError
If LoRA weights are in Nunchaku format (must be converted to Diffusers format first)
or if tensor shapes are incompatible.
Notes
-----
- Converts all input LoRAs to Diffusers format.
- Handles QKV projection fusion for attention layers.
- Applies strength scaling to LoRA weights.
- Concatenates multiple LoRAs along appropriate dimensions.
- Handles normalization layers, bias vectors, and FLUX.1-tools LoRA compatibility.
Examples
--------
>>> lora_paths = [("lora1.safetensors", 0.8), ("lora2.safetensors", 0.6)]
>>> composed = compose_lora(lora_paths, "composed_lora.safetensors")
>>> lora_dicts = [({"layer.weight": torch.randn(10, 20)}, 1.0)]
>>> composed = compose_lora(lora_dicts)
"""
if len(loras) == 1:
if is_nunchaku_format(loras[0][0]) and (loras[0][1] - 1) < 1e-5:
if isinstance(loras[0][0], str):
......
"""
CLI tool to convert LoRA weights to Nunchaku format.
**Example Usage:**
.. code-block:: bash
python -m nunchaku.lora.flux.convert \\
--lora-path composed_lora.safetensors \\
--quant-path mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors \\
--output-root ./converted \\
--dtype bfloat16
**Arguments:**
- ``--lora-path``: Path to the LoRA weights safetensor file (required)
- ``--quant-path``: Path to the quantized model safetensor file (default: ``mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors``)
- ``--output-root``: Root directory for the output safetensor file (default: parent directory of the lora file)
- ``--lora-name``: Name of the LoRA weights (optional, auto-generated if not provided)
- ``--dtype``: Data type of the converted weights, either ``bfloat16`` or ``float16`` (default: ``bfloat16``)
**Main Function**
:func:`nunchaku.lora.flux.nunchaku_converter.to_nunchaku`
"""
import argparse
import os
......@@ -9,27 +35,26 @@ if __name__ == "__main__":
parser.add_argument(
"--quant-path",
type=str,
help="path to the quantized model safetensor file",
help="Path to the quantized model safetensors file.",
default="mit-han-lab/svdq-int4-flux.1-dev/transformer_blocks.safetensors",
)
parser.add_argument("--lora-path", type=str, required=True, help="path to LoRA weights safetensor file")
parser.add_argument("--output-root", type=str, default="", help="root to the output safetensor file")
parser.add_argument("--lora-name", type=str, default=None, help="name of the LoRA weights")
parser.add_argument("--lora-path", type=str, required=True, help="Path to LoRA weights safetensors file.")
parser.add_argument("--output-root", type=str, default="", help="Root directory for output safetensors file.")
parser.add_argument("--lora-name", type=str, default=None, help="Name for the output LoRA weights.")
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["bfloat16", "float16"],
help="data type of the converted weights",
help="Data type of the converted weights.",
)
args = parser.parse_args()
if is_nunchaku_format(args.lora_path):
print("Already in nunchaku format, no conversion needed.")
print("Already in Nunchaku format, no conversion needed.")
exit(0)
if not args.output_root:
# output to the parent directory of the lora safetensors file
args.output_root = os.path.dirname(args.lora_path)
if args.lora_name is None:
base_name = os.path.basename(args.lora_path)
......
"""
This module implements the functions to convert FLUX LoRA weights from various formats
to the Diffusers format, which will later be converted to Nunchaku format.
"""
import argparse
import logging
import os
......@@ -7,7 +12,7 @@ from diffusers.loaders import FluxLoraLoaderMixin
from diffusers.utils.state_dict_utils import convert_unet_state_dict_to_peft
from safetensors.torch import save_file
from .utils import load_state_dict_in_safetensors
from ...utils import load_state_dict_in_safetensors
# Get log level from environment variable (default to INFO)
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
......@@ -18,6 +23,19 @@ logger = logging.getLogger(__name__)
def handle_kohya_lora(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Convert Kohya LoRA format keys to Diffusers format.
Parameters
----------
state_dict : dict[str, torch.Tensor]
LoRA weights, possibly in Kohya format.
Returns
-------
dict[str, torch.Tensor]
LoRA weights in Diffusers format.
"""
# first check if the state_dict is in the kohya format
# like: https://civitai.com/models/1118358?modelVersionId=1256866
if any([not k.startswith("lora_transformer_") for k in state_dict.keys()]):
......@@ -57,6 +75,21 @@ def handle_kohya_lora(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Te
def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | None = None) -> dict[str, torch.Tensor]:
"""
Convert LoRA weights to Diffusers format, which will later be converted to Nunchaku format.
Parameters
----------
input_lora : str or dict[str, torch.Tensor]
Path to a safetensors file or a LoRA weight dictionary.
output_path : str, optional
If given, save the converted weights to this path.
Returns
-------
dict[str, torch.Tensor]
LoRA weights in Diffusers format.
"""
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
......@@ -64,7 +97,7 @@ def to_diffusers(input_lora: str | dict[str, torch.Tensor], output_path: str | N
tensors = handle_kohya_lora(tensors)
### convert the FP8 tensors to BF16
# Convert FP8 tensors to BF16
for k, v in tensors.items():
if v.dtype not in [torch.float64, torch.float32, torch.bfloat16, torch.float16]:
tensors[k] = v.to(torch.bfloat16)
......
# convert the diffusers lora to nunchaku format
"""Convert LoRA weights to Nunchaku format."""
"""
Nunchaku LoRA format converter for Flux models.
This module provides utilities to convert LoRA weights from Diffusers format
to Nunchaku format for efficient quantized inference in Flux models.
Key functions
-------------
- :func:`to_nunchaku` : Main conversion entry point
- :func:`fuse_vectors` : Vector fusion for bias terms
"""
import logging
import os
......@@ -26,6 +36,28 @@ logger = logging.getLogger(__name__)
def update_state_dict(
lhs: dict[str, torch.Tensor], rhs: dict[str, torch.Tensor], prefix: str = ""
) -> dict[str, torch.Tensor]:
"""
Update a state dictionary with values from another, optionally adding a prefix to keys.
Parameters
----------
lhs : dict[str, torch.Tensor]
Target state dictionary.
rhs : dict[str, torch.Tensor]
Source state dictionary.
prefix : str, optional
Prefix to add to keys from rhs.
Returns
-------
dict[str, torch.Tensor]
Updated state dictionary.
Raises
------
AssertionError
If any key already exists in the target dictionary.
"""
for rkey, value in rhs.items():
lkey = f"{prefix}.{rkey}" if prefix else rkey
assert lkey not in lhs, f"Key {lkey} already exists in the state dict."
......@@ -37,13 +69,20 @@ def update_state_dict(
def pack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Pack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
Pack the low-rank weight tensor for W4A4 linear layers.
Parameters
----------
weight : torch.Tensor
Low-rank weight tensor.
down : bool
If True, pack as down-projection; else as up-projection.
Returns
-------
torch.Tensor
Packed weight tensor.
"""
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
lane_n, lane_k = 1, 2 # lane_n is always 1, lane_k is 32 bits // 16 bits = 2
......@@ -66,13 +105,20 @@ def pack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
def unpack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Unpack Low-Rank Weight.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
"""
Unpack the low-rank weight tensor from W4A4 linear layers.
Parameters
----------
weight : torch.Tensor
Packed low-rank weight tensor.
down : bool
If True, unpack as down-projection; else as up-projection.
Returns
-------
torch.Tensor
Unpacked weight tensor.
"""
c, r = weight.shape
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
......@@ -96,6 +142,21 @@ def unpack_lowrank_weight(weight: torch.Tensor, down: bool) -> torch.Tensor:
def reorder_adanorm_lora_up(lora_up: torch.Tensor, splits: int) -> torch.Tensor:
"""
Reorder AdaNorm LoRA up-projection tensor for correct shape.
Parameters
----------
lora_up : torch.Tensor
LoRA up-projection tensor.
splits : int
Number of splits for AdaNorm.
Returns
-------
torch.Tensor
Reordered tensor.
"""
c, r = lora_up.shape
assert c % splits == 0
return lora_up.view(splits, c // splits, r).transpose(0, 1).reshape(c, r).contiguous()
......@@ -110,6 +171,41 @@ def convert_to_nunchaku_transformer_block_lowrank_dict( # noqa: C901
convert_map: dict[str, str],
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
"""
Convert LoRA weights for a transformer block from Diffusers to Nunchaku format.
Merges and converts LoRA weights from the original SVDQuant low-rank branch and an extra LoRA dict
for a given transformer block, producing a Nunchaku-compatible dictionary. Handles both fused and
unfused LoRA branches (e.g., qkv), and merges multiple LoRA branches as needed.
Parameters
----------
orig_state_dict : dict[str, torch.Tensor]
Original state dict with LoRA weights, keys like ``"{block}.{local}.lora_down"`` and ``"{block}.{local}.lora_up"``.
extra_lora_dict : dict[str, torch.Tensor]
Extra LoRA weights, keys like ``"{block}.{local}.lora_A.weight"`` and ``"{block}.{local}.lora_B.weight"``.
converted_block_name : str
Block name for output (e.g., ``"transformer_blocks.0"``).
candidate_block_name : str
Block name for input lookup (e.g., ``"blocks.0"``).
local_name_map : dict[str, str | list[str]]
Maps output local names (e.g., ``"attn.qkv"``) to one or more input local names.
convert_map : dict[str, str]
Maps output local names to conversion types: ``"adanorm_single"``, ``"adanorm_zero"``, or ``"linear"``.
default_dtype : torch.dtype, optional
Output tensor dtype (default: ``torch.bfloat16``).
Returns
-------
dict[str, torch.Tensor]
A dictionary containing the converted LoRA weights in Nunchaku format.
Notes
-----
- If both original and extra LoRA weights are present, they are merged by concatenation.
- Handles both fused and unfused attention projections (e.g., qkv).
- Applies special packing for W4A16 linear layers (e.g., ``"adanorm_single"`` and ``"adanorm_zero"``).
"""
logger.debug(f"Converting LoRA branch for block {candidate_block_name}...")
converted: dict[str, torch.Tensor] = {}
for converted_local_name, candidate_local_names in local_name_map.items():
......@@ -254,6 +350,37 @@ def convert_to_nunchaku_flux_single_transformer_block_lowrank_dict(
candidate_block_name: str,
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
"""
Convert LoRA weights for a single FLUX transformer block from Diffusers to Nunchaku format.
This function merges and converts LoRA weights from the original SVDQuant low-rank branch and an
extra LoRA dictionary for a given transformer block, producing a Nunchaku-compatible dictionary.
It handles both fused and unfused LoRA branches (e.g., qkv), and merges multiple LoRA branches as needed.
Parameters
----------
orig_state_dict : dict[str, torch.Tensor]
Original state dict with LoRA weights, keys like ``"{block}.{local}.lora_down"`` and ``"{block}.{local}.lora_up"``.
extra_lora_dict : dict[str, torch.Tensor]
Extra LoRA weights, keys like ``"{block}.{local}.lora_A.weight"`` and ``"{block}.{local}.lora_B.weight"``.
converted_block_name : str
Block name for output (e.g., ``"transformer_blocks.0"``).
candidate_block_name : str
Block name for input lookup (e.g., ``"blocks.0"``).
default_dtype : torch.dtype, optional
Output tensor dtype (default: ``torch.bfloat16``).
Returns
-------
dict[str, torch.Tensor]
A dictionary containing the converted LoRA weights in Nunchaku format.
Notes
-----
- If both original and extra LoRA weights are present, they are merged by concatenation.
- Handles both fused and unfused attention projections (e.g., qkv).
- Applies special packing for W4A16 linear layers (e.g., ``"adanorm_single"`` and ``"adanorm_zero"``).
"""
if f"{candidate_block_name}.proj_out.lora_A.weight" in extra_lora_dict:
assert f"{converted_block_name}.out_proj.qweight" in orig_state_dict
assert f"{converted_block_name}.mlp_fc2.qweight" in orig_state_dict
......@@ -317,6 +444,27 @@ def convert_to_nunchaku_flux_transformer_block_lowrank_dict(
candidate_block_name: str,
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
"""
Convert LoRA weights for a single transformer block from Diffusers to Nunchaku format.
Parameters
----------
orig_state_dict : dict[str, torch.Tensor]
Original model state dict.
extra_lora_dict : dict[str, torch.Tensor]
LoRA weights state dict.
converted_block_name : str
Output block name for the converted weights.
candidate_block_name : str
Input block name for lookup.
default_dtype : torch.dtype, optional
Output tensor dtype (default: torch.bfloat16).
Returns
-------
dict[str, torch.Tensor]
Converted LoRA weights in Nunchaku format.
"""
return convert_to_nunchaku_transformer_block_lowrank_dict(
orig_state_dict=orig_state_dict,
extra_lora_dict=extra_lora_dict,
......@@ -359,6 +507,23 @@ def convert_to_nunchaku_flux_lowrank_dict(
lora: dict[str, torch.Tensor] | str,
default_dtype: torch.dtype = torch.bfloat16,
) -> dict[str, torch.Tensor]:
"""
Convert a base model and LoRA weights from Diffusers format to Nunchaku format.
Parameters
----------
base_model : dict[str, torch.Tensor] or str
Base model weights or path to safetensors file.
lora : dict[str, torch.Tensor] or str
LoRA weights or path to safetensors file.
default_dtype : torch.dtype, optional
Output tensor dtype (default: torch.bfloat16).
Returns
-------
dict[str, torch.Tensor]
LoRA weights in Nunchaku format.
"""
if isinstance(base_model, str):
orig_state_dict = load_state_dict_in_safetensors(base_model)
else:
......@@ -377,7 +542,7 @@ def convert_to_nunchaku_flux_lowrank_dict(
elif "transformer_blocks" not in k:
unquantized_lora_dict[k] = extra_lora_dict.pop(k)
# concat qkv_proj's bias
# Concatenate qkv_proj biases if present
for k in list(vector_dict.keys()):
if ".to_q." in k or ".add_q_proj." in k:
k_q = k
......@@ -445,6 +610,32 @@ def to_nunchaku(
dtype: str | torch.dtype = torch.bfloat16,
output_path: str | None = None,
) -> dict[str, torch.Tensor]:
"""
Convert LoRA weights to Nunchaku format.
Parameters
----------
input_lora : str or dict[str, torch.Tensor]
Path or dictionary of LoRA weights in Diffusers format. Can be composed of multiple LoRA weights.
base_sd : str or dict[str, torch.Tensor]
Path or dictionary of base quantized model weights.
dtype : str or torch.dtype, optional
Output data type ("bfloat16", "float16", or torch dtype). Default is torch.bfloat16.
output_path : str, optional
If provided, saves the result to this path.
Returns
-------
dict[str, torch.Tensor]
LoRA weights in Nunchaku format.
Example
-------
.. code-block:: python
nunchaku_weights = to_nunchaku("lora.safetensors", "base_model.safetensors")
nunchaku_weights = to_nunchaku(lora_dict, base_dict)
"""
if isinstance(input_lora, str):
tensors = load_state_dict_in_safetensors(input_lora, device="cpu")
else:
......@@ -486,6 +677,23 @@ def to_nunchaku(
def fuse_vectors(
vectors: dict[str, torch.Tensor], base_sd: dict[str, torch.Tensor], strength: float = 1
) -> dict[str, torch.Tensor]:
"""
Fuse vector (bias) terms from LoRA into the base model.
Parameters
----------
vectors : dict[str, torch.Tensor]
LoRA vector terms.
base_sd : dict[str, torch.Tensor]
Base model state dict.
strength : float, optional
Scaling factor for LoRA vectors.
Returns
-------
dict[str, torch.Tensor]
State dict with fused vectors.
"""
tensors: dict[str, torch.Tensor] = {}
packer = NunchakuWeightPacker(bits=4)
for k, v in base_sd.items():
......
# Copy the packer from https://github.com/mit-han-lab/deepcompressor/
"""
Weight packing utilities for Nunchaku quantization.
This module provides concise tools for packing and unpacking weight tensors,
optimized for efficient GPU computation using Matrix Multiply and Accumulate (MMA) operations.
"""
import torch
from ...utils import ceil_divide
......@@ -6,27 +12,87 @@ from .utils import pad
class MmaWeightPackerBase:
"""
Base class for Matrix Multiply and Accumulate (MMA) weight packing.
Packs weight tensors for efficient GPU computation using MMA operations.
Handles tile sizes, memory layout, and packing parameters.
Parameters
----------
bits : int
Quantization bits. Must be 1, 4, 8, 16, or 32.
warp_n : int
Warp size in the n dimension.
comp_n : int, optional
Computation tile size in n (default: 16).
comp_k : int, optional
Computation tile size in k (default: 256 // bits).
Raises
------
AssertionError
If bits or tile/pack sizes are invalid.
Attributes
----------
comp_n : int
Tile size in n for MMA computation.
comp_k : int
Tile size in k for MMA computation.
insn_n : int
MMA instruction tile size in n.
insn_k : int
MMA instruction tile size in k.
num_lanes : int
Number of lanes (threads) in a warp.
num_k_lanes : int
Number of lanes in k.
num_n_lanes : int
Number of lanes in n.
warp_n : int
Warp size in n.
reg_k : int
Elements in a register in k.
reg_n : int
Elements in a register in n.
k_pack_size : int
Elements in a pack in k.
n_pack_size : int
Elements in a pack in n.
pack_size : int
Elements in a pack accessed by a lane.
mem_k : int
Tile size in k for one memory access.
mem_n : int
Tile size in n for one memory access.
num_k_packs : int
Packs in k for one memory access.
num_n_packs : int
Packs in n for one memory access.
"""
def __init__(self, bits: int, warp_n: int, comp_n: int = None, comp_k: int = None):
self.bits = bits
assert self.bits in (1, 4, 8, 16, 32), "weight bits should be 1, 4, 8, 16, or 32."
# region compute tile size
self.comp_n = comp_n if comp_n is not None else 16
"""smallest tile size in `n` dimension for MMA computation."""
# smallest tile size in `n` dimension for MMA computation.
self.comp_k = comp_k if comp_k is not None else 256 // self.bits
"""smallest tile size in `k` dimension for MMA computation."""
# smallest tile size in `k` dimension for MMA computation.
# the smallest MMA computation may contain several MMA instructions
self.insn_n = 8 # mma instruction tile size in `n` dimension
"""tile size in `n` dimension for MMA instruction."""
# tile size in `n` dimension for MMA instruction.
self.insn_k = self.comp_k
"""tile size in `k` dimension for MMA instruction."""
# tile size in `k` dimension for MMA instruction.
assert self.insn_k * self.bits in (
128,
256,
), f"insn_k ({self.insn_k}) * bits ({self.bits}) should be 128 or 256."
assert self.comp_n % self.insn_n == 0, f"comp_n ({self.comp_n}) should be divisible by insn_n ({self.insn_n})."
self.num_lanes = 32
"""there are 32 lanes (or threds) in a warp."""
# there are 32 lanes (or threads) in a warp.
self.num_k_lanes = 4
self.num_n_lanes = 8
assert (
......@@ -36,29 +102,50 @@ class MmaWeightPackerBase:
# endregion
# region memory
self.reg_k = 32 // self.bits
"""number of elements in a register in `k` dimension."""
# number of elements in a register in `k` dimension.
self.reg_n = 1
"""number of elements in a register in `n` dimension (always 1)."""
# number of elements in a register in `n` dimension (always 1).
self.k_pack_size = self.comp_k // (self.num_k_lanes * self.reg_k)
"""number of elements in a pack in `k` dimension."""
# number of elements in a pack in `k` dimension.
self.n_pack_size = self.comp_n // (self.num_n_lanes * self.reg_n)
"""number of elements in a pack in `n` dimension."""
# number of elements in a pack in `n` dimension.
self.pack_size = self.k_pack_size * self.n_pack_size
"""number of elements in a pack accessed by a lane at a time."""
# number of elements in a pack accessed by a lane at a time.
assert 1 <= self.pack_size <= 4, "pack size should be less than or equal to 4."
assert self.k_pack_size * self.num_k_lanes * self.reg_k == self.comp_k
assert self.n_pack_size * self.num_n_lanes * self.reg_n == self.comp_n
self.mem_k = self.comp_k
"""the tile size in `k` dimension for one tensor memory access."""
# the tile size in `k` dimension for one tensor memory access.
self.mem_n = warp_n
"""the tile size in `n` dimension for one tensor memory access."""
# the tile size in `n` dimension for one tensor memory access.
self.num_k_packs = self.mem_k // (self.k_pack_size * self.num_k_lanes * self.reg_k)
"""number of packs in `k` dimension for one tensor memory access."""
# number of packs in `k` dimension for one tensor memory access.
self.num_n_packs = self.mem_n // (self.n_pack_size * self.num_n_lanes * self.reg_n)
"""number of packs in `n` dimension for one tensor memory access."""
# number of packs in `n` dimension for one tensor memory access.
# endregion
def get_view_shape(self, n: int, k: int) -> tuple[int, int, int, int, int, int, int, int, int, int]:
"""
Returns the tensor view shape for MMA operations.
Parameters
----------
n : int
Output channel size (must be divisible by mem_n).
k : int
Input channel size (must be divisible by mem_k).
Returns
-------
tuple of int
(n_tiles, num_n_packs, n_pack_size, num_n_lanes, reg_n,
k_tiles, num_k_packs, k_pack_size, num_k_lanes, reg_k)
Raises
------
AssertionError
If n or k is not divisible by mem_n or mem_k.
"""
assert n % self.mem_n == 0, "output channel size should be divisible by mem_n."
assert k % self.mem_k == 0, "input channel size should be divisible by mem_k."
return (
......@@ -76,11 +163,41 @@ class MmaWeightPackerBase:
class NunchakuWeightPacker(MmaWeightPackerBase):
"""
Nunchaku-specific weight packer. Provide Nunchaku-specific packing of
quantized weights, scales, and low-rank weights.
Parameters
----------
bits : int
Number of quantization bits. Must be 1, 4, 8, 16, or 32.
warp_n : int, optional
Warp size in the n dimension. Default is 128.
Attributes
----------
num_k_unrolls : int
Number of unrolls in the k dimension (always 2 for Nunchaku).
"""
def __init__(self, bits: int, warp_n: int = 128):
super().__init__(bits=bits, warp_n=warp_n)
self.num_k_unrolls = 2
def pack_weight(self, weight: torch.Tensor) -> torch.Tensor:
"""
Pack quantized weight tensor for Nunchaku MMA.
Parameters
----------
weight : torch.Tensor
Quantized weight tensor of dtype torch.int32 and shape (n, k).
Returns
-------
torch.Tensor
Packed weight tensor of dtype torch.int8.
"""
assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
n, k = weight.shape
assert n % self.mem_n == 0, f"output channel size ({n}) should be divisible by mem_n ({self.mem_n})."
......@@ -122,6 +239,21 @@ class NunchakuWeightPacker(MmaWeightPackerBase):
return weight.view(dtype=torch.int8).view(n, -1) # assume little-endian
def pack_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
"""
Pack scale tensor for Nunchaku MMA.
Parameters
----------
scale : torch.Tensor
Scale tensor of dtype torch.float16 or torch.bfloat16.
group_size : int
Group size for quantization.
Returns
-------
torch.Tensor
Packed scale tensor.
"""
if self.check_if_micro_scale(group_size=group_size):
return self.pack_micro_scale(scale, group_size=group_size)
# note: refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#mma-16864-c
......@@ -169,6 +301,21 @@ class NunchakuWeightPacker(MmaWeightPackerBase):
return scale.view(-1) if group_size == -1 else scale.view(-1, n) # the shape is just used for validation
def pack_micro_scale(self, scale: torch.Tensor, group_size: int) -> torch.Tensor:
"""
Pack micro scale tensor for Nunchaku MMA.
Parameters
----------
scale : torch.Tensor
Scale tensor of dtype torch.float16 or torch.bfloat16.
group_size : int
Group size for quantization (must be 16).
Returns
-------
torch.Tensor
Packed micro scale tensor.
"""
assert scale.dtype in (torch.float16, torch.bfloat16), "currently nunchaku only supports fp16 and bf16."
assert scale.max() <= 448, "scale should be less than 448."
assert scale.min() >= -448, "scale should be greater than -448."
......@@ -213,13 +360,20 @@ class NunchakuWeightPacker(MmaWeightPackerBase):
return scale.view(-1, n) # the shape is just used for validation
def pack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Pack Low-Rank Weight.
"""
Pack low-rank weight tensor.
Parameters
----------
weight : torch.Tensor
Low-rank weight tensor of dtype torch.float16 or torch.bfloat16.
down : bool
If True, weight is for down projection in low-rank branch.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
Returns
-------
torch.Tensor
Packed low-rank weight tensor.
"""
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
reg_n, reg_k = 1, 2 # reg_n is always 1, reg_k is 32 bits // 16 bits = 2
......@@ -244,13 +398,20 @@ class NunchakuWeightPacker(MmaWeightPackerBase):
return weight.view(c, r)
def unpack_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
"""Unpack Low-Rank Weight.
"""
Unpack low-rank weight tensor.
Args:
weight (`torch.Tensor`):
low-rank weight tensor.
down (`bool`):
whether the weight is for down projection in low-rank branch.
Parameters
----------
weight : torch.Tensor
Packed low-rank weight tensor of dtype torch.float16 or torch.bfloat16.
down : bool
If True, weight is for down projection in low-rank branch.
Returns
-------
torch.Tensor
Unpacked low-rank weight tensor.
"""
c, r = weight.shape
assert weight.dtype in (torch.float16, torch.bfloat16), f"Unsupported weight dtype {weight.dtype}."
......@@ -276,13 +437,56 @@ class NunchakuWeightPacker(MmaWeightPackerBase):
return weight
def check_if_micro_scale(self, group_size: int) -> bool:
"""
Check if micro scale packing is required.
Parameters
----------
group_size : int
Group size for quantization.
Returns
-------
bool
True if micro scale packing is required.
"""
return self.insn_k == group_size * 4
def pad_weight(self, weight: torch.Tensor) -> torch.Tensor:
"""
Pad weight tensor to required shape.
Parameters
----------
weight : torch.Tensor
Weight tensor of shape (n, k).
Returns
-------
torch.Tensor
Padded weight tensor.
"""
assert weight.ndim == 2, "weight tensor should be 2D."
return pad(weight, divisor=(self.mem_n, self.mem_k * self.num_k_unrolls), dim=(0, 1))
def pad_scale(self, scale: torch.Tensor, group_size: int, fill_value: float = 0) -> torch.Tensor:
"""
Pad scale tensor to required shape.
Parameters
----------
scale : torch.Tensor
Scale tensor.
group_size : int
Group size for quantization.
fill_value : float, optional
Value to use for padding. Default is 0.
Returns
-------
torch.Tensor
Padded scale tensor.
"""
if group_size > 0 and scale.numel() > scale.shape[0]:
scale = scale.view(scale.shape[0], 1, -1, 1)
if self.check_if_micro_scale(group_size=group_size):
......@@ -294,5 +498,20 @@ class NunchakuWeightPacker(MmaWeightPackerBase):
return scale
def pad_lowrank_weight(self, weight: torch.Tensor, down: bool) -> torch.Tensor:
"""
Pad low-rank weight tensor to required shape.
Parameters
----------
weight : torch.Tensor
Low-rank weight tensor.
down : bool
If True, weight is for down projection in low-rank branch.
Returns
-------
torch.Tensor
Padded low-rank weight tensor.
"""
assert weight.ndim == 2, "weight tensor should be 2D."
return pad(weight, divisor=self.warp_n, dim=1 if down else 0)
"""
Utility functions for LoRAs in Flux models.
"""
import typing as tp
import torch
......@@ -6,8 +10,27 @@ from ...utils import ceil_divide, load_state_dict_in_safetensors
def is_nunchaku_format(lora: str | dict[str, torch.Tensor]) -> bool:
"""
Check if LoRA weights are in Nunchaku format.
Parameters
----------
lora : str or dict[str, torch.Tensor]
Path to a safetensors file or a dictionary of LoRA weights.
Returns
-------
bool
True if the weights are in Nunchaku format, False otherwise.
Examples
--------
>>> is_nunchaku_format("path/to/lora.safetensors")
True
"""
if isinstance(lora, str):
tensors = load_state_dict_in_safetensors(lora, device="cpu")
tensors = load_state_dict_in_safetensors(lora, device="cpu", return_metadata=False)
assert isinstance(tensors, dict), "Expected dict when return_metadata=False"
else:
tensors = lora
......@@ -23,6 +46,33 @@ def pad(
dim: int | tp.Sequence[int],
fill_value: float | int = 0,
) -> torch.Tensor | None:
"""
Pad a tensor so specified dimensions are divisible by given divisors.
Parameters
----------
tensor : torch.Tensor or None
The tensor to pad. If None, returns None.
divisor : int or sequence of int
Divisor(s) for the dimension(s) to pad.
dim : int or sequence of int
Dimension(s) to pad.
fill_value : float or int, optional
Value to use for padding (default: 0).
Returns
-------
torch.Tensor or None
The padded tensor, or None if input tensor was None.
Examples
--------
>>> tensor = torch.randn(10, 20)
>>> pad(tensor, divisor=16, dim=0).shape
torch.Size([16, 20])
>>> pad(tensor, divisor=[16, 32], dim=[0, 1]).shape
torch.Size([16, 32])
"""
if isinstance(divisor, int):
if divisor <= 1:
return tensor
......
"""
Merge split safetensors model files into a single safetensors file.
**Example usage**
.. code-block:: bash
python -m nunchaku.merge_safetensors -i <input_path_or_repo> -o <output_path>
**Arguments**
- ``-i``, ``--input-path`` (Path): Path to the model directory or HuggingFace repo.
- ``-o``, ``--output-path`` (Path): Path to save the merged safetensors file.
It will combine the ``unquantized_layers.safetensors`` and ``transformer_blocks.safetensors``
files (and associated config files) from a local directory or a HuggingFace Hub repository
into a single safetensors file with appropriate metadata.
**Main Function**
:func:`merge_safetensors`
"""
import argparse
import json
import os
......@@ -13,6 +38,28 @@ from .utils import load_state_dict_in_safetensors
def merge_safetensors(
pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[dict[str, torch.Tensor], dict[str, str]]:
"""
Merge split safetensors model files into a single state dict and metadata.
This function loads the ``unquantized_layers.safetensors`` and ``transformer_blocks.safetensors``
files (and associated config files) from a local directory or a HuggingFace Hub repository,
and merges them into a single state dict and metadata dictionary.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the model directory or HuggingFace repo.
**kwargs
Additional keyword arguments for subfolder, comfy_config_path, and HuggingFace download options.
Returns
-------
tuple[dict[str, torch.Tensor], dict[str, str]]
The merged state dict and metadata dictionary.
- **state_dict**: The merged model state dict.
- **metadata**: Dictionary containing ``config``, ``comfy_config``, ``model_class``, and ``quantization_config``.
"""
subfolder = kwargs.get("subfolder", None)
comfy_config_path = kwargs.get("comfy_config_path", None)
......
# Adapted from https://github.com/ToTheBeginning/PuLID
"""
This module implements the encoders for PuLID.
.. note::
This module is adapted from https://github.com/ToTheBeginning/PuLID.
"""
import math
import torch
from torch import nn
# FFN
def FeedForward(dim, mult=4):
"""
Feed-forward network (FFN) block with LayerNorm and GELU activation.
Parameters
----------
dim : int
Input and output feature dimension.
mult : int, optional
Expansion multiplier for the hidden dimension (default: 4).
Returns
-------
nn.Sequential
A sequential FFN block: LayerNorm -> Linear -> GELU -> Linear.
"""
inner_dim = int(dim * mult)
return nn.Sequential(
nn.LayerNorm(dim),
......@@ -17,17 +37,44 @@ def FeedForward(dim, mult=4):
def reshape_tensor(x, heads):
"""
Reshape a tensor for multi-head attention.
Parameters
----------
x : torch.Tensor
Input tensor of shape (batch_size, seq_len, width).
heads : int
Number of attention heads.
Returns
-------
torch.Tensor
Reshaped tensor of shape (batch_size, heads, seq_len, dim_per_head).
"""
bs, length, width = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
x = x.transpose(1, 2)
# (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
x = x.reshape(bs, heads, length, -1)
return x
class PerceiverAttentionCA(nn.Module):
"""
Perceiver-style cross-attention module.
Parameters
----------
dim : int, optional
Input feature dimension for queries (default: 3072).
dim_head : int, optional
Dimension per attention head (default: 128).
heads : int, optional
Number of attention heads (default: 16).
kv_dim : int, optional
Input feature dimension for keys/values (default: 2048).
"""
def __init__(self, *, dim=3072, dim_head=128, heads=16, kv_dim=2048):
super().__init__()
self.scale = dim_head**-0.5
......@@ -44,11 +91,19 @@ class PerceiverAttentionCA(nn.Module):
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
Forward pass for cross-attention.
Parameters
----------
x : torch.Tensor
Image features of shape (batch_size, n1, D).
latents : torch.Tensor
Latent features of shape (batch_size, n2, D).
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, n2, D).
"""
x = self.norm1(x)
latents = self.norm2(latents)
......@@ -64,7 +119,7 @@ class PerceiverAttentionCA(nn.Module):
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = (q * scale) @ (k * scale).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
......@@ -74,6 +129,21 @@ class PerceiverAttentionCA(nn.Module):
class PerceiverAttention(nn.Module):
"""
Perceiver-style self-attention module with optional cross-attention.
Parameters
----------
dim : int
Input feature dimension for queries.
dim_head : int, optional
Dimension per attention head (default: 64).
heads : int, optional
Number of attention heads (default: 8).
kv_dim : int, optional
Input feature dimension for keys/values (default: None).
"""
def __init__(self, *, dim, dim_head=64, heads=8, kv_dim=None):
super().__init__()
self.scale = dim_head**-0.5
......@@ -90,11 +160,19 @@ class PerceiverAttention(nn.Module):
def forward(self, x, latents):
"""
Args:
x (torch.Tensor): image features
shape (b, n1, D)
latent (torch.Tensor): latent features
shape (b, n2, D)
Forward pass for (cross-)attention.
Parameters
----------
x : torch.Tensor
Image features of shape (batch_size, n1, D).
latents : torch.Tensor
Latent features of shape (batch_size, n2, D).
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, n2, D).
"""
x = self.norm1(x)
latents = self.norm2(latents)
......@@ -111,7 +189,7 @@ class PerceiverAttention(nn.Module):
# attention
scale = 1 / math.sqrt(math.sqrt(self.dim_head))
weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
weight = (q * scale) @ (k * scale).transpose(-2, -1)
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
out = weight @ v
......@@ -122,11 +200,34 @@ class PerceiverAttention(nn.Module):
class IDFormer(nn.Module):
"""
- perceiver resampler like arch (compared with previous MLP-like arch)
- we concat id embedding (generated by arcface) and query tokens as latents
- latents will attend each other and interact with vit features through cross-attention
- vit features are multi-scaled and inserted into IDFormer in order, currently, each scale corresponds to two
IDFormer layers
IDFormer: Perceiver-style transformer encoder for identity and vision features.
This module fuses identity embeddings (e.g., from ArcFace) and multi-scale ViT features
using a stack of PerceiverAttention and FeedForward layers.
The architecture:
- Concatenates ID embedding tokens and query tokens as latents.
- Latents attend to each other and interact with ViT features via cross-attention.
- Multi-scale ViT features are inserted in order, each scale processed by a block of layers.
Parameters
----------
dim : int, optional
Embedding dimension for all tokens (default: 1024).
depth : int, optional
Total number of transformer layers (must be divisible by 5, default: 10).
dim_head : int, optional
Dimension per attention head (default: 64).
heads : int, optional
Number of attention heads (default: 16).
num_id_token : int, optional
Number of ID embedding tokens (default: 5).
num_queries : int, optional
Number of query tokens (default: 32).
output_dim : int, optional
Output projection dimension (default: 2048).
ff_mult : int, optional
Feed-forward expansion multiplier (default: 4).
"""
def __init__(
......@@ -189,6 +290,21 @@ class IDFormer(nn.Module):
)
def forward(self, x, y):
"""
Forward pass for IDFormer.
Parameters
----------
x : torch.Tensor
ID embedding tensor of shape (batch_size, 1280) or (batch_size, N, 1280).
y : list of torch.Tensor
List of 5 ViT feature tensors, each of shape (batch_size, feature_dim).
Returns
-------
torch.Tensor
Output tensor of shape (batch_size, num_queries, output_dim).
"""
latents = self.latents.repeat(x.size(0), 1, 1)
num_duotu = x.shape[1] if x.ndim == 3 else 1
......
# Adapted from https://github.com/ToTheBeginning/PuLID
"""
This module implements the PuLID forward function for the :class:`nunchaku.models.transformers.NunchakuFluxTransformer2dModel`,
.. note::
This module is adapted from the original PuLID repository:
https://github.com/ToTheBeginning/PuLID
"""
import logging
from typing import Any, Dict, Optional, Union
......@@ -28,30 +35,56 @@ def pulid_forward(
end_timestep: float | None = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Copied from diffusers.models.flux.transformer_flux.py
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
Implements the forward pass for the PuLID transformer block.
This function supports time and text conditioning, rotary embeddings, ControlNet integration,
and joint attention. It is adapted from
``diffusers.models.flux.transformer_flux.py`` and the original PuLID repository.
Parameters
----------
self : nn.Module
The :class:`nunchaku.models.transformers.NunchakuFluxTransformer2dModel` instance.
This function is intended to be bound as a method.
hidden_states : torch.Tensor
Input hidden states of shape ``(batch_size, channels, height, width)``.
id_embeddings : torch.Tensor, optional
Optional PuLID ID embeddings for conditioning (default: None).
id_weight : torch.Tensor, optional
Optional PuLID ID weights for conditioning (default: None).
encoder_hidden_states : torch.Tensor, optional
Conditional embeddings (e.g., from text encoder) of shape ``(batch_size, sequence_len, embed_dim)``.
pooled_projections : torch.Tensor, optional
Embeddings projected from input conditions, shape ``(batch_size, projection_dim)``.
timestep : torch.LongTensor, optional
Timestep tensor indicating the denoising step.
img_ids : torch.Tensor, optional
Image token IDs for rotary embedding.
txt_ids : torch.Tensor, optional
Text token IDs for rotary embedding.
guidance : torch.Tensor, optional
Optional guidance tensor for classifier-free guidance or similar.
joint_attention_kwargs : dict, optional
Additional keyword arguments for joint attention, passed to the attention processor.
controlnet_block_samples : Any, optional
ControlNet block samples for multi-block conditioning (default: None).
controlnet_single_block_samples : Any, optional
ControlNet single block samples for single-block conditioning (default: None).
return_dict : bool, optional
If True (default), returns a :class:`~diffusers.models.modeling_outputs.Transformer2DModelOutput`.
If False, returns a tuple containing the output tensor.
controlnet_blocks_repeat : bool, optional
Whether to repeat ControlNet blocks (default: False).
start_timestep : float, optional
If specified, disables ID embeddings for timesteps before this value.
end_timestep : float, optional
If specified, disables ID embeddings for timesteps after this value.
Returns
-------
torch.FloatTensor or Transformer2DModelOutput
If ``return_dict`` is True, returns a :class:`~diffusers.models.modeling_outputs.Transformer2DModelOutput`
with the output sample. Otherwise, returns a tuple containing the output tensor.
"""
hidden_states = self.x_embedder(hidden_states)
......@@ -79,14 +112,8 @@ def pulid_forward(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment