Unverified Commit cd214093 authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

Merge pull request #530 from mit-han-lab/dev

parents 2a785f77 51732b7a
# Adapted from https://github.com/ToTheBeginning/PuLID
"""
This module provides utility functions for PuLID.
.. note::
This module is adapted from the original PuLID repository:
https://github.com/ToTheBeginning/PuLID
"""
import math
import cv2
......@@ -8,6 +15,23 @@ from torchvision.utils import make_grid
def resize_numpy_image_long(image, resize_long_edge=768):
"""
Resize a numpy image so that its longest edge matches ``resize_long_edge``, preserving aspect ratio.
If the image's longest edge is already less than or equal to ``resize_long_edge``, the image is returned unchanged.
Parameters
----------
image : np.ndarray
Input image as a numpy array of shape (H, W, C).
resize_long_edge : int, optional
Desired size for the longest edge (default: 768).
Returns
-------
np.ndarray
The resized image as a numpy array.
"""
h, w = image.shape[:2]
if max(h, w) <= resize_long_edge:
return image
......@@ -18,18 +42,27 @@ def resize_numpy_image_long(image, resize_long_edge=768):
return image
# from basicsr
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
Convert numpy images to PyTorch tensors.
This function supports both single images and lists of images. The images are converted from
HWC (height, width, channel) format to CHW (channel, height, width) format. Optionally, BGR images
can be converted to RGB, and the output can be cast to float32.
Parameters
----------
imgs : np.ndarray or list of np.ndarray
Input image(s) as numpy array(s).
bgr2rgb : bool, optional
Whether to convert BGR images to RGB (default: True).
float32 : bool, optional
Whether to cast the output tensor(s) to float32 (default: True).
Returns
-------
torch.Tensor or list of torch.Tensor
Converted tensor(s). If a single image is provided, returns a tensor; if a list is provided, returns a list of tensors.
"""
def _totensor(img, bgr2rgb, float32):
......@@ -48,25 +81,44 @@ def img2tensor(imgs, bgr2rgb=True, float32=True):
def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to [min, max], values will be normalized to [0, 1].
Args:
tensor (Tensor or list[Tensor]): Accept shapes:
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
rgb2bgr (bool): Whether to change rgb to bgr.
out_type (numpy type): output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
Returns:
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
"""
Convert PyTorch tensor(s) to image numpy array(s).
This function supports 4D mini-batch tensors, 3D tensors, and 2D tensors. The output is a numpy array
in HWC (height, width, channel) or HW (height, width) format. Optionally, RGB images can be converted to BGR,
and the output type can be specified.
After clamping to [min, max], values are normalized to [0, 1].
Parameters
----------
tensor : torch.Tensor or list of torch.Tensor
Input tensor(s). Accepts:
1) 4D mini-batch tensor of shape (B x 3/1 x H x W)
2) 3D tensor of shape (3/1 x H x W)
3) 2D tensor of shape (H x W)
The channel order should be RGB.
rgb2bgr : bool, optional
Whether to convert RGB images to BGR (default: True).
out_type : numpy type, optional
Output data type. If ``np.uint8``, output is in [0, 255]; otherwise, in [0, 1] (default: np.uint8).
min_max : tuple of int, optional
Min and max values for clamping (default: (0, 1)).
Returns
-------
np.ndarray or list of np.ndarray
Converted image(s) as numpy array(s). If a single tensor is provided, returns a numpy array; if a list is provided, returns a list of numpy arrays.
Raises
------
TypeError
If the input is not a tensor or list of tensors, or if the tensor has unsupported dimensions.
"""
if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}")
......
"""
This module provides a `SafetyChecker` class for evaluating user prompts against
defined safety policies using a large language model. Only used deploying online gradio demos.
"""
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
#: Template for the safety check prompt.
safety_check_template = """You are a policy expert trying to help determine whether a user
prompt is in violation of the defined safety policies.
......@@ -19,7 +25,40 @@ safety_check_template = """You are a policy expert trying to help determine whet
class SafetyChecker:
"""
SafetyChecker(device, disabled=False)
A class to check whether a user prompt violates safety policies using a language model.
Parameters
----------
device : str or torch.device
The device to run the model on (e.g., "cuda", "cpu").
disabled : bool, optional
If True, disables the safety check and always returns True (default: False).
Examples
--------
>>> checker = SafetyChecker(device="cuda")
>>> checker("Generate a nude girl image")
False
>>> checker = SafetyChecker(device="cpu", disabled=True)
>>> checker("Any prompt")
True
"""
def __init__(self, device: str | torch.device, disabled: bool = False):
"""
Initialize the SafetyChecker.
Parameters
----------
device : str or torch.device
The device to run the model on.
disabled : bool, optional
If True, disables the safety check (default: False).
"""
if not disabled:
self.device = device
self.tokenizer = AutoTokenizer.from_pretrained("google/shieldgemma-2b")
......@@ -29,6 +68,21 @@ class SafetyChecker:
self.disabled = disabled
def __call__(self, user_prompt: str, threshold: float = 0.2) -> bool:
"""
Evaluate whether a user prompt is safe according to the defined policy.
Parameters
----------
user_prompt : str
The user prompt to evaluate.
threshold : float, optional
The probability threshold for flagging a prompt as unsafe (default: 0.2).
Returns
-------
bool
True if the prompt is considered safe, False otherwise.
"""
if self.disabled:
return True
device = self.device
......
# -*- coding: utf-8 -*-
"""TinyChat Quantized Linear Module"""
"""
This module provides the :class:`W4Linear` quantized linear layer, which implements
4-bit weight-only quantization for efficient inference.
"""
import torch
import torch.nn as nn
......@@ -11,6 +14,40 @@ __all__ = ["W4Linear"]
class W4Linear(nn.Module):
"""
4-bit quantized linear layer with group-wise quantization.
Parameters
----------
in_features : int
Number of input features.
out_features : int
Number of output features.
bias : bool, optional
If True, adds a learnable bias (default: False).
group_size : int, optional
Number of input channels per quantization group (default: 128).
If -1, uses the full input dimension as a single group.
dtype : torch.dtype, optional
Data type for quantization scales and zeros (default: torch.float16).
device : str or torch.device, optional
Device for weights and buffers (default: "cuda").
Attributes
----------
in_features : int
out_features : int
group_size : int
qweight : torch.Tensor
Quantized weight tensor (int16).
scales : torch.Tensor
Per-group scale tensor.
scaled_zeros : torch.Tensor
Per-group zero-point tensor (scaled).
bias : torch.Tensor or None
Optional bias tensor.
"""
def __init__(
self,
in_features: int,
......@@ -61,14 +98,33 @@ class W4Linear(nn.Module):
@property
def weight_bits(self) -> int:
"""
Number of bits per quantized weight (always 4).
"""
return 4
@property
def interleave(self) -> int:
"""
Interleave factor for quantized weights (always 4).
"""
return 4
@torch.no_grad()
def forward(self, x):
"""
Forward pass.
Parameters
----------
x : torch.Tensor
Input tensor of shape (..., in_features).
Returns
-------
torch.Tensor
Output tensor of shape (..., out_features).
"""
if x.numel() / x.shape[-1] < 8:
out = gemv_awq(
x,
......@@ -97,27 +153,30 @@ class W4Linear(nn.Module):
zero: torch.Tensor | None = None,
zero_pre_scaled: bool = False,
) -> "W4Linear":
"""Convert a linear layer to a TinyChat 4-bit weight-only quantized linear layer.
Args:
linear (`nn.Linear`):
linear layer to be converted.
group_size (`int`):
quantization group size.
init_only (`bool`, *optional*, defaults to `False`):
whether to only initialize the quantized linear layer.
weight (`torch.Tensor`, *optional*, defaults to `None`):
weight tensor for the quantized linear layer.
scale (`torch.Tensor`, *optional*, defaults to `None`):
scale tensor for the quantized linear layer.
zero (`torch.Tensor`, *optional*, defaults to `None`):
zero point tensor for the quantized linear layer.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`W4Linear`:
quantized linear layer.
"""
Convert a standard nn.Linear to a quantized W4Linear.
Parameters
----------
linear : nn.Linear
The linear layer to convert.
group_size : int
Quantization group size.
init_only : bool, optional
If True, only initializes the quantized layer (default: False).
weight : torch.Tensor, optional
Precomputed quantized weight (default: None).
scale : torch.Tensor, optional
Precomputed scale tensor (default: None).
zero : torch.Tensor, optional
Precomputed zero-point tensor (default: None).
zero_pre_scaled : bool, optional
Whether the zero-point tensor is pre-scaled (default: False).
Returns
-------
W4Linear
Quantized linear layer.
"""
assert isinstance(linear, nn.Linear)
weight = linear.weight.data if weight is None else weight.data
......@@ -167,6 +226,9 @@ class W4Linear(nn.Module):
return _linear
def extra_repr(self) -> str:
"""
Returns a string describing the layer configuration.
"""
return "in_features={}, out_features={}, bias={}, weight_bits={}, group_size={}".format(
self.in_features,
self.out_features,
......
"""
The NunchakuT5EncoderModel class enables loading T5 encoder weights from safetensors files,
automatically replacing supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear`
modules for improved performance and memory efficiency.
"""
import json
import logging
import os
......@@ -20,12 +26,62 @@ logger = logging.getLogger(__name__)
class NunchakuT5EncoderModel(T5EncoderModel):
"""
Nunchaku T5 Encoder Model
Extends :class:`transformers.T5EncoderModel` to support quantized weights and
memory-efficient inference using :class:`~nunchaku.models.text_encoders.linear.W4Linear`.
This class provides a convenient interface for loading T5 encoder weights from
safetensors files, automatically replacing supported linear layers with quantized
modules for improved speed and reduced memory usage.
Example
-------
.. code-block:: python
model = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
"""
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
"""
Load a :class:`NunchakuT5EncoderModel` from a safetensors file.
This method loads the model configuration and weights from a safetensors file,
initializes the model on the 'meta' device (no memory allocation for weights),
and replaces supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear` modules.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the safetensors file containing the model weights and metadata.
torch_dtype : torch.dtype, optional
Data type for model initialization (default: ``torch.bfloat16``).
Set to ``torch.float16`` for Turing GPUs.
device : str or torch.device, optional
Device to load the model onto (default: ``"cuda"``).
If the model is loaded on CPU, it will be automatically moved to GPU.
Returns
-------
NunchakuT5EncoderModel
The loaded and quantized T5 encoder model.
Example
-------
.. code-block:: python
model = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
"""
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
# Load the config file
# Load the config file from metadata
config = json.loads(metadata["config"])
config = T5Config(**config)
......@@ -35,7 +91,7 @@ class NunchakuT5EncoderModel(T5EncoderModel):
t5_encoder.eval()
# Load the model weights from the safetensors file
# Load the model weights from the safetensors file and quantize supported linear layers
named_modules = {}
for name, module in t5_encoder.named_modules():
assert isinstance(name, str)
......
# -*- coding: utf-8 -*-
"""TinyChat backend utilities."""
"""
This module provides utility functions for quantized linear layers in the TinyChat backend.
"""
import torch
......@@ -7,35 +9,50 @@ __all__ = ["ceil_num_groups", "convert_to_tinychat_w4x16y16_linear_weight"]
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Returns:
`int`:
ceiling division result.
"""
Compute the ceiling of integer division.
Parameters
----------
x : int
Dividend.
divisor : int
Divisor.
Returns
-------
int
The smallest integer greater than or equal to ``x / divisor``.
"""
return (x + divisor - 1) // divisor
def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) -> int:
"""Calculate the ceiling number of quantization groups.
Args:
in_features (`int`):
input channel size.
group_size (`int`):
quantization group size.
weight_bits (`int`, *optional*, defaults to `4`):
quantized weight bits.
Returns:
`int`:
ceiling number of quantization groups.
"""
Calculate the padded number of quantization groups for TinyChat quantization.
This ensures the number of groups is compatible with TinyChat's packing and kernel requirements.
Parameters
----------
in_features : int
Input channel size (number of input features).
group_size : int
Quantization group size.
weight_bits : int, optional
Number of bits per quantized weight (default: 4).
Returns
-------
int
The padded number of quantization groups.
Raises
------
AssertionError
If ``in_features`` is not divisible by ``group_size``, or if ``weight_bits`` is not 4, 2, or 1.
NotImplementedError
If ``group_size`` is not one of the supported values (>=128, 64, 32).
"""
assert in_features % group_size == 0, "input channel size should be divisible by group size."
num_groups = in_features // group_size
......@@ -49,7 +66,7 @@ def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) ->
elif group_size == 32:
num_packs_factor = 4
else:
raise NotImplementedError
raise NotImplementedError("Unsupported group size for TinyChat quantization.")
# make sure num_packs is a multiple of num_packs_factor
num_packs = ceil_divide(num_packs, num_packs_factor) * num_packs_factor
num_groups = num_packs * pack_size
......@@ -57,6 +74,28 @@ def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) ->
def pack_w4(weight: torch.Tensor) -> torch.Tensor:
"""
Pack quantized 4-bit weights into TinyChat's int16 format.
This function rearranges and packs 4-bit quantized weights (stored as int32) into
the format expected by TinyChat CUDA kernels.
Parameters
----------
weight : torch.Tensor
Quantized weight tensor of shape (out_features, in_features), dtype int32.
The input channel dimension must be divisible by 32.
Returns
-------
torch.Tensor
Packed weight tensor of dtype int16.
Raises
------
AssertionError
If input tensor is not int32 or input channel size is not divisible by 32.
"""
assert weight.dtype == torch.int32, f"quantized weight should be torch.int32, but got {weight.dtype}."
oc, ic = weight.shape
assert ic % 32 == 0, "input channel size should be divisible by 32."
......@@ -74,23 +113,49 @@ def convert_to_tinychat_w4x16y16_linear_weight(
group_size: int = -1,
zero_pre_scaled: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Convert a weight tensor to TinyChat W4-X16-Y16 linear weight format.
Args:
weight (`torch.Tensor`):
weight tensor to be converted.
scale (`torch.Tensor`):
scale tensor for the weight tensor.
zero (`torch.Tensor`):
zero point tensor for the weight tensor.
group_size (`int`, *optional*, defaults to `-1`):
quantization group size.
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
whether zero point tensor is pre-scaled.
Returns:
`tuple[torch.Tensor, torch.Tensor, torch.Tensor]`:
packed quantized weight tensor, scale tensor, and zero point tensor.
"""
Convert a floating-point weight tensor to TinyChat W4-X16-Y16 quantized linear format.
This function quantizes the input weights to 4 bits per value, applies group-wise
scaling and zero-point, and packs the result into the format expected by TinyChat
quantized linear layers.
Parameters
----------
weight : torch.Tensor
Floating-point weight tensor of shape (out_features, in_features).
Must be of dtype ``torch.float16`` or ``torch.bfloat16``.
scale : torch.Tensor
Per-group scale tensor (can be broadcastable).
zero : torch.Tensor
Per-group zero-point tensor (can be broadcastable).
group_size : int, optional
Quantization group size. If set to -1 (default), uses the full input dimension as a single group.
zero_pre_scaled : bool, optional
If True, the zero tensor is already scaled by the scale tensor (default: False).
Returns
-------
tuple of torch.Tensor
- packed_weight : torch.Tensor
Packed quantized weight tensor (int16).
- packed_scale : torch.Tensor
Packed scale tensor (shape: [num_groups, out_features], dtype matches input).
- packed_zero : torch.Tensor
Packed zero-point tensor (shape: [num_groups, out_features], dtype matches input).
Raises
------
AssertionError
If input types or shapes are invalid, or quantized values are out of range.
Example
-------
.. code-block:: python
qweight, qscale, qzero = convert_to_tinychat_w4x16y16_linear_weight(
weight, scale, zero, group_size=128
)
"""
dtype, device = weight.dtype, weight.device
assert dtype in (torch.float16, torch.bfloat16), "currently tinychat only supports fp16 and bf16."
......
"""
Implements the :class:`NunchakuFluxTransformer2dModel`, a quantized transformer for Diffusers with efficient inference and LoRA support.
"""
import json
import logging
import os
......@@ -32,6 +36,20 @@ logger = logging.getLogger(__name__)
class NunchakuFluxTransformerBlocks(nn.Module):
"""
Wrapper for quantized Nunchaku FLUX transformer blocks.
This class manages the forward pass, rotary embedding packing, and optional
residual callbacks for ID embeddings.
Parameters
----------
m : QuantizedFluxModel
The quantized transformer model.
device : str or torch.device
Device to run the model on.
"""
def __init__(self, m: QuantizedFluxModel, device: str | torch.device):
super(NunchakuFluxTransformerBlocks, self).__init__()
self.m = m
......@@ -40,6 +58,19 @@ class NunchakuFluxTransformerBlocks(nn.Module):
@staticmethod
def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
"""
Packs rotary embeddings for efficient computation.
Parameters
----------
rotemb : torch.Tensor
Rotary embedding tensor of shape (B, M, D//2, 1, 2), dtype float32.
Returns
-------
torch.Tensor
Packed rotary embedding tensor of shape (B, M, D).
"""
assert rotemb.dtype == torch.float32
B = rotemb.shape[0]
M = rotemb.shape[1]
......@@ -73,6 +104,38 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_single_block_samples=None,
skip_first_layer=False,
):
"""
Forward pass for the quantized transformer blocks.
It will call the forward method of ``m`` on the C backend.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states for image tokens.
temb : torch.Tensor
Temporal embedding tensor.
encoder_hidden_states : torch.Tensor
Input hidden states for text tokens.
image_rotary_emb : torch.Tensor
Rotary embedding tensor for all tokens.
id_embeddings : torch.Tensor, optional
Optional ID embeddings for residual callback.
id_weight : float, optional
Weight for ID embedding residual.
joint_attention_kwargs : dict, optional
Additional kwargs for joint attention.
controlnet_block_samples : list[torch.Tensor], optional
ControlNet block samples.
controlnet_single_block_samples : list[torch.Tensor], optional
ControlNet single block samples.
skip_first_layer : bool, optional
Whether to skip the first layer.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
(encoder_hidden_states, hidden_states) after transformer blocks.
"""
# batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
......@@ -149,6 +212,33 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_block_samples=None,
controlnet_single_block_samples=None,
):
"""
Forward pass for a specific transformer layer in ``m``.
Parameters
----------
idx : int
Index of the transformer layer.
hidden_states : torch.Tensor
Input hidden states for image tokens.
encoder_hidden_states : torch.Tensor
Input hidden states for text tokens.
temb : torch.Tensor
Temporal embedding tensor.
image_rotary_emb : torch.Tensor
Rotary embedding tensor for all tokens.
joint_attention_kwargs : dict, optional
Additional kwargs for joint attention.
controlnet_block_samples : list[torch.Tensor], optional
ControlNet block samples.
controlnet_single_block_samples : list[torch.Tensor], optional
ControlNet single block samples.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
(encoder_hidden_states, hidden_states) after the specified layer.
"""
# batch_size = hidden_states.shape[0]
txt_tokens = encoder_hidden_states.shape[1]
img_tokens = hidden_states.shape[1]
......@@ -195,6 +285,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
return encoder_hidden_states, hidden_states
def set_pulid_residual_callback(self):
"""
Sets the residual callback for PulID (personalized ID) embeddings.
"""
id_embeddings = self.id_embeddings
pulid_ca = self.pulid_ca
pulid_ca_idx = [self.pulid_ca_idx]
......@@ -209,10 +302,16 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self.m.set_residual_callback(callback)
def reset_pulid_residual_callback(self):
"""
Resets the PulID residual callback to None.
"""
self.callback_holder = None
self.m.set_residual_callback(None)
def __del__(self):
"""
Destructor to reset the quantized model.
"""
self.m.reset()
def norm1(
......@@ -221,11 +320,44 @@ class NunchakuFluxTransformerBlocks(nn.Module):
emb: torch.Tensor,
idx: int = 0,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Runs the norm_one_forward for a specific layer in ``m``.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states.
emb : torch.Tensor
Embedding tensor.
idx : int, optional
Layer index (default: 0).
Returns
-------
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Output tensors from norm_one_forward.
"""
return self.m.norm_one_forward(idx, hidden_states, emb)
## copied from diffusers 0.30.3
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
"""
Rotary positional embedding function.
Parameters
----------
pos : torch.Tensor
Position tensor of shape (..., n).
dim : int
Embedding dimension (must be even).
theta : int
Rotary base.
Returns
-------
torch.Tensor
Rotary embedding tensor.
"""
assert dim % 2 == 0, "The dimension must be even."
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
......@@ -247,6 +379,19 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
class EmbedND(nn.Module):
"""
Multi-dimensional rotary embedding module.
Parameters
----------
dim : int
Embedding dimension.
theta : int
Rotary base.
axes_dim : list[int]
List of axis dimensions for each spatial axis.
"""
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
super(EmbedND, self).__init__()
self.dim = dim
......@@ -254,6 +399,19 @@ class EmbedND(nn.Module):
self.axes_dim = axes_dim
def forward(self, ids: torch.Tensor) -> torch.Tensor:
"""
Computes rotary embeddings for multi-dimensional positions.
Parameters
----------
ids : torch.Tensor
Position indices tensor of shape (..., n_axes).
Returns
-------
torch.Tensor
Rotary embedding tensor.
"""
if Version(diffusers.__version__) >= Version("0.31.0"):
ids = ids[None, ...]
n_axes = ids.shape[-1]
......@@ -268,6 +426,27 @@ def load_quantized_module(
offload: bool = False,
bf16: bool = True,
) -> QuantizedFluxModel:
"""
Loads a quantized Nunchaku FLUX model from a state dict or file.
Parameters
----------
path_or_state_dict : str, os.PathLike, or dict
Path to the quantized model file or a state dict.
device : str or torch.device, optional
Device to load the model on (default: "cuda").
use_fp4 : bool, optional
Whether to use FP4 quantization (default: False).
offload : bool, optional
Whether to offload weights to CPU (default: False).
bf16 : bool, optional
Whether to use bfloat16 (default: True).
Returns
-------
QuantizedFluxModel
Loaded quantized model.
"""
device = torch.device(device)
assert device.type == "cuda"
m = QuantizedFluxModel()
......@@ -281,6 +460,38 @@ def load_quantized_module(
class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoaderMixin):
"""
Nunchaku FLUX Transformer 2D Model.
This class implements a quantized transformer model compatible with the Diffusers
library, supporting LoRA, rotary embeddings, and efficient inference.
Parameters
----------
patch_size : int, optional
Patch size for input images (default: 1).
in_channels : int, optional
Number of input channels (default: 64).
out_channels : int or None, optional
Number of output channels (default: None).
num_layers : int, optional
Number of transformer layers (default: 19).
num_single_layers : int, optional
Number of single transformer layers (default: 38).
attention_head_dim : int, optional
Dimension of each attention head (default: 128).
num_attention_heads : int, optional
Number of attention heads (default: 24).
joint_attention_dim : int, optional
Joint attention dimension (default: 4096).
pooled_projection_dim : int, optional
Pooled projection dimension (default: 768).
guidance_embeds : bool, optional
Whether to use guidance embeddings (default: False).
axes_dims_rope : tuple[int], optional
Axes dimensions for rotary embeddings (default: (16, 56, 56)).
"""
@register_to_config
def __init__(
self,
......@@ -323,6 +534,21 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
"""
Loads a Nunchaku FLUX transformer model from pretrained weights.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the model directory or HuggingFace repo.
**kwargs
Additional keyword arguments for device, offload, torch_dtype, precision, etc.
Returns
-------
NunchakuFluxTransformer2dModel or (NunchakuFluxTransformer2dModel, dict)
The loaded model, and optionally metadata if `return_metadata=True`.
"""
device = kwargs.get("device", "cuda")
if isinstance(device, str):
device = torch.device(device)
......@@ -395,6 +621,21 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
return transformer
def inject_quantized_module(self, m: QuantizedFluxModel, device: str | torch.device = "cuda"):
"""
Injects a quantized module into the model and sets up transformer blocks.
Parameters
----------
m : QuantizedFluxModel
The quantized transformer model.
device : str or torch.device, optional
Device to run the model on (default: "cuda").
Returns
-------
self : NunchakuFluxTransformer2dModel
The model with injected quantized module.
"""
print("Injecting quantized module")
self.pos_embed = EmbedND(dim=self.inner_dim, theta=10000, axes_dim=[16, 56, 56])
......@@ -405,6 +646,17 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
return self
def set_attention_impl(self, impl: str):
"""
Set the attention implementation for the quantized transformer block.
Parameters
----------
impl : str
Attention implementation to use. Supported values:
- ``"flash-attention2"`` (default): Standard FlashAttention-2.
- ``"nunchaku-fp16"``: Uses FP16 attention accumulation, up to 1.2× faster than FlashAttention-2 on NVIDIA 30-, 40-, and 50-series GPUs.
"""
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.setAttentionImpl(impl)
......@@ -412,6 +664,17 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
### LoRA Related Functions
def _expand_module(self, module_name: str, new_shape: tuple[int, int]):
"""
Expands a linear module to a new shape for LoRA compatibility.
Mostly for FLUX.1-tools LoRA which changes the input channels.
Parameters
----------
module_name : str
Name of the module to expand.
new_shape : tuple[int, int]
New shape (out_features, in_features) for the module.
"""
module = self.get_submodule(module_name)
assert isinstance(module, nn.Linear)
weight_shape = module.weight.shape
......@@ -443,6 +706,14 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
setattr(self.config, "in_channels", new_value)
def _update_unquantized_part_lora_params(self, strength: float = 1):
"""
Updates the unquantized part of the model with LoRA parameters.
Parameters
----------
strength : float, optional
LoRA scaling strength (default: 1).
"""
# check if we need to expand the linear layers
device = next(self.parameters()).device
for k, v in self._unquantized_part_loras.items():
......@@ -505,6 +776,18 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
self.load_state_dict(new_state_dict, strict=True)
def update_lora_params(self, path_or_state_dict: str | dict[str, torch.Tensor]):
"""
Update the model with new LoRA parameters.
Parameters
----------
path_or_state_dict : str or dict
Path to a LoRA weights file or a state dict. The path supports:
- Local file path, e.g., ``"/path/to/your/lora.safetensors"``
- HuggingFace repo with file, e.g., ``"user/repo/lora.safetensors"``
(automatically downloaded and cached)
"""
if isinstance(path_or_state_dict, dict):
state_dict = {
k: v for k, v in path_or_state_dict.items()
......@@ -543,9 +826,20 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
block.m.loadDict(state_dict, True)
# This function can only be used with a single LoRA.
# For multiple LoRAs, please fuse the lora scale into the weights.
def set_lora_strength(self, strength: float = 1):
"""
Sets the LoRA scaling strength for the model.
Note: This function can only be used with a single LoRA. For multiple LoRAs,
please fuse the LoRA scale into the weights.
Parameters
----------
strength : float, optional
LoRA scaling strength (default: 1).
Note: This function will change the strength of all the LoRAs. So only use it when you only have a single LoRA.
"""
block = self.transformer_blocks[0]
assert isinstance(block, NunchakuFluxTransformerBlocks)
block.m.setLoraScale(SVD_RANK, strength)
......@@ -556,6 +850,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
block.m.loadDict(vector_dict, True)
def reset_x_embedder(self):
"""
Resets the x_embedder module if the input channel count has changed.
This is used for removing the effect of FLUX.1-tools LoRA which changes the input channels.
"""
# if change the model in channels, we need to update the x_embedder
if self._original_in_channels != self.config.in_channels:
assert self._original_in_channels < self.config.in_channels
......@@ -577,6 +875,9 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
setattr(self.config, "in_channels", self._original_in_channels)
def reset_lora(self):
"""
Resets all LoRA parameters to their default state.
"""
unquantized_part_loras = {}
if len(self._unquantized_part_loras) > 0 or len(unquantized_part_loras) > 0:
self._unquantized_part_loras = unquantized_part_loras
......@@ -606,30 +907,42 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
Copied from diffusers.models.flux.transformer_flux.py
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
Forward pass for the Nunchaku FLUX transformer model.
This method is compatible with the Diffusers pipeline and supports LoRA,
rotary embeddings, and ControlNet.
Parameters
----------
hidden_states : torch.FloatTensor
Input hidden states of shape (batch_size, channel, height, width).
encoder_hidden_states : torch.FloatTensor, optional
Conditional embeddings (e.g., prompt embeddings) of shape (batch_size, sequence_len, embed_dims).
pooled_projections : torch.FloatTensor, optional
Embeddings projected from the input conditions.
timestep : torch.LongTensor, optional
Denoising step.
img_ids : torch.Tensor, optional
Image token indices.
txt_ids : torch.Tensor, optional
Text token indices.
guidance : torch.Tensor, optional
Guidance tensor for classifier-free guidance.
joint_attention_kwargs : dict, optional
Additional kwargs for joint attention.
controlnet_block_samples : list[torch.Tensor], optional
ControlNet block samples.
controlnet_single_block_samples : list[torch.Tensor], optional
ControlNet single block samples.
return_dict : bool, optional
Whether to return a Transformer2DModelOutput (default: True).
controlnet_blocks_repeat : bool, optional
Whether to repeat ControlNet blocks (default: False).
Returns
-------
torch.FloatTensor or Transformer2DModelOutput
Output tensor or output object containing the sample.
"""
hidden_states = self.x_embedder(hidden_states)
......
"""
Implements the :class:`NunchakuSanaTransformer2DModel`,
a quantized Sana transformer for Diffusers with efficient inference support.
"""
import os
from pathlib import Path
from typing import Optional
......@@ -18,6 +23,22 @@ SVD_RANK = 32
class NunchakuSanaTransformerBlocks(nn.Module):
"""
Wrapper for quantized Sana transformer blocks.
This module wraps a QuantizedSanaModel and provides forward methods compatible
with the expected transformer block interface.
Parameters
----------
m : QuantizedSanaModel
The quantized transformer model.
dtype : torch.dtype
The data type to use for computation.
device : str or torch.device
The device to run the model on.
"""
def __init__(self, m: QuantizedSanaModel, dtype: torch.dtype, device: str | torch.device):
super(NunchakuSanaTransformerBlocks, self).__init__()
self.m = m
......@@ -35,7 +56,33 @@ class NunchakuSanaTransformerBlocks(nn.Module):
width: Optional[int] = None,
skip_first_layer: Optional[bool] = False,
):
"""
Forward pass through all quantized transformer blocks.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states of shape (batch_size, img_tokens, ...).
attention_mask : torch.Tensor, optional
Not used.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states of shape (batch_size, txt_tokens, ...).
encoder_attention_mask : torch.Tensor, optional
Encoder attention mask of shape (batch_size, 1, txt_tokens).
timestep : torch.LongTensor, optional
Timestep tensor.
height : int, optional
Image height.
width : int, optional
Image width.
skip_first_layer : bool, optional
Whether to skip the first layer.
Returns
-------
torch.Tensor
Output tensor after passing through the quantized transformer blocks.
"""
batch_size = hidden_states.shape[0]
img_tokens = hidden_states.shape[1]
txt_tokens = encoder_hidden_states.shape[1]
......@@ -90,6 +137,33 @@ class NunchakuSanaTransformerBlocks(nn.Module):
height: Optional[int] = None,
width: Optional[int] = None,
):
"""
Forward pass through a specific quantized transformer layer.
Parameters
----------
idx : int
Index of the layer to run.
hidden_states : torch.Tensor
Input hidden states.
attention_mask : torch.Tensor, optional
Not used.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states.
encoder_attention_mask : torch.Tensor, optional
Encoder attention mask.
timestep : torch.LongTensor, optional
Timestep tensor.
height : int, optional
Image height.
width : int, optional
Image width.
Returns
-------
torch.Tensor
Output tensor after passing through the specified quantized transformer layer.
"""
batch_size = hidden_states.shape[0]
img_tokens = hidden_states.shape[1]
txt_tokens = encoder_hidden_states.shape[1]
......@@ -134,13 +208,41 @@ class NunchakuSanaTransformerBlocks(nn.Module):
)
def __del__(self):
"""
Destructor to reset the quantized model and free resources.
"""
self.m.reset()
class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoaderMixin):
"""
SanaTransformer2DModel with Nunchaku quantized backend support.
This class extends the base SanaTransformer2DModel to support loading and
injecting quantized transformer blocks using Nunchaku's custom backend.
"""
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
"""
Load a pretrained NunchakuSanaTransformer2DModel from a local file or HuggingFace Hub.
This method supports both quantized and unquantized checkpoints, and will
automatically inject quantized transformer blocks if available.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the model checkpoint or HuggingFace Hub model name.
**kwargs
Additional keyword arguments for model loading.
Returns
-------
NunchakuSanaTransformer2DModel or (NunchakuSanaTransformer2DModel, dict)
The loaded model, and optionally metadata if ``return_metadata=True``.
"""
device = kwargs.get("device", "cuda")
if isinstance(device, str):
device = torch.device(device)
......@@ -184,6 +286,21 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
return transformer
def inject_quantized_module(self, m: QuantizedSanaModel, device: str | torch.device = "cuda"):
"""
Inject a quantized transformer module into this model.
Parameters
----------
m : QuantizedSanaModel
The quantized transformer module to inject.
device : str or torch.device, optional
The device to place the module on (default: "cuda").
Returns
-------
NunchakuSanaTransformer2DModel
The model with the quantized module injected.
"""
self.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, self.dtype, device)])
return self
......@@ -195,6 +312,27 @@ def load_quantized_module(
pag_layers: int | list[int] | None = None,
use_fp4: bool = False,
) -> QuantizedSanaModel:
"""
Load quantized weights into a QuantizedSanaModel.
Parameters
----------
net : SanaTransformer2DModel
The base transformer model (for config and dtype).
path_or_state_dict : str, os.PathLike, or dict
Path to the quantized weights or a state dict.
device : str or torch.device, optional
Device to load the quantized model on (default: "cuda").
pag_layers : int, list of int, or None, optional
List of layers to use pag (default: None).
use_fp4 : bool, optional
Whether to use FP4 quantization (default: False).
Returns
-------
QuantizedSanaModel
The loaded quantized model.
"""
if pag_layers is None:
pag_layers = []
elif isinstance(pag_layers, int):
......@@ -215,5 +353,22 @@ def load_quantized_module(
def inject_quantized_module(
net: SanaTransformer2DModel, m: QuantizedSanaModel, device: torch.device
) -> SanaTransformer2DModel:
"""
Inject a quantized transformer module into a SanaTransformer2DModel.
Parameters
----------
net : SanaTransformer2DModel
The base transformer model.
m : QuantizedSanaModel
The quantized transformer module to inject.
device : torch.device
The device to place the module on.
Returns
-------
SanaTransformer2DModel
The model with the quantized module injected.
"""
net.transformer_blocks = torch.nn.ModuleList([NunchakuSanaTransformerBlocks(m, net.dtype, device)])
return net
"""
Utilities for Nunchaku transformer model loading.
"""
import json
import logging
import os
......@@ -20,16 +24,38 @@ logger = logging.getLogger(__name__)
class NunchakuModelLoaderMixin:
"""
Mixin for standardized model loading in Nunchaku transformer models.
Provides:
- :meth:`_build_model`: Load model from a safetensors file.
- :meth:`_build_model_legacy`: Load model from a legacy folder structure (deprecated).
"""
@classmethod
def _build_model(
cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs
) -> tuple[nn.Module, dict[str, torch.Tensor], dict[str, str]]:
"""
Build a transformer model from a safetensors file.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the safetensors file.
**kwargs
Additional keyword arguments (e.g., ``torch_dtype``).
Returns
-------
tuple
(transformer, state_dict, metadata)
"""
if isinstance(pretrained_model_name_or_path, str):
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
state_dict, metadata = load_state_dict_in_safetensors(pretrained_model_name_or_path, return_metadata=True)
# Load the config file
config = json.loads(metadata["config"])
with torch.device("meta"):
......@@ -41,6 +67,25 @@ class NunchakuModelLoaderMixin:
def _build_model_legacy(
cls, pretrained_model_name_or_path: str | os.PathLike, **kwargs
) -> tuple[nn.Module, str, str]:
"""
Build a transformer model from a legacy folder structure.
.. warning::
This method is deprecated and will be removed in v0.4.
Please migrate to safetensors-based model loading.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the folder containing model weights.
**kwargs
Additional keyword arguments for HuggingFace Hub download and config loading.
Returns
-------
tuple
(transformer, unquantized_part_path, transformer_block_path)
"""
logger.warning(
"Loading models from a folder will be deprecated in v0.4. "
"Please download the latest safetensors model, or use one of the following tools to "
......@@ -109,6 +154,25 @@ class NunchakuModelLoaderMixin:
def pad_tensor(tensor: Optional[torch.Tensor], multiples: int, dim: int, fill: Any = 0) -> torch.Tensor | None:
"""
Pad a tensor along a given dimension to the next multiple of a specified value.
Parameters
----------
tensor : torch.Tensor or None
Input tensor. If None, returns None.
multiples : int
Pad to this multiple. If <= 1, no padding is applied.
dim : int
Dimension along which to pad.
fill : Any, optional
Value to use for padding (default: 0).
Returns
-------
torch.Tensor or None
The padded tensor, or None if input was None.
"""
if multiples <= 1:
return tensor
if tensor is None:
......
# Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
"""
This module provides the PuLID FluxPipeline for personalized image generation with identity preservation.
It integrates face analysis, alignment, and embedding extraction using InsightFace and FaceXLib, and injects
identity embeddings into a Flux transformer pipeline.
.. note::
This module is adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
"""
import gc
import logging
import os
......@@ -11,9 +20,8 @@ import numpy as np
import torch
from diffusers import FluxPipeline
from diffusers.image_processor import PipelineImageInput
from diffusers.pipelines.flux.pipeline_flux import EXAMPLE_DOC_STRING, calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_flux import calculate_shift, retrieve_timesteps
from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
from diffusers.utils import replace_example_docstring
from facexlib.parsing import init_parsing_model
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from huggingface_hub import snapshot_download
......@@ -39,6 +47,19 @@ logger = logging.getLogger(__name__)
def check_antelopev2_dir(antelopev2_dirpath: str | os.PathLike[str]) -> bool:
"""
Check if the given directory contains all required AntelopeV2 ONNX model files with correct SHA256 hashes.
Parameters
----------
antelopev2_dirpath : str or os.PathLike
Path to the directory containing AntelopeV2 ONNX models.
Returns
-------
bool
True if all required files exist and have correct hashes, False otherwise.
"""
antelopev2_dirpath = Path(antelopev2_dirpath)
required_files = {
"1k3d68.onnx": "df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc",
......@@ -64,6 +85,53 @@ def check_antelopev2_dir(antelopev2_dirpath: str | os.PathLike[str]) -> bool:
class PuLIDPipeline(nn.Module):
"""
PyTorch module for extracting identity embeddings using PuLID, InsightFace, and EVA-CLIP.
This class handles face detection, alignment, parsing, and embedding extraction for use in personalized
diffusion pipelines.
Parameters
----------
dit : NunchakuFluxTransformer2dModel
The transformer model to inject PuLID attention modules into.
device : str or torch.device
Device to run the pipeline on.
weight_dtype : str or torch.dtype, optional
Data type for model weights (default: torch.bfloat16).
onnx_provider : str, optional
ONNX runtime provider, "gpu" or "cpu" (default: "gpu").
pulid_path : str or os.PathLike, optional
Path to PuLID weights in safetensors format.
eva_clip_path : str or os.PathLike, optional
Path to EVA-CLIP weights.
insightface_dirpath : str or os.PathLike or None, optional
Path to InsightFace models directory.
facexlib_dirpath : str or os.PathLike or None, optional
Path to FaceXLib models directory.
Attributes
----------
pulid_encoder : IDFormer
The IDFormer encoder for identity embedding.
pulid_ca : nn.ModuleList
List of PerceiverAttentionCA modules injected into the transformer.
face_helper : FaceRestoreHelper
Helper for face alignment and parsing.
clip_vision_model : nn.Module
EVA-CLIP visual backbone.
eva_transform_mean : tuple
Mean for image normalization.
eva_transform_std : tuple
Std for image normalization.
app : FaceAnalysis
InsightFace face analysis application.
handler_ante : insightface.model_zoo.model_zoo.Model
InsightFace embedding model.
debug_img_list : list
List of debug images (for visualization).
"""
def __init__(
self,
dit: NunchakuFluxTransformer2dModel,
......@@ -177,6 +245,19 @@ class PuLIDPipeline(nn.Module):
self.debug_img_list = []
def to_gray(self, img):
"""
Convert an image tensor to grayscale (3 channels).
Parameters
----------
img : torch.Tensor
Image tensor of shape (B, 3, H, W).
Returns
-------
torch.Tensor
Grayscale image tensor of shape (B, 3, H, W).
"""
x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
x = x.repeat(1, 3, 1, 1)
return x
......@@ -184,8 +265,21 @@ class PuLIDPipeline(nn.Module):
@torch.no_grad()
def get_id_embedding(self, image, cal_uncond=False):
"""
Args:
image: numpy rgb image, range [0, 255]
Extract identity embedding from an RGB image.
Parameters
----------
image : np.ndarray
Input RGB image as a numpy array, range [0, 255].
cal_uncond : bool, optional
If True, also compute unconditional embedding (default: False).
Returns
-------
id_embedding : torch.Tensor
Identity embedding tensor.
uncond_id_embedding : torch.Tensor or None
Unconditional embedding tensor if cal_uncond is True, else None.
"""
self.face_helper.clean_all()
self.debug_img_list = []
......@@ -260,6 +354,41 @@ class PuLIDPipeline(nn.Module):
class PuLIDFluxPipeline(FluxPipeline):
"""
FluxPipeline with PuLID identity embedding support.
This pipeline extends the standard FluxPipeline to support personalized image generation using
identity embeddings extracted from a reference image. It injects the PuLID identity encoder into
the transformer and supports all standard FluxPipeline features.
Parameters
----------
scheduler : SchedulerMixin
Scheduler for diffusion process.
vae : AutoencoderKL
Variational autoencoder for image encoding/decoding.
text_encoder : PreTrainedModel
Text encoder for prompt embeddings.
tokenizer : PreTrainedTokenizer
Tokenizer for text encoder.
text_encoder_2 : PreTrainedModel
Second text encoder (optional).
tokenizer_2 : PreTrainedTokenizer
Second tokenizer (optional).
transformer : NunchakuFluxTransformer2dModel
Transformer model for denoising.
image_encoder : nn.Module, optional
Image encoder for IP-Adapter (default: None).
feature_extractor : nn.Module, optional
Feature extractor for images (default: None).
pulid_device : str, optional
Device for PuLID pipeline (default: "cuda").
weight_dtype : torch.dtype, optional
Data type for model weights (default: torch.bfloat16).
onnx_provider : str, optional
ONNX runtime provider (default: "gpu").
"""
def __init__(
self,
scheduler,
......@@ -301,7 +430,6 @@ class PuLIDFluxPipeline(FluxPipeline):
)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
......@@ -335,103 +463,79 @@ class PuLIDFluxPipeline(FluxPipeline):
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_ip_adapter_image:
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
Function invoked when calling the pipeline for generation.
See the parent class :class:`diffusers.FluxPipeline` for full documentation.
Parameters
----------
prompt : str or List[str], optional
The prompt(s) to guide image generation.
prompt_2 : str or List[str], optional
Second prompt(s) for dual-encoder pipelines.
negative_prompt : str or List[str], optional
Negative prompt(s) to avoid in generation.
negative_prompt_2 : str or List[str], optional
Second negative prompt(s) for dual-encoder pipelines.
true_cfg_scale : float, optional
True classifier-free guidance scale.
height : int, optional
Output image height.
width : int, optional
Output image width.
num_inference_steps : int, optional
Number of denoising steps.
sigmas : List[float], optional
Custom sigmas for the scheduler.
guidance_scale : float, optional
Classifier-free guidance scale.
num_images_per_prompt : int, optional
Number of images per prompt.
generator : torch.Generator or List[torch.Generator], optional
Random generator(s) for reproducibility.
latents : torch.FloatTensor, optional
Pre-generated latents.
prompt_embeds : torch.FloatTensor, optional
Pre-generated prompt embeddings.
pooled_prompt_embeds : torch.FloatTensor, optional
Pre-generated pooled prompt embeddings.
ip_adapter_image : PipelineImageInput, optional
Image input for IP-Adapter.
id_image : PIL.Image.Image or np.ndarray, optional
Reference image for identity embedding.
id_weight : float, optional
Weight for identity embedding.
start_step : int, optional
Step to start from (for advanced use).
ip_adapter_image_embeds : List[torch.Tensor], optional
Precomputed IP-Adapter image embeddings.
negative_ip_adapter_image : PipelineImageInput, optional
Negative image input for IP-Adapter.
negative_ip_adapter_image_embeds : List[torch.Tensor], optional
Precomputed negative IP-Adapter image embeddings.
negative_prompt_embeds : torch.FloatTensor, optional
Precomputed negative prompt embeddings.
negative_pooled_prompt_embeds : torch.FloatTensor, optional
Precomputed negative pooled prompt embeddings.
output_type : str, optional
Output format ("pil" or "np").
return_dict : bool, optional
Whether to return a dict or tuple.
joint_attention_kwargs : dict, optional
Additional kwargs for joint attention.
callback_on_step_end : Callable, optional
Callback at the end of each denoising step.
callback_on_step_end_tensor_inputs : List[str], optional
List of tensor names for callback.
max_sequence_length : int, optional
Maximum sequence length for prompts.
Returns
-------
FluxPipelineOutput or tuple
Output images and additional info.
"""
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
......
"""
Test script for generating an image using the Nunchaku FLUX.1-schnell.
This script demonstrates how to load a quantized Nunchaku FLUX transformer model and
use it with the Diffusers :class:`~diffusers.FluxPipeline` to generate an image from a text prompt.
**Example usage**
.. code-block:: bash
python -m nunchaku.test
The generated image will be saved as ``flux.1-schnell.png`` in the current directory.
"""
import torch
from diffusers import FluxPipeline
......
"""
Utility functions for Nunchaku.
"""
import hashlib
import os
import warnings
......@@ -9,6 +13,19 @@ from huggingface_hub import hf_hub_download
def sha256sum(filepath: str | os.PathLike[str]) -> str:
"""
Compute the SHA-256 checksum of a file.
Parameters
----------
filepath : str or os.PathLike
Path to the file.
Returns
-------
str
The SHA-256 hexadecimal digest of the file.
"""
sha256 = hashlib.sha256()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(8192), b""):
......@@ -17,6 +34,28 @@ def sha256sum(filepath: str | os.PathLike[str]) -> str:
def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
"""
Fetch a file from a local path or download from HuggingFace Hub if not present.
The remote path should be in the format: ``<repo_id>/<filename>`` or ``<repo_id>/<subfolder>/<filename>``.
Parameters
----------
path : str or Path
Local file path or HuggingFace Hub path.
repo_type : str, optional
Type of HuggingFace repo (default: "model").
Returns
-------
Path
Path to the local file.
Raises
------
ValueError
If the path is too short to extract repo_id and subfolder.
"""
path = Path(path)
if path.exists():
......@@ -29,24 +68,27 @@ def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
repo_id = "/".join(parts[:2])
sub_path = Path(*parts[2:])
filename = sub_path.name
subfolder = sub_path.parent if sub_path.parent != Path(".") else None
subfolder = str(sub_path.parent) if sub_path.parent != Path(".") else None
path = hf_hub_download(repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type)
return Path(path)
def ceil_divide(x: int, divisor: int) -> int:
"""Ceiling division.
"""
Compute the ceiling of x divided by divisor.
Args:
x (`int`):
dividend.
divisor (`int`):
divisor.
Parameters
----------
x : int
Dividend.
divisor : int
Divisor.
Returns:
`int`:
ceiling division result.
Returns
-------
int
The smallest integer >= x / divisor.
"""
return (x + divisor - 1) // divisor
......@@ -57,6 +99,25 @@ def load_state_dict_in_safetensors(
filter_prefix: str = "",
return_metadata: bool = False,
) -> dict[str, torch.Tensor] | tuple[dict[str, torch.Tensor], dict[str, str]]:
"""
Load a state dict from a safetensors file, optionally filtering by prefix.
Parameters
----------
path : str or os.PathLike
Path to the safetensors file (local or HuggingFace Hub).
device : str or torch.device, optional
Device to load tensors onto (default: "cpu").
filter_prefix : str, optional
Only load keys starting with this prefix (default: "", no filter).
return_metadata : bool, optional
Whether to return safetensors metadata (default: False).
Returns
-------
dict[str, torch.Tensor] or tuple[dict[str, torch.Tensor], dict[str, str]]
The loaded state dict, and optionally the metadata if ``return_metadata`` is True.
"""
state_dict = {}
with safetensors.safe_open(fetch_or_download(path), framework="pt", device=device) as f:
metadata = f.metadata()
......@@ -71,17 +132,20 @@ def load_state_dict_in_safetensors(
def filter_state_dict(state_dict: dict[str, torch.Tensor], filter_prefix: str = "") -> dict[str, torch.Tensor]:
"""Filter state dict.
"""
Filter a state dict to only include keys starting with a given prefix.
Args:
state_dict (`dict`):
state dict.
filter_prefix (`str`):
filter prefix.
Parameters
----------
state_dict : dict[str, torch.Tensor]
The input state dict.
filter_prefix : str, optional
Prefix to filter keys by (default: "", no filter).
Returns:
`dict`:
filtered state dict.
Returns
-------
dict[str, torch.Tensor]
Filtered state dict with prefix removed from keys.
"""
return {k.removeprefix(filter_prefix): v for k, v in state_dict.items() if k.startswith(filter_prefix)}
......@@ -91,6 +155,28 @@ def get_precision(
device: str | torch.device = "cuda",
pretrained_model_name_or_path: str | os.PathLike[str] | None = None,
) -> str:
"""
Determine the quantization precision to use based on device and model.
Parameters
----------
precision : str, optional
"auto", "int4", or "fp4" (default: "auto").
device : str or torch.device, optional
Device to check (default: "cuda").
pretrained_model_name_or_path : str or os.PathLike or None, optional
Model name or path for warning checks.
Returns
-------
str
The selected precision ("int4" or "fp4").
Raises
------
AssertionError
If precision is not one of "auto", "int4", or "fp4".
"""
assert precision in ("auto", "int4", "fp4")
if precision == "auto":
if isinstance(device, str):
......@@ -109,11 +195,18 @@ def get_precision(
def is_turing(device: str | torch.device = "cuda") -> bool:
"""Check if the current GPU is a Turing GPU.
"""
Check if the current GPU is a Turing GPU (compute capability 7.5).
Returns:
`bool`:
True if the current GPU is a Turing GPU, False otherwise.
Parameters
----------
device : str or torch.device, optional
Device to check (default: "cuda").
Returns
-------
bool
True if the current GPU is a Turing GPU, False otherwise.
"""
if isinstance(device, str):
device = torch.device(device)
......@@ -124,15 +217,25 @@ def is_turing(device: str | torch.device = "cuda") -> bool:
def get_gpu_memory(device: str | torch.device = "cuda", unit: str = "GiB") -> int:
"""Get the GPU memory of the current device.
"""
Get the total memory of the current GPU.
Args:
device (`str` | `torch.device`, optional, defaults to `"cuda"`):
device.
Parameters
----------
device : str or torch.device, optional
Device to check (default: "cuda").
unit : str, optional
Unit for memory ("GiB", "MiB", or "B") (default: "GiB").
Returns:
`int`:
GPU memory in bytes.
Returns
-------
int
GPU memory in the specified unit.
Raises
------
AssertionError
If unit is not one of "GiB", "MiB", or "B".
"""
if isinstance(device, str):
device = torch.device(device)
......@@ -147,6 +250,21 @@ def get_gpu_memory(device: str | torch.device = "cuda", unit: str = "GiB") -> in
def check_hardware_compatibility(quantization_config: dict, device: str | torch.device = "cuda"):
"""
Check if the quantization config is compatible with the current GPU.
Parameters
----------
quantization_config : dict
Quantization configuration dictionary.
device : str or torch.device, optional
Device to check (default: "cuda").
Raises
------
ValueError
If the quantization config is not compatible with the GPU architecture.
"""
if isinstance(device, str):
device = torch.device(device)
capability = torch.cuda.get_device_capability(0 if device.index is None else device.index)
......
......@@ -34,3 +34,12 @@ build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
include = ["nunchaku"]
[tool.doc8]
max-line-length = 120
ignore-path = ["docs/_build"]
ignore = ["D000", "D001"]
[tool.rstcheck]
ignore_directives = ["tabs"]
ignore_messages = ["ERROR/3", "INFO/1"]
......@@ -778,6 +778,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
FluxModel::FluxModel(bool use_fp4, bool offload, Tensor::ScalarType dtype, Device device)
: dtype(dtype), offload(offload) {
CUDADeviceContext model_construction_ctx(device.idx);
for (int i = 0; i < 19; i++) {
transformer_blocks.push_back(
std::make_unique<JointTransformerBlock>(3072, 24, 3072, false, use_fp4, dtype, device));
......
......@@ -432,6 +432,13 @@ public:
return *this;
}
std::optional<CUDADeviceContext> operation_ctx_guard;
if (this->device().type == Device::CUDA) {
} else if (other.device().type == Device::CUDA) {
operation_ctx_guard.emplace(other.device().idx);
}
if (this->device().type == Device::CPU && other.device().type == Device::CPU) {
memcpy(data_ptr<char>(), other.data_ptr<char>(), shape.size() * scalar_size());
return *this;
......
import gc
import torch
from diffusers import (
AutoencoderKL,
FlowMatchEulerDiscreteScheduler,
FluxControlNetModel,
FluxControlNetPipeline,
FluxPipeline,
)
from diffusers.models import FluxMultiControlNetModel
from diffusers.utils import load_image
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
from nunchaku import NunchakuT5EncoderModel
from nunchaku.caching.diffusers_adapters import apply_cache_on_pipe
from nunchaku.models.transformers.transformer_flux import NunchakuFluxTransformer2dModel
from nunchaku.utils import get_precision
def test_flux_txt2img_cache_controlnet():
bfl_repo = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16 # or torch.float16, or torch.float32
device = "cuda" # or "cpu" if you want to run on CPU
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(bfl_repo, subfolder="scheduler", torch_dtype=dtype)
text_encoder = CLIPTextModel.from_pretrained(bfl_repo, subfolder="text_encoder", torch_dtype=dtype)
text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(
bfl_repo, subfolder="tokenizer", torch_dtype=dtype, clean_up_tokenization_spaces=True
)
tokenizer_2 = T5TokenizerFast.from_pretrained(
bfl_repo, subfolder="tokenizer_2", torch_dtype=dtype, clean_up_tokenization_spaces=True
)
vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=dtype)
precision = get_precision()
transformer = NunchakuFluxTransformer2dModel.from_pretrained(
f"mit-han-lab/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors",
# offload=True
)
transformer.set_attention_impl("nunchaku-fp16")
# qencoder
text_encoder_2 = NunchakuT5EncoderModel.from_pretrained("mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors")
controlnet_union = FluxControlNetModel.from_pretrained(
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0", torch_dtype=torch.bfloat16
)
controlnet = FluxMultiControlNetModel(
[controlnet_union]
) # we always recommend loading via FluxMultiControlNetModel
params = {
"scheduler": scheduler,
"vae": vae,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"transformer": transformer,
}
# pipe
pipe = FluxPipeline(**params).to(device, dtype=dtype)
pipe_cn = FluxControlNetPipeline(**params, controlnet=controlnet).to(device, dtype)
# offload
pipe.enable_sequential_cpu_offload(device=device)
pipe_cn.enable_sequential_cpu_offload(device=device)
# cache
apply_cache_on_pipe(
pipe_cn,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
params = {
"prompt": "A bohemian-style female travel blogger with sun-kissed skin and messy beach waves.",
"height": 1152,
"width": 768,
"num_inference_steps": 30,
"guidance_scale": 3.5,
}
# pipe
txt2img_res = pipe(
**params,
).images[0]
txt2img_res.save("flux.1-dev-txt2img.jpg")
gc.collect()
torch.cuda.empty_cache()
# cache
apply_cache_on_pipe(
pipe_cn,
use_double_fb_cache=True,
residual_diff_threshold_multi=0.09,
residual_diff_threshold_single=0.12,
)
# pipe_cn
control_iamge = load_image(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/openpose.jpg"
)
params["control_image"] = [control_iamge]
params["controlnet_conditioning_scale"] = [0.9]
params["control_guidance_end"] = [0.65]
cn_res = pipe_cn(
**params,
).images[0]
cn_res.save("flux.1-dev-cn-txt2img.jpg")
# additional requirements for testing
pytest
datasets
datasets<4
torchmetrics
mediapipe
controlnet_aux
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment