Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
fengzch-das
nunchaku
Commits
cd214093
Unverified
Commit
cd214093
authored
Jul 14, 2025
by
Muyang Li
Committed by
GitHub
Jul 14, 2025
Browse files
Merge pull request #530 from mit-han-lab/dev
parents
2a785f77
51732b7a
Changes
96
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1446 additions
and
258 deletions
+1446
-258
nunchaku/models/pulid/utils.py
nunchaku/models/pulid/utils.py
+83
-31
nunchaku/models/safety_checker.py
nunchaku/models/safety_checker.py
+54
-0
nunchaku/models/text_encoders/linear.py
nunchaku/models/text_encoders/linear.py
+84
-22
nunchaku/models/text_encoders/t5_encoder.py
nunchaku/models/text_encoders/t5_encoder.py
+58
-2
nunchaku/models/text_encoders/tinychat_utils.py
nunchaku/models/text_encoders/tinychat_utils.py
+108
-43
nunchaku/models/transformers/transformer_flux.py
nunchaku/models/transformers/transformer_flux.py
+340
-27
nunchaku/models/transformers/transformer_sana.py
nunchaku/models/transformers/transformer_sana.py
+155
-0
nunchaku/models/transformers/utils.py
nunchaku/models/transformers/utils.py
+65
-1
nunchaku/pipeline/pipeline_flux_pulid.py
nunchaku/pipeline/pipeline_flux_pulid.py
+205
-101
nunchaku/test.py
nunchaku/test.py
+15
-0
nunchaku/utils.py
nunchaku/utils.py
+148
-30
pyproject.toml
pyproject.toml
+9
-0
src/FluxModel.cpp
src/FluxModel.cpp
+2
-0
src/Tensor.h
src/Tensor.h
+7
-0
tests/flux/test_flux_txt2img_cache_controlnet.py
tests/flux/test_flux_txt2img_cache_controlnet.py
+112
-0
tests/requirements.txt
tests/requirements.txt
+1
-1
No files found.
nunchaku/models/pulid/utils.py
View file @
cd214093
# Adapted from https://github.com/ToTheBeginning/PuLID
"""
This module provides utility functions for PuLID.
.. note::
This module is adapted from the original PuLID repository:
https://github.com/ToTheBeginning/PuLID
"""
import
math
import
math
import
cv2
import
cv2
...
@@ -8,6 +15,23 @@ from torchvision.utils import make_grid
...
@@ -8,6 +15,23 @@ from torchvision.utils import make_grid
def
resize_numpy_image_long
(
image
,
resize_long_edge
=
768
):
def
resize_numpy_image_long
(
image
,
resize_long_edge
=
768
):
"""
Resize a numpy image so that its longest edge matches ``resize_long_edge``, preserving aspect ratio.
If the image's longest edge is already less than or equal to ``resize_long_edge``, the image is returned unchanged.
Parameters
----------
image : np.ndarray
Input image as a numpy array of shape (H, W, C).
resize_long_edge : int, optional
Desired size for the longest edge (default: 768).
Returns
-------
np.ndarray
The resized image as a numpy array.
"""
h
,
w
=
image
.
shape
[:
2
]
h
,
w
=
image
.
shape
[:
2
]
if
max
(
h
,
w
)
<=
resize_long_edge
:
if
max
(
h
,
w
)
<=
resize_long_edge
:
return
image
return
image
...
@@ -18,18 +42,27 @@ def resize_numpy_image_long(image, resize_long_edge=768):
...
@@ -18,18 +42,27 @@ def resize_numpy_image_long(image, resize_long_edge=768):
return
image
return
image
# from basicsr
def
img2tensor
(
imgs
,
bgr2rgb
=
True
,
float32
=
True
):
def
img2tensor
(
imgs
,
bgr2rgb
=
True
,
float32
=
True
):
"""Numpy array to tensor.
"""
Convert numpy images to PyTorch tensors.
Args:
imgs (list[ndarray] | ndarray): Input images.
This function supports both single images and lists of images. The images are converted from
bgr2rgb (bool): Whether to change bgr to rgb.
HWC (height, width, channel) format to CHW (channel, height, width) format. Optionally, BGR images
float32 (bool): Whether to change to float32.
can be converted to RGB, and the output can be cast to float32.
Returns:
Parameters
list[tensor] | tensor: Tensor images. If returned results only have
----------
one element, just return tensor.
imgs : np.ndarray or list of np.ndarray
Input image(s) as numpy array(s).
bgr2rgb : bool, optional
Whether to convert BGR images to RGB (default: True).
float32 : bool, optional
Whether to cast the output tensor(s) to float32 (default: True).
Returns
-------
torch.Tensor or list of torch.Tensor
Converted tensor(s). If a single image is provided, returns a tensor; if a list is provided, returns a list of tensors.
"""
"""
def
_totensor
(
img
,
bgr2rgb
,
float32
):
def
_totensor
(
img
,
bgr2rgb
,
float32
):
...
@@ -48,25 +81,44 @@ def img2tensor(imgs, bgr2rgb=True, float32=True):
...
@@ -48,25 +81,44 @@ def img2tensor(imgs, bgr2rgb=True, float32=True):
def
tensor2img
(
tensor
,
rgb2bgr
=
True
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
)):
def
tensor2img
(
tensor
,
rgb2bgr
=
True
,
out_type
=
np
.
uint8
,
min_max
=
(
0
,
1
)):
"""Convert torch Tensors into image numpy arrays.
"""
Convert PyTorch tensor(s) to image numpy array(s).
After clamping to [min, max], values will be normalized to [0, 1].
This function supports 4D mini-batch tensors, 3D tensors, and 2D tensors. The output is a numpy array
Args:
in HWC (height, width, channel) or HW (height, width) format. Optionally, RGB images can be converted to BGR,
tensor (Tensor or list[Tensor]): Accept shapes:
and the output type can be specified.
1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
2) 3D Tensor of shape (3/1 x H x W);
After clamping to [min, max], values are normalized to [0, 1].
3) 2D Tensor of shape (H x W).
Tensor channel should be in RGB order.
Parameters
rgb2bgr (bool): Whether to change rgb to bgr.
----------
out_type (numpy type): output types. If ``np.uint8``, transform outputs
tensor : torch.Tensor or list of torch.Tensor
to uint8 type with range [0, 255]; otherwise, float type with
Input tensor(s). Accepts:
range [0, 1]. Default: ``np.uint8``.
min_max (tuple[int]): min and max values for clamp.
1) 4D mini-batch tensor of shape (B x 3/1 x H x W)
2) 3D tensor of shape (3/1 x H x W)
Returns:
3) 2D tensor of shape (H x W)
(Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
shape (H x W). The channel order is BGR.
The channel order should be RGB.
rgb2bgr : bool, optional
Whether to convert RGB images to BGR (default: True).
out_type : numpy type, optional
Output data type. If ``np.uint8``, output is in [0, 255]; otherwise, in [0, 1] (default: np.uint8).
min_max : tuple of int, optional
Min and max values for clamping (default: (0, 1)).
Returns
-------
np.ndarray or list of np.ndarray
Converted image(s) as numpy array(s). If a single tensor is provided, returns a numpy array; if a list is provided, returns a list of numpy arrays.
Raises
------
TypeError
If the input is not a tensor or list of tensors, or if the tensor has unsupported dimensions.
"""
"""
if
not
(
torch
.
is_tensor
(
tensor
)
or
(
isinstance
(
tensor
,
list
)
and
all
(
torch
.
is_tensor
(
t
)
for
t
in
tensor
))):
if
not
(
torch
.
is_tensor
(
tensor
)
or
(
isinstance
(
tensor
,
list
)
and
all
(
torch
.
is_tensor
(
t
)
for
t
in
tensor
))):
raise
TypeError
(
f
"tensor or list of tensors expected, got
{
type
(
tensor
)
}
"
)
raise
TypeError
(
f
"tensor or list of tensors expected, got
{
type
(
tensor
)
}
"
)
...
...
nunchaku/models/safety_checker.py
View file @
cd214093
"""
This module provides a `SafetyChecker` class for evaluating user prompts against
defined safety policies using a large language model. Only used deploying online gradio demos.
"""
import
torch
import
torch
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
#: Template for the safety check prompt.
safety_check_template
=
"""You are a policy expert trying to help determine whether a user
safety_check_template
=
"""You are a policy expert trying to help determine whether a user
prompt is in violation of the defined safety policies.
prompt is in violation of the defined safety policies.
...
@@ -19,7 +25,40 @@ safety_check_template = """You are a policy expert trying to help determine whet
...
@@ -19,7 +25,40 @@ safety_check_template = """You are a policy expert trying to help determine whet
class
SafetyChecker
:
class
SafetyChecker
:
"""
SafetyChecker(device, disabled=False)
A class to check whether a user prompt violates safety policies using a language model.
Parameters
----------
device : str or torch.device
The device to run the model on (e.g., "cuda", "cpu").
disabled : bool, optional
If True, disables the safety check and always returns True (default: False).
Examples
--------
>>> checker = SafetyChecker(device="cuda")
>>> checker("Generate a nude girl image")
False
>>> checker = SafetyChecker(device="cpu", disabled=True)
>>> checker("Any prompt")
True
"""
def
__init__
(
self
,
device
:
str
|
torch
.
device
,
disabled
:
bool
=
False
):
def
__init__
(
self
,
device
:
str
|
torch
.
device
,
disabled
:
bool
=
False
):
"""
Initialize the SafetyChecker.
Parameters
----------
device : str or torch.device
The device to run the model on.
disabled : bool, optional
If True, disables the safety check (default: False).
"""
if
not
disabled
:
if
not
disabled
:
self
.
device
=
device
self
.
device
=
device
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"google/shieldgemma-2b"
)
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
"google/shieldgemma-2b"
)
...
@@ -29,6 +68,21 @@ class SafetyChecker:
...
@@ -29,6 +68,21 @@ class SafetyChecker:
self
.
disabled
=
disabled
self
.
disabled
=
disabled
def
__call__
(
self
,
user_prompt
:
str
,
threshold
:
float
=
0.2
)
->
bool
:
def
__call__
(
self
,
user_prompt
:
str
,
threshold
:
float
=
0.2
)
->
bool
:
"""
Evaluate whether a user prompt is safe according to the defined policy.
Parameters
----------
user_prompt : str
The user prompt to evaluate.
threshold : float, optional
The probability threshold for flagging a prompt as unsafe (default: 0.2).
Returns
-------
bool
True if the prompt is considered safe, False otherwise.
"""
if
self
.
disabled
:
if
self
.
disabled
:
return
True
return
True
device
=
self
.
device
device
=
self
.
device
...
...
nunchaku/models/text_encoders/linear.py
View file @
cd214093
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""TinyChat Quantized Linear Module"""
"""
This module provides the :class:`W4Linear` quantized linear layer, which implements
4-bit weight-only quantization for efficient inference.
"""
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -11,6 +14,40 @@ __all__ = ["W4Linear"]
...
@@ -11,6 +14,40 @@ __all__ = ["W4Linear"]
class
W4Linear
(
nn
.
Module
):
class
W4Linear
(
nn
.
Module
):
"""
4-bit quantized linear layer with group-wise quantization.
Parameters
----------
in_features : int
Number of input features.
out_features : int
Number of output features.
bias : bool, optional
If True, adds a learnable bias (default: False).
group_size : int, optional
Number of input channels per quantization group (default: 128).
If -1, uses the full input dimension as a single group.
dtype : torch.dtype, optional
Data type for quantization scales and zeros (default: torch.float16).
device : str or torch.device, optional
Device for weights and buffers (default: "cuda").
Attributes
----------
in_features : int
out_features : int
group_size : int
qweight : torch.Tensor
Quantized weight tensor (int16).
scales : torch.Tensor
Per-group scale tensor.
scaled_zeros : torch.Tensor
Per-group zero-point tensor (scaled).
bias : torch.Tensor or None
Optional bias tensor.
"""
def
__init__
(
def
__init__
(
self
,
self
,
in_features
:
int
,
in_features
:
int
,
...
@@ -61,14 +98,33 @@ class W4Linear(nn.Module):
...
@@ -61,14 +98,33 @@ class W4Linear(nn.Module):
@
property
@
property
def
weight_bits
(
self
)
->
int
:
def
weight_bits
(
self
)
->
int
:
"""
Number of bits per quantized weight (always 4).
"""
return
4
return
4
@
property
@
property
def
interleave
(
self
)
->
int
:
def
interleave
(
self
)
->
int
:
"""
Interleave factor for quantized weights (always 4).
"""
return
4
return
4
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
Forward pass.
Parameters
----------
x : torch.Tensor
Input tensor of shape (..., in_features).
Returns
-------
torch.Tensor
Output tensor of shape (..., out_features).
"""
if
x
.
numel
()
/
x
.
shape
[
-
1
]
<
8
:
if
x
.
numel
()
/
x
.
shape
[
-
1
]
<
8
:
out
=
gemv_awq
(
out
=
gemv_awq
(
x
,
x
,
...
@@ -97,27 +153,30 @@ class W4Linear(nn.Module):
...
@@ -97,27 +153,30 @@ class W4Linear(nn.Module):
zero
:
torch
.
Tensor
|
None
=
None
,
zero
:
torch
.
Tensor
|
None
=
None
,
zero_pre_scaled
:
bool
=
False
,
zero_pre_scaled
:
bool
=
False
,
)
->
"W4Linear"
:
)
->
"W4Linear"
:
"""Convert a linear layer to a TinyChat 4-bit weight-only quantized linear layer.
"""
Convert a standard nn.Linear to a quantized W4Linear.
Args:
linear (`nn.Linear`):
Parameters
linear layer to be converted.
----------
group_size (`int`):
linear : nn.Linear
quantization group size.
The linear layer to convert.
init_only (`bool`, *optional*, defaults to `False`):
group_size : int
whether to only initialize the quantized linear layer.
Quantization group size.
weight (`torch.Tensor`, *optional*, defaults to `None`):
init_only : bool, optional
weight tensor for the quantized linear layer.
If True, only initializes the quantized layer (default: False).
scale (`torch.Tensor`, *optional*, defaults to `None`):
weight : torch.Tensor, optional
scale tensor for the quantized linear layer.
Precomputed quantized weight (default: None).
zero (`torch.Tensor`, *optional*, defaults to `None`):
scale : torch.Tensor, optional
zero point tensor for the quantized linear layer.
Precomputed scale tensor (default: None).
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
zero : torch.Tensor, optional
whether zero point tensor is pre-scaled.
Precomputed zero-point tensor (default: None).
zero_pre_scaled : bool, optional
Returns:
Whether the zero-point tensor is pre-scaled (default: False).
`W4Linear`:
quantized linear layer.
Returns
-------
W4Linear
Quantized linear layer.
"""
"""
assert
isinstance
(
linear
,
nn
.
Linear
)
assert
isinstance
(
linear
,
nn
.
Linear
)
weight
=
linear
.
weight
.
data
if
weight
is
None
else
weight
.
data
weight
=
linear
.
weight
.
data
if
weight
is
None
else
weight
.
data
...
@@ -167,6 +226,9 @@ class W4Linear(nn.Module):
...
@@ -167,6 +226,9 @@ class W4Linear(nn.Module):
return
_linear
return
_linear
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
"""
Returns a string describing the layer configuration.
"""
return
"in_features={}, out_features={}, bias={}, weight_bits={}, group_size={}"
.
format
(
return
"in_features={}, out_features={}, bias={}, weight_bits={}, group_size={}"
.
format
(
self
.
in_features
,
self
.
in_features
,
self
.
out_features
,
self
.
out_features
,
...
...
nunchaku/models/text_encoders/t5_encoder.py
View file @
cd214093
"""
The NunchakuT5EncoderModel class enables loading T5 encoder weights from safetensors files,
automatically replacing supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear`
modules for improved performance and memory efficiency.
"""
import
json
import
json
import
logging
import
logging
import
os
import
os
...
@@ -20,12 +26,62 @@ logger = logging.getLogger(__name__)
...
@@ -20,12 +26,62 @@ logger = logging.getLogger(__name__)
class
NunchakuT5EncoderModel
(
T5EncoderModel
):
class
NunchakuT5EncoderModel
(
T5EncoderModel
):
"""
Nunchaku T5 Encoder Model
Extends :class:`transformers.T5EncoderModel` to support quantized weights and
memory-efficient inference using :class:`~nunchaku.models.text_encoders.linear.W4Linear`.
This class provides a convenient interface for loading T5 encoder weights from
safetensors files, automatically replacing supported linear layers with quantized
modules for improved speed and reduced memory usage.
Example
-------
.. code-block:: python
model = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
"""
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
[
str
],
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
[
str
],
**
kwargs
):
"""
Load a :class:`NunchakuT5EncoderModel` from a safetensors file.
This method loads the model configuration and weights from a safetensors file,
initializes the model on the 'meta' device (no memory allocation for weights),
and replaces supported linear layers with quantized :class:`~nunchaku.models.text_encoders.linear.W4Linear` modules.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the safetensors file containing the model weights and metadata.
torch_dtype : torch.dtype, optional
Data type for model initialization (default: ``torch.bfloat16``).
Set to ``torch.float16`` for Turing GPUs.
device : str or torch.device, optional
Device to load the model onto (default: ``"cuda"``).
If the model is loaded on CPU, it will be automatically moved to GPU.
Returns
-------
NunchakuT5EncoderModel
The loaded and quantized T5 encoder model.
Example
-------
.. code-block:: python
model = NunchakuT5EncoderModel.from_pretrained(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
"""
pretrained_model_name_or_path
=
Path
(
pretrained_model_name_or_path
)
pretrained_model_name_or_path
=
Path
(
pretrained_model_name_or_path
)
state_dict
,
metadata
=
load_state_dict_in_safetensors
(
pretrained_model_name_or_path
,
return_metadata
=
True
)
state_dict
,
metadata
=
load_state_dict_in_safetensors
(
pretrained_model_name_or_path
,
return_metadata
=
True
)
# Load the config file
# Load the config file
from metadata
config
=
json
.
loads
(
metadata
[
"config"
])
config
=
json
.
loads
(
metadata
[
"config"
])
config
=
T5Config
(
**
config
)
config
=
T5Config
(
**
config
)
...
@@ -35,7 +91,7 @@ class NunchakuT5EncoderModel(T5EncoderModel):
...
@@ -35,7 +91,7 @@ class NunchakuT5EncoderModel(T5EncoderModel):
t5_encoder
.
eval
()
t5_encoder
.
eval
()
# Load the model weights from the safetensors file
# Load the model weights from the safetensors file
and quantize supported linear layers
named_modules
=
{}
named_modules
=
{}
for
name
,
module
in
t5_encoder
.
named_modules
():
for
name
,
module
in
t5_encoder
.
named_modules
():
assert
isinstance
(
name
,
str
)
assert
isinstance
(
name
,
str
)
...
...
nunchaku/models/text_encoders/tinychat_utils.py
View file @
cd214093
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""TinyChat backend utilities."""
"""
This module provides utility functions for quantized linear layers in the TinyChat backend.
"""
import
torch
import
torch
...
@@ -7,35 +9,50 @@ __all__ = ["ceil_num_groups", "convert_to_tinychat_w4x16y16_linear_weight"]
...
@@ -7,35 +9,50 @@ __all__ = ["ceil_num_groups", "convert_to_tinychat_w4x16y16_linear_weight"]
def
ceil_divide
(
x
:
int
,
divisor
:
int
)
->
int
:
def
ceil_divide
(
x
:
int
,
divisor
:
int
)
->
int
:
"""Ceiling division.
"""
Compute the ceiling of integer division.
Args:
x (`int`):
Parameters
dividend.
----------
divisor (`int`):
x : int
divisor.
Dividend.
divisor : int
Returns:
Divisor.
`int`:
ceiling division result.
Returns
-------
int
The smallest integer greater than or equal to ``x / divisor``.
"""
"""
return
(
x
+
divisor
-
1
)
//
divisor
return
(
x
+
divisor
-
1
)
//
divisor
def
ceil_num_groups
(
in_features
:
int
,
group_size
:
int
,
weight_bits
:
int
=
4
)
->
int
:
def
ceil_num_groups
(
in_features
:
int
,
group_size
:
int
,
weight_bits
:
int
=
4
)
->
int
:
"""Calculate the ceiling number of quantization groups.
"""
Calculate the padded number of quantization groups for TinyChat quantization.
Args:
in_features (`int`):
This ensures the number of groups is compatible with TinyChat's packing and kernel requirements.
input channel size.
group_size (`int`):
Parameters
quantization group size.
----------
weight_bits (`int`, *optional*, defaults to `4`):
in_features : int
quantized weight bits.
Input channel size (number of input features).
group_size : int
Returns:
Quantization group size.
`int`:
weight_bits : int, optional
ceiling number of quantization groups.
Number of bits per quantized weight (default: 4).
Returns
-------
int
The padded number of quantization groups.
Raises
------
AssertionError
If ``in_features`` is not divisible by ``group_size``, or if ``weight_bits`` is not 4, 2, or 1.
NotImplementedError
If ``group_size`` is not one of the supported values (>=128, 64, 32).
"""
"""
assert
in_features
%
group_size
==
0
,
"input channel size should be divisible by group size."
assert
in_features
%
group_size
==
0
,
"input channel size should be divisible by group size."
num_groups
=
in_features
//
group_size
num_groups
=
in_features
//
group_size
...
@@ -49,7 +66,7 @@ def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) ->
...
@@ -49,7 +66,7 @@ def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) ->
elif
group_size
==
32
:
elif
group_size
==
32
:
num_packs_factor
=
4
num_packs_factor
=
4
else
:
else
:
raise
NotImplementedError
raise
NotImplementedError
(
"Unsupported group size for TinyChat quantization."
)
# make sure num_packs is a multiple of num_packs_factor
# make sure num_packs is a multiple of num_packs_factor
num_packs
=
ceil_divide
(
num_packs
,
num_packs_factor
)
*
num_packs_factor
num_packs
=
ceil_divide
(
num_packs
,
num_packs_factor
)
*
num_packs_factor
num_groups
=
num_packs
*
pack_size
num_groups
=
num_packs
*
pack_size
...
@@ -57,6 +74,28 @@ def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) ->
...
@@ -57,6 +74,28 @@ def ceil_num_groups(in_features: int, group_size: int, weight_bits: int = 4) ->
def
pack_w4
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
pack_w4
(
weight
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Pack quantized 4-bit weights into TinyChat's int16 format.
This function rearranges and packs 4-bit quantized weights (stored as int32) into
the format expected by TinyChat CUDA kernels.
Parameters
----------
weight : torch.Tensor
Quantized weight tensor of shape (out_features, in_features), dtype int32.
The input channel dimension must be divisible by 32.
Returns
-------
torch.Tensor
Packed weight tensor of dtype int16.
Raises
------
AssertionError
If input tensor is not int32 or input channel size is not divisible by 32.
"""
assert
weight
.
dtype
==
torch
.
int32
,
f
"quantized weight should be torch.int32, but got
{
weight
.
dtype
}
."
assert
weight
.
dtype
==
torch
.
int32
,
f
"quantized weight should be torch.int32, but got
{
weight
.
dtype
}
."
oc
,
ic
=
weight
.
shape
oc
,
ic
=
weight
.
shape
assert
ic
%
32
==
0
,
"input channel size should be divisible by 32."
assert
ic
%
32
==
0
,
"input channel size should be divisible by 32."
...
@@ -74,23 +113,49 @@ def convert_to_tinychat_w4x16y16_linear_weight(
...
@@ -74,23 +113,49 @@ def convert_to_tinychat_w4x16y16_linear_weight(
group_size
:
int
=
-
1
,
group_size
:
int
=
-
1
,
zero_pre_scaled
:
bool
=
False
,
zero_pre_scaled
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Convert a weight tensor to TinyChat W4-X16-Y16 linear weight format.
"""
Convert a floating-point weight tensor to TinyChat W4-X16-Y16 quantized linear format.
Args:
weight (`torch.Tensor`):
This function quantizes the input weights to 4 bits per value, applies group-wise
weight tensor to be converted.
scaling and zero-point, and packs the result into the format expected by TinyChat
scale (`torch.Tensor`):
quantized linear layers.
scale tensor for the weight tensor.
zero (`torch.Tensor`):
Parameters
zero point tensor for the weight tensor.
----------
group_size (`int`, *optional*, defaults to `-1`):
weight : torch.Tensor
quantization group size.
Floating-point weight tensor of shape (out_features, in_features).
zero_pre_scaled (`bool`, *optional*, defaults to `False`):
Must be of dtype ``torch.float16`` or ``torch.bfloat16``.
whether zero point tensor is pre-scaled.
scale : torch.Tensor
Per-group scale tensor (can be broadcastable).
Returns:
zero : torch.Tensor
`tuple[torch.Tensor, torch.Tensor, torch.Tensor]`:
Per-group zero-point tensor (can be broadcastable).
packed quantized weight tensor, scale tensor, and zero point tensor.
group_size : int, optional
Quantization group size. If set to -1 (default), uses the full input dimension as a single group.
zero_pre_scaled : bool, optional
If True, the zero tensor is already scaled by the scale tensor (default: False).
Returns
-------
tuple of torch.Tensor
- packed_weight : torch.Tensor
Packed quantized weight tensor (int16).
- packed_scale : torch.Tensor
Packed scale tensor (shape: [num_groups, out_features], dtype matches input).
- packed_zero : torch.Tensor
Packed zero-point tensor (shape: [num_groups, out_features], dtype matches input).
Raises
------
AssertionError
If input types or shapes are invalid, or quantized values are out of range.
Example
-------
.. code-block:: python
qweight, qscale, qzero = convert_to_tinychat_w4x16y16_linear_weight(
weight, scale, zero, group_size=128
)
"""
"""
dtype
,
device
=
weight
.
dtype
,
weight
.
device
dtype
,
device
=
weight
.
dtype
,
weight
.
device
assert
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
"currently tinychat only supports fp16 and bf16."
assert
dtype
in
(
torch
.
float16
,
torch
.
bfloat16
),
"currently tinychat only supports fp16 and bf16."
...
...
nunchaku/models/transformers/transformer_flux.py
View file @
cd214093
"""
Implements the :class:`NunchakuFluxTransformer2dModel`, a quantized transformer for Diffusers with efficient inference and LoRA support.
"""
import
json
import
json
import
logging
import
logging
import
os
import
os
...
@@ -32,6 +36,20 @@ logger = logging.getLogger(__name__)
...
@@ -32,6 +36,20 @@ logger = logging.getLogger(__name__)
class
NunchakuFluxTransformerBlocks
(
nn
.
Module
):
class
NunchakuFluxTransformerBlocks
(
nn
.
Module
):
"""
Wrapper for quantized Nunchaku FLUX transformer blocks.
This class manages the forward pass, rotary embedding packing, and optional
residual callbacks for ID embeddings.
Parameters
----------
m : QuantizedFluxModel
The quantized transformer model.
device : str or torch.device
Device to run the model on.
"""
def
__init__
(
self
,
m
:
QuantizedFluxModel
,
device
:
str
|
torch
.
device
):
def
__init__
(
self
,
m
:
QuantizedFluxModel
,
device
:
str
|
torch
.
device
):
super
(
NunchakuFluxTransformerBlocks
,
self
).
__init__
()
super
(
NunchakuFluxTransformerBlocks
,
self
).
__init__
()
self
.
m
=
m
self
.
m
=
m
...
@@ -40,6 +58,19 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -40,6 +58,19 @@ class NunchakuFluxTransformerBlocks(nn.Module):
@
staticmethod
@
staticmethod
def
pack_rotemb
(
rotemb
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
pack_rotemb
(
rotemb
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Packs rotary embeddings for efficient computation.
Parameters
----------
rotemb : torch.Tensor
Rotary embedding tensor of shape (B, M, D//2, 1, 2), dtype float32.
Returns
-------
torch.Tensor
Packed rotary embedding tensor of shape (B, M, D).
"""
assert
rotemb
.
dtype
==
torch
.
float32
assert
rotemb
.
dtype
==
torch
.
float32
B
=
rotemb
.
shape
[
0
]
B
=
rotemb
.
shape
[
0
]
M
=
rotemb
.
shape
[
1
]
M
=
rotemb
.
shape
[
1
]
...
@@ -73,6 +104,38 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -73,6 +104,38 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_single_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
skip_first_layer
=
False
,
skip_first_layer
=
False
,
):
):
"""
Forward pass for the quantized transformer blocks.
It will call the forward method of ``m`` on the C backend.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states for image tokens.
temb : torch.Tensor
Temporal embedding tensor.
encoder_hidden_states : torch.Tensor
Input hidden states for text tokens.
image_rotary_emb : torch.Tensor
Rotary embedding tensor for all tokens.
id_embeddings : torch.Tensor, optional
Optional ID embeddings for residual callback.
id_weight : float, optional
Weight for ID embedding residual.
joint_attention_kwargs : dict, optional
Additional kwargs for joint attention.
controlnet_block_samples : list[torch.Tensor], optional
ControlNet block samples.
controlnet_single_block_samples : list[torch.Tensor], optional
ControlNet single block samples.
skip_first_layer : bool, optional
Whether to skip the first layer.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
(encoder_hidden_states, hidden_states) after transformer blocks.
"""
# batch_size = hidden_states.shape[0]
# batch_size = hidden_states.shape[0]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
...
@@ -149,6 +212,33 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -149,6 +212,33 @@ class NunchakuFluxTransformerBlocks(nn.Module):
controlnet_block_samples
=
None
,
controlnet_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
controlnet_single_block_samples
=
None
,
):
):
"""
Forward pass for a specific transformer layer in ``m``.
Parameters
----------
idx : int
Index of the transformer layer.
hidden_states : torch.Tensor
Input hidden states for image tokens.
encoder_hidden_states : torch.Tensor
Input hidden states for text tokens.
temb : torch.Tensor
Temporal embedding tensor.
image_rotary_emb : torch.Tensor
Rotary embedding tensor for all tokens.
joint_attention_kwargs : dict, optional
Additional kwargs for joint attention.
controlnet_block_samples : list[torch.Tensor], optional
ControlNet block samples.
controlnet_single_block_samples : list[torch.Tensor], optional
ControlNet single block samples.
Returns
-------
tuple[torch.Tensor, torch.Tensor]
(encoder_hidden_states, hidden_states) after the specified layer.
"""
# batch_size = hidden_states.shape[0]
# batch_size = hidden_states.shape[0]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
...
@@ -195,6 +285,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -195,6 +285,9 @@ class NunchakuFluxTransformerBlocks(nn.Module):
return
encoder_hidden_states
,
hidden_states
return
encoder_hidden_states
,
hidden_states
def
set_pulid_residual_callback
(
self
):
def
set_pulid_residual_callback
(
self
):
"""
Sets the residual callback for PulID (personalized ID) embeddings.
"""
id_embeddings
=
self
.
id_embeddings
id_embeddings
=
self
.
id_embeddings
pulid_ca
=
self
.
pulid_ca
pulid_ca
=
self
.
pulid_ca
pulid_ca_idx
=
[
self
.
pulid_ca_idx
]
pulid_ca_idx
=
[
self
.
pulid_ca_idx
]
...
@@ -209,10 +302,16 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -209,10 +302,16 @@ class NunchakuFluxTransformerBlocks(nn.Module):
self
.
m
.
set_residual_callback
(
callback
)
self
.
m
.
set_residual_callback
(
callback
)
def
reset_pulid_residual_callback
(
self
):
def
reset_pulid_residual_callback
(
self
):
"""
Resets the PulID residual callback to None.
"""
self
.
callback_holder
=
None
self
.
callback_holder
=
None
self
.
m
.
set_residual_callback
(
None
)
self
.
m
.
set_residual_callback
(
None
)
def
__del__
(
self
):
def
__del__
(
self
):
"""
Destructor to reset the quantized model.
"""
self
.
m
.
reset
()
self
.
m
.
reset
()
def
norm1
(
def
norm1
(
...
@@ -221,11 +320,44 @@ class NunchakuFluxTransformerBlocks(nn.Module):
...
@@ -221,11 +320,44 @@ class NunchakuFluxTransformerBlocks(nn.Module):
emb
:
torch
.
Tensor
,
emb
:
torch
.
Tensor
,
idx
:
int
=
0
,
idx
:
int
=
0
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Runs the norm_one_forward for a specific layer in ``m``.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states.
emb : torch.Tensor
Embedding tensor.
idx : int, optional
Layer index (default: 0).
Returns
-------
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
Output tensors from norm_one_forward.
"""
return
self
.
m
.
norm_one_forward
(
idx
,
hidden_states
,
emb
)
return
self
.
m
.
norm_one_forward
(
idx
,
hidden_states
,
emb
)
## copied from diffusers 0.30.3
def
rope
(
pos
:
torch
.
Tensor
,
dim
:
int
,
theta
:
int
)
->
torch
.
Tensor
:
def
rope
(
pos
:
torch
.
Tensor
,
dim
:
int
,
theta
:
int
)
->
torch
.
Tensor
:
"""
Rotary positional embedding function.
Parameters
----------
pos : torch.Tensor
Position tensor of shape (..., n).
dim : int
Embedding dimension (must be even).
theta : int
Rotary base.
Returns
-------
torch.Tensor
Rotary embedding tensor.
"""
assert
dim
%
2
==
0
,
"The dimension must be even."
assert
dim
%
2
==
0
,
"The dimension must be even."
scale
=
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float64
,
device
=
pos
.
device
)
/
dim
scale
=
torch
.
arange
(
0
,
dim
,
2
,
dtype
=
torch
.
float64
,
device
=
pos
.
device
)
/
dim
...
@@ -247,6 +379,19 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
...
@@ -247,6 +379,19 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
class
EmbedND
(
nn
.
Module
):
class
EmbedND
(
nn
.
Module
):
"""
Multi-dimensional rotary embedding module.
Parameters
----------
dim : int
Embedding dimension.
theta : int
Rotary base.
axes_dim : list[int]
List of axis dimensions for each spatial axis.
"""
def
__init__
(
self
,
dim
:
int
,
theta
:
int
,
axes_dim
:
list
[
int
]):
def
__init__
(
self
,
dim
:
int
,
theta
:
int
,
axes_dim
:
list
[
int
]):
super
(
EmbedND
,
self
).
__init__
()
super
(
EmbedND
,
self
).
__init__
()
self
.
dim
=
dim
self
.
dim
=
dim
...
@@ -254,6 +399,19 @@ class EmbedND(nn.Module):
...
@@ -254,6 +399,19 @@ class EmbedND(nn.Module):
self
.
axes_dim
=
axes_dim
self
.
axes_dim
=
axes_dim
def
forward
(
self
,
ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
ids
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
Computes rotary embeddings for multi-dimensional positions.
Parameters
----------
ids : torch.Tensor
Position indices tensor of shape (..., n_axes).
Returns
-------
torch.Tensor
Rotary embedding tensor.
"""
if
Version
(
diffusers
.
__version__
)
>=
Version
(
"0.31.0"
):
if
Version
(
diffusers
.
__version__
)
>=
Version
(
"0.31.0"
):
ids
=
ids
[
None
,
...]
ids
=
ids
[
None
,
...]
n_axes
=
ids
.
shape
[
-
1
]
n_axes
=
ids
.
shape
[
-
1
]
...
@@ -268,6 +426,27 @@ def load_quantized_module(
...
@@ -268,6 +426,27 @@ def load_quantized_module(
offload
:
bool
=
False
,
offload
:
bool
=
False
,
bf16
:
bool
=
True
,
bf16
:
bool
=
True
,
)
->
QuantizedFluxModel
:
)
->
QuantizedFluxModel
:
"""
Loads a quantized Nunchaku FLUX model from a state dict or file.
Parameters
----------
path_or_state_dict : str, os.PathLike, or dict
Path to the quantized model file or a state dict.
device : str or torch.device, optional
Device to load the model on (default: "cuda").
use_fp4 : bool, optional
Whether to use FP4 quantization (default: False).
offload : bool, optional
Whether to offload weights to CPU (default: False).
bf16 : bool, optional
Whether to use bfloat16 (default: True).
Returns
-------
QuantizedFluxModel
Loaded quantized model.
"""
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
assert
device
.
type
==
"cuda"
assert
device
.
type
==
"cuda"
m
=
QuantizedFluxModel
()
m
=
QuantizedFluxModel
()
...
@@ -281,6 +460,38 @@ def load_quantized_module(
...
@@ -281,6 +460,38 @@ def load_quantized_module(
class
NunchakuFluxTransformer2dModel
(
FluxTransformer2DModel
,
NunchakuModelLoaderMixin
):
class
NunchakuFluxTransformer2dModel
(
FluxTransformer2DModel
,
NunchakuModelLoaderMixin
):
"""
Nunchaku FLUX Transformer 2D Model.
This class implements a quantized transformer model compatible with the Diffusers
library, supporting LoRA, rotary embeddings, and efficient inference.
Parameters
----------
patch_size : int, optional
Patch size for input images (default: 1).
in_channels : int, optional
Number of input channels (default: 64).
out_channels : int or None, optional
Number of output channels (default: None).
num_layers : int, optional
Number of transformer layers (default: 19).
num_single_layers : int, optional
Number of single transformer layers (default: 38).
attention_head_dim : int, optional
Dimension of each attention head (default: 128).
num_attention_heads : int, optional
Number of attention heads (default: 24).
joint_attention_dim : int, optional
Joint attention dimension (default: 4096).
pooled_projection_dim : int, optional
Pooled projection dimension (default: 768).
guidance_embeds : bool, optional
Whether to use guidance embeddings (default: False).
axes_dims_rope : tuple[int], optional
Axes dimensions for rotary embeddings (default: (16, 56, 56)).
"""
@
register_to_config
@
register_to_config
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -323,6 +534,21 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -323,6 +534,21 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
@
classmethod
@
classmethod
@
utils
.
validate_hf_hub_args
@
utils
.
validate_hf_hub_args
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
[
str
],
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
[
str
],
**
kwargs
):
"""
Loads a Nunchaku FLUX transformer model from pretrained weights.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the model directory or HuggingFace repo.
**kwargs
Additional keyword arguments for device, offload, torch_dtype, precision, etc.
Returns
-------
NunchakuFluxTransformer2dModel or (NunchakuFluxTransformer2dModel, dict)
The loaded model, and optionally metadata if `return_metadata=True`.
"""
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
if
isinstance
(
device
,
str
):
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
...
@@ -395,6 +621,21 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -395,6 +621,21 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
return
transformer
return
transformer
def
inject_quantized_module
(
self
,
m
:
QuantizedFluxModel
,
device
:
str
|
torch
.
device
=
"cuda"
):
def
inject_quantized_module
(
self
,
m
:
QuantizedFluxModel
,
device
:
str
|
torch
.
device
=
"cuda"
):
"""
Injects a quantized module into the model and sets up transformer blocks.
Parameters
----------
m : QuantizedFluxModel
The quantized transformer model.
device : str or torch.device, optional
Device to run the model on (default: "cuda").
Returns
-------
self : NunchakuFluxTransformer2dModel
The model with injected quantized module.
"""
print
(
"Injecting quantized module"
)
print
(
"Injecting quantized module"
)
self
.
pos_embed
=
EmbedND
(
dim
=
self
.
inner_dim
,
theta
=
10000
,
axes_dim
=
[
16
,
56
,
56
])
self
.
pos_embed
=
EmbedND
(
dim
=
self
.
inner_dim
,
theta
=
10000
,
axes_dim
=
[
16
,
56
,
56
])
...
@@ -405,6 +646,17 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -405,6 +646,17 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
return
self
return
self
def
set_attention_impl
(
self
,
impl
:
str
):
def
set_attention_impl
(
self
,
impl
:
str
):
"""
Set the attention implementation for the quantized transformer block.
Parameters
----------
impl : str
Attention implementation to use. Supported values:
- ``"flash-attention2"`` (default): Standard FlashAttention-2.
- ``"nunchaku-fp16"``: Uses FP16 attention accumulation, up to 1.2× faster than FlashAttention-2 on NVIDIA 30-, 40-, and 50-series GPUs.
"""
block
=
self
.
transformer_blocks
[
0
]
block
=
self
.
transformer_blocks
[
0
]
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
block
.
m
.
setAttentionImpl
(
impl
)
block
.
m
.
setAttentionImpl
(
impl
)
...
@@ -412,6 +664,17 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -412,6 +664,17 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
### LoRA Related Functions
### LoRA Related Functions
def
_expand_module
(
self
,
module_name
:
str
,
new_shape
:
tuple
[
int
,
int
]):
def
_expand_module
(
self
,
module_name
:
str
,
new_shape
:
tuple
[
int
,
int
]):
"""
Expands a linear module to a new shape for LoRA compatibility.
Mostly for FLUX.1-tools LoRA which changes the input channels.
Parameters
----------
module_name : str
Name of the module to expand.
new_shape : tuple[int, int]
New shape (out_features, in_features) for the module.
"""
module
=
self
.
get_submodule
(
module_name
)
module
=
self
.
get_submodule
(
module_name
)
assert
isinstance
(
module
,
nn
.
Linear
)
assert
isinstance
(
module
,
nn
.
Linear
)
weight_shape
=
module
.
weight
.
shape
weight_shape
=
module
.
weight
.
shape
...
@@ -443,6 +706,14 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -443,6 +706,14 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
setattr
(
self
.
config
,
"in_channels"
,
new_value
)
setattr
(
self
.
config
,
"in_channels"
,
new_value
)
def
_update_unquantized_part_lora_params
(
self
,
strength
:
float
=
1
):
def
_update_unquantized_part_lora_params
(
self
,
strength
:
float
=
1
):
"""
Updates the unquantized part of the model with LoRA parameters.
Parameters
----------
strength : float, optional
LoRA scaling strength (default: 1).
"""
# check if we need to expand the linear layers
# check if we need to expand the linear layers
device
=
next
(
self
.
parameters
()).
device
device
=
next
(
self
.
parameters
()).
device
for
k
,
v
in
self
.
_unquantized_part_loras
.
items
():
for
k
,
v
in
self
.
_unquantized_part_loras
.
items
():
...
@@ -505,6 +776,18 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -505,6 +776,18 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
self
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
self
.
load_state_dict
(
new_state_dict
,
strict
=
True
)
def
update_lora_params
(
self
,
path_or_state_dict
:
str
|
dict
[
str
,
torch
.
Tensor
]):
def
update_lora_params
(
self
,
path_or_state_dict
:
str
|
dict
[
str
,
torch
.
Tensor
]):
"""
Update the model with new LoRA parameters.
Parameters
----------
path_or_state_dict : str or dict
Path to a LoRA weights file or a state dict. The path supports:
- Local file path, e.g., ``"/path/to/your/lora.safetensors"``
- HuggingFace repo with file, e.g., ``"user/repo/lora.safetensors"``
(automatically downloaded and cached)
"""
if
isinstance
(
path_or_state_dict
,
dict
):
if
isinstance
(
path_or_state_dict
,
dict
):
state_dict
=
{
state_dict
=
{
k
:
v
for
k
,
v
in
path_or_state_dict
.
items
()
k
:
v
for
k
,
v
in
path_or_state_dict
.
items
()
...
@@ -543,9 +826,20 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -543,9 +826,20 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
block
.
m
.
loadDict
(
state_dict
,
True
)
block
.
m
.
loadDict
(
state_dict
,
True
)
# This function can only be used with a single LoRA.
# For multiple LoRAs, please fuse the lora scale into the weights.
def
set_lora_strength
(
self
,
strength
:
float
=
1
):
def
set_lora_strength
(
self
,
strength
:
float
=
1
):
"""
Sets the LoRA scaling strength for the model.
Note: This function can only be used with a single LoRA. For multiple LoRAs,
please fuse the LoRA scale into the weights.
Parameters
----------
strength : float, optional
LoRA scaling strength (default: 1).
Note: This function will change the strength of all the LoRAs. So only use it when you only have a single LoRA.
"""
block
=
self
.
transformer_blocks
[
0
]
block
=
self
.
transformer_blocks
[
0
]
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
assert
isinstance
(
block
,
NunchakuFluxTransformerBlocks
)
block
.
m
.
setLoraScale
(
SVD_RANK
,
strength
)
block
.
m
.
setLoraScale
(
SVD_RANK
,
strength
)
...
@@ -556,6 +850,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -556,6 +850,10 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
block
.
m
.
loadDict
(
vector_dict
,
True
)
block
.
m
.
loadDict
(
vector_dict
,
True
)
def
reset_x_embedder
(
self
):
def
reset_x_embedder
(
self
):
"""
Resets the x_embedder module if the input channel count has changed.
This is used for removing the effect of FLUX.1-tools LoRA which changes the input channels.
"""
# if change the model in channels, we need to update the x_embedder
# if change the model in channels, we need to update the x_embedder
if
self
.
_original_in_channels
!=
self
.
config
.
in_channels
:
if
self
.
_original_in_channels
!=
self
.
config
.
in_channels
:
assert
self
.
_original_in_channels
<
self
.
config
.
in_channels
assert
self
.
_original_in_channels
<
self
.
config
.
in_channels
...
@@ -577,6 +875,9 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -577,6 +875,9 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
setattr
(
self
.
config
,
"in_channels"
,
self
.
_original_in_channels
)
setattr
(
self
.
config
,
"in_channels"
,
self
.
_original_in_channels
)
def
reset_lora
(
self
):
def
reset_lora
(
self
):
"""
Resets all LoRA parameters to their default state.
"""
unquantized_part_loras
=
{}
unquantized_part_loras
=
{}
if
len
(
self
.
_unquantized_part_loras
)
>
0
or
len
(
unquantized_part_loras
)
>
0
:
if
len
(
self
.
_unquantized_part_loras
)
>
0
or
len
(
unquantized_part_loras
)
>
0
:
self
.
_unquantized_part_loras
=
unquantized_part_loras
self
.
_unquantized_part_loras
=
unquantized_part_loras
...
@@ -606,30 +907,42 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
...
@@ -606,30 +907,42 @@ class NunchakuFluxTransformer2dModel(FluxTransformer2DModel, NunchakuModelLoader
controlnet_blocks_repeat
:
bool
=
False
,
controlnet_blocks_repeat
:
bool
=
False
,
)
->
Union
[
torch
.
FloatTensor
,
Transformer2DModelOutput
]:
)
->
Union
[
torch
.
FloatTensor
,
Transformer2DModelOutput
]:
"""
"""
Copied from diffusers.models.flux.transformer_flux.py
Forward pass for the Nunchaku FLUX transformer model.
Args:
This method is compatible with the Diffusers pipeline and supports LoRA,
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
rotary embeddings, and ControlNet.
Input `hidden_states`.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Parameters
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
----------
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
hidden_states : torch.FloatTensor
from the embeddings of input conditions.
Input hidden states of shape (batch_size, channel, height, width).
timestep ( `torch.LongTensor`):
encoder_hidden_states : torch.FloatTensor, optional
Used to indicate denoising step.
Conditional embeddings (e.g., prompt embeddings) of shape (batch_size, sequence_len, embed_dims).
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
pooled_projections : torch.FloatTensor, optional
A list of tensors that if specified are added to the residuals of transformer blocks.
Embeddings projected from the input conditions.
joint_attention_kwargs (`dict`, *optional*):
timestep : torch.LongTensor, optional
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
Denoising step.
`self.processor` in
img_ids : torch.Tensor, optional
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Image token indices.
return_dict (`bool`, *optional*, defaults to `True`):
txt_ids : torch.Tensor, optional
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
Text token indices.
tuple.
guidance : torch.Tensor, optional
Guidance tensor for classifier-free guidance.
Returns:
joint_attention_kwargs : dict, optional
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
Additional kwargs for joint attention.
`tuple` where the first element is the sample tensor.
controlnet_block_samples : list[torch.Tensor], optional
ControlNet block samples.
controlnet_single_block_samples : list[torch.Tensor], optional
ControlNet single block samples.
return_dict : bool, optional
Whether to return a Transformer2DModelOutput (default: True).
controlnet_blocks_repeat : bool, optional
Whether to repeat ControlNet blocks (default: False).
Returns
-------
torch.FloatTensor or Transformer2DModelOutput
Output tensor or output object containing the sample.
"""
"""
hidden_states
=
self
.
x_embedder
(
hidden_states
)
hidden_states
=
self
.
x_embedder
(
hidden_states
)
...
...
nunchaku/models/transformers/transformer_sana.py
View file @
cd214093
"""
Implements the :class:`NunchakuSanaTransformer2DModel`,
a quantized Sana transformer for Diffusers with efficient inference support.
"""
import
os
import
os
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Optional
from
typing
import
Optional
...
@@ -18,6 +23,22 @@ SVD_RANK = 32
...
@@ -18,6 +23,22 @@ SVD_RANK = 32
class
NunchakuSanaTransformerBlocks
(
nn
.
Module
):
class
NunchakuSanaTransformerBlocks
(
nn
.
Module
):
"""
Wrapper for quantized Sana transformer blocks.
This module wraps a QuantizedSanaModel and provides forward methods compatible
with the expected transformer block interface.
Parameters
----------
m : QuantizedSanaModel
The quantized transformer model.
dtype : torch.dtype
The data type to use for computation.
device : str or torch.device
The device to run the model on.
"""
def
__init__
(
self
,
m
:
QuantizedSanaModel
,
dtype
:
torch
.
dtype
,
device
:
str
|
torch
.
device
):
def
__init__
(
self
,
m
:
QuantizedSanaModel
,
dtype
:
torch
.
dtype
,
device
:
str
|
torch
.
device
):
super
(
NunchakuSanaTransformerBlocks
,
self
).
__init__
()
super
(
NunchakuSanaTransformerBlocks
,
self
).
__init__
()
self
.
m
=
m
self
.
m
=
m
...
@@ -35,7 +56,33 @@ class NunchakuSanaTransformerBlocks(nn.Module):
...
@@ -35,7 +56,33 @@ class NunchakuSanaTransformerBlocks(nn.Module):
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
skip_first_layer
:
Optional
[
bool
]
=
False
,
skip_first_layer
:
Optional
[
bool
]
=
False
,
):
):
"""
Forward pass through all quantized transformer blocks.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states of shape (batch_size, img_tokens, ...).
attention_mask : torch.Tensor, optional
Not used.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states of shape (batch_size, txt_tokens, ...).
encoder_attention_mask : torch.Tensor, optional
Encoder attention mask of shape (batch_size, 1, txt_tokens).
timestep : torch.LongTensor, optional
Timestep tensor.
height : int, optional
Image height.
width : int, optional
Image width.
skip_first_layer : bool, optional
Whether to skip the first layer.
Returns
-------
torch.Tensor
Output tensor after passing through the quantized transformer blocks.
"""
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
img_tokens
=
hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
...
@@ -90,6 +137,33 @@ class NunchakuSanaTransformerBlocks(nn.Module):
...
@@ -90,6 +137,33 @@ class NunchakuSanaTransformerBlocks(nn.Module):
height
:
Optional
[
int
]
=
None
,
height
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
width
:
Optional
[
int
]
=
None
,
):
):
"""
Forward pass through a specific quantized transformer layer.
Parameters
----------
idx : int
Index of the layer to run.
hidden_states : torch.Tensor
Input hidden states.
attention_mask : torch.Tensor, optional
Not used.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states.
encoder_attention_mask : torch.Tensor, optional
Encoder attention mask.
timestep : torch.LongTensor, optional
Timestep tensor.
height : int, optional
Image height.
width : int, optional
Image width.
Returns
-------
torch.Tensor
Output tensor after passing through the specified quantized transformer layer.
"""
batch_size
=
hidden_states
.
shape
[
0
]
batch_size
=
hidden_states
.
shape
[
0
]
img_tokens
=
hidden_states
.
shape
[
1
]
img_tokens
=
hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
txt_tokens
=
encoder_hidden_states
.
shape
[
1
]
...
@@ -134,13 +208,41 @@ class NunchakuSanaTransformerBlocks(nn.Module):
...
@@ -134,13 +208,41 @@ class NunchakuSanaTransformerBlocks(nn.Module):
)
)
def
__del__
(
self
):
def
__del__
(
self
):
"""
Destructor to reset the quantized model and free resources.
"""
self
.
m
.
reset
()
self
.
m
.
reset
()
class
NunchakuSanaTransformer2DModel
(
SanaTransformer2DModel
,
NunchakuModelLoaderMixin
):
class
NunchakuSanaTransformer2DModel
(
SanaTransformer2DModel
,
NunchakuModelLoaderMixin
):
"""
SanaTransformer2DModel with Nunchaku quantized backend support.
This class extends the base SanaTransformer2DModel to support loading and
injecting quantized transformer blocks using Nunchaku's custom backend.
"""
@
classmethod
@
classmethod
@
utils
.
validate_hf_hub_args
@
utils
.
validate_hf_hub_args
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
[
str
],
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
[
str
],
**
kwargs
):
"""
Load a pretrained NunchakuSanaTransformer2DModel from a local file or HuggingFace Hub.
This method supports both quantized and unquantized checkpoints, and will
automatically inject quantized transformer blocks if available.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the model checkpoint or HuggingFace Hub model name.
**kwargs
Additional keyword arguments for model loading.
Returns
-------
NunchakuSanaTransformer2DModel or (NunchakuSanaTransformer2DModel, dict)
The loaded model, and optionally metadata if ``return_metadata=True``.
"""
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
device
=
kwargs
.
get
(
"device"
,
"cuda"
)
if
isinstance
(
device
,
str
):
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
...
@@ -184,6 +286,21 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
...
@@ -184,6 +286,21 @@ class NunchakuSanaTransformer2DModel(SanaTransformer2DModel, NunchakuModelLoader
return
transformer
return
transformer
def
inject_quantized_module
(
self
,
m
:
QuantizedSanaModel
,
device
:
str
|
torch
.
device
=
"cuda"
):
def
inject_quantized_module
(
self
,
m
:
QuantizedSanaModel
,
device
:
str
|
torch
.
device
=
"cuda"
):
"""
Inject a quantized transformer module into this model.
Parameters
----------
m : QuantizedSanaModel
The quantized transformer module to inject.
device : str or torch.device, optional
The device to place the module on (default: "cuda").
Returns
-------
NunchakuSanaTransformer2DModel
The model with the quantized module injected.
"""
self
.
transformer_blocks
=
torch
.
nn
.
ModuleList
([
NunchakuSanaTransformerBlocks
(
m
,
self
.
dtype
,
device
)])
self
.
transformer_blocks
=
torch
.
nn
.
ModuleList
([
NunchakuSanaTransformerBlocks
(
m
,
self
.
dtype
,
device
)])
return
self
return
self
...
@@ -195,6 +312,27 @@ def load_quantized_module(
...
@@ -195,6 +312,27 @@ def load_quantized_module(
pag_layers
:
int
|
list
[
int
]
|
None
=
None
,
pag_layers
:
int
|
list
[
int
]
|
None
=
None
,
use_fp4
:
bool
=
False
,
use_fp4
:
bool
=
False
,
)
->
QuantizedSanaModel
:
)
->
QuantizedSanaModel
:
"""
Load quantized weights into a QuantizedSanaModel.
Parameters
----------
net : SanaTransformer2DModel
The base transformer model (for config and dtype).
path_or_state_dict : str, os.PathLike, or dict
Path to the quantized weights or a state dict.
device : str or torch.device, optional
Device to load the quantized model on (default: "cuda").
pag_layers : int, list of int, or None, optional
List of layers to use pag (default: None).
use_fp4 : bool, optional
Whether to use FP4 quantization (default: False).
Returns
-------
QuantizedSanaModel
The loaded quantized model.
"""
if
pag_layers
is
None
:
if
pag_layers
is
None
:
pag_layers
=
[]
pag_layers
=
[]
elif
isinstance
(
pag_layers
,
int
):
elif
isinstance
(
pag_layers
,
int
):
...
@@ -215,5 +353,22 @@ def load_quantized_module(
...
@@ -215,5 +353,22 @@ def load_quantized_module(
def
inject_quantized_module
(
def
inject_quantized_module
(
net
:
SanaTransformer2DModel
,
m
:
QuantizedSanaModel
,
device
:
torch
.
device
net
:
SanaTransformer2DModel
,
m
:
QuantizedSanaModel
,
device
:
torch
.
device
)
->
SanaTransformer2DModel
:
)
->
SanaTransformer2DModel
:
"""
Inject a quantized transformer module into a SanaTransformer2DModel.
Parameters
----------
net : SanaTransformer2DModel
The base transformer model.
m : QuantizedSanaModel
The quantized transformer module to inject.
device : torch.device
The device to place the module on.
Returns
-------
SanaTransformer2DModel
The model with the quantized module injected.
"""
net
.
transformer_blocks
=
torch
.
nn
.
ModuleList
([
NunchakuSanaTransformerBlocks
(
m
,
net
.
dtype
,
device
)])
net
.
transformer_blocks
=
torch
.
nn
.
ModuleList
([
NunchakuSanaTransformerBlocks
(
m
,
net
.
dtype
,
device
)])
return
net
return
net
nunchaku/models/transformers/utils.py
View file @
cd214093
"""
Utilities for Nunchaku transformer model loading.
"""
import
json
import
json
import
logging
import
logging
import
os
import
os
...
@@ -20,16 +24,38 @@ logger = logging.getLogger(__name__)
...
@@ -20,16 +24,38 @@ logger = logging.getLogger(__name__)
class
NunchakuModelLoaderMixin
:
class
NunchakuModelLoaderMixin
:
"""
Mixin for standardized model loading in Nunchaku transformer models.
Provides:
- :meth:`_build_model`: Load model from a safetensors file.
- :meth:`_build_model_legacy`: Load model from a legacy folder structure (deprecated).
"""
@
classmethod
@
classmethod
def
_build_model
(
def
_build_model
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
[
str
],
**
kwargs
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
[
str
],
**
kwargs
)
->
tuple
[
nn
.
Module
,
dict
[
str
,
torch
.
Tensor
],
dict
[
str
,
str
]]:
)
->
tuple
[
nn
.
Module
,
dict
[
str
,
torch
.
Tensor
],
dict
[
str
,
str
]]:
"""
Build a transformer model from a safetensors file.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the safetensors file.
**kwargs
Additional keyword arguments (e.g., ``torch_dtype``).
Returns
-------
tuple
(transformer, state_dict, metadata)
"""
if
isinstance
(
pretrained_model_name_or_path
,
str
):
if
isinstance
(
pretrained_model_name_or_path
,
str
):
pretrained_model_name_or_path
=
Path
(
pretrained_model_name_or_path
)
pretrained_model_name_or_path
=
Path
(
pretrained_model_name_or_path
)
state_dict
,
metadata
=
load_state_dict_in_safetensors
(
pretrained_model_name_or_path
,
return_metadata
=
True
)
state_dict
,
metadata
=
load_state_dict_in_safetensors
(
pretrained_model_name_or_path
,
return_metadata
=
True
)
# Load the config file
config
=
json
.
loads
(
metadata
[
"config"
])
config
=
json
.
loads
(
metadata
[
"config"
])
with
torch
.
device
(
"meta"
):
with
torch
.
device
(
"meta"
):
...
@@ -41,6 +67,25 @@ class NunchakuModelLoaderMixin:
...
@@ -41,6 +67,25 @@ class NunchakuModelLoaderMixin:
def
_build_model_legacy
(
def
_build_model_legacy
(
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
cls
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
,
**
kwargs
)
->
tuple
[
nn
.
Module
,
str
,
str
]:
)
->
tuple
[
nn
.
Module
,
str
,
str
]:
"""
Build a transformer model from a legacy folder structure.
.. warning::
This method is deprecated and will be removed in v0.4.
Please migrate to safetensors-based model loading.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the folder containing model weights.
**kwargs
Additional keyword arguments for HuggingFace Hub download and config loading.
Returns
-------
tuple
(transformer, unquantized_part_path, transformer_block_path)
"""
logger
.
warning
(
logger
.
warning
(
"Loading models from a folder will be deprecated in v0.4. "
"Loading models from a folder will be deprecated in v0.4. "
"Please download the latest safetensors model, or use one of the following tools to "
"Please download the latest safetensors model, or use one of the following tools to "
...
@@ -109,6 +154,25 @@ class NunchakuModelLoaderMixin:
...
@@ -109,6 +154,25 @@ class NunchakuModelLoaderMixin:
def
pad_tensor
(
tensor
:
Optional
[
torch
.
Tensor
],
multiples
:
int
,
dim
:
int
,
fill
:
Any
=
0
)
->
torch
.
Tensor
|
None
:
def
pad_tensor
(
tensor
:
Optional
[
torch
.
Tensor
],
multiples
:
int
,
dim
:
int
,
fill
:
Any
=
0
)
->
torch
.
Tensor
|
None
:
"""
Pad a tensor along a given dimension to the next multiple of a specified value.
Parameters
----------
tensor : torch.Tensor or None
Input tensor. If None, returns None.
multiples : int
Pad to this multiple. If <= 1, no padding is applied.
dim : int
Dimension along which to pad.
fill : Any, optional
Value to use for padding (default: 0).
Returns
-------
torch.Tensor or None
The padded tensor, or None if input was None.
"""
if
multiples
<=
1
:
if
multiples
<=
1
:
return
tensor
return
tensor
if
tensor
is
None
:
if
tensor
is
None
:
...
...
nunchaku/pipeline/pipeline_flux_pulid.py
View file @
cd214093
# Adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
"""
This module provides the PuLID FluxPipeline for personalized image generation with identity preservation.
It integrates face analysis, alignment, and embedding extraction using InsightFace and FaceXLib, and injects
identity embeddings into a Flux transformer pipeline.
.. note::
This module is adapted from https://github.com/ToTheBeginning/PuLID/blob/main/pulid/pipeline.py
"""
import
gc
import
gc
import
logging
import
logging
import
os
import
os
...
@@ -11,9 +20,8 @@ import numpy as np
...
@@ -11,9 +20,8 @@ import numpy as np
import
torch
import
torch
from
diffusers
import
FluxPipeline
from
diffusers
import
FluxPipeline
from
diffusers.image_processor
import
PipelineImageInput
from
diffusers.image_processor
import
PipelineImageInput
from
diffusers.pipelines.flux.pipeline_flux
import
EXAMPLE_DOC_STRING
,
calculate_shift
,
retrieve_timesteps
from
diffusers.pipelines.flux.pipeline_flux
import
calculate_shift
,
retrieve_timesteps
from
diffusers.pipelines.flux.pipeline_output
import
FluxPipelineOutput
from
diffusers.pipelines.flux.pipeline_output
import
FluxPipelineOutput
from
diffusers.utils
import
replace_example_docstring
from
facexlib.parsing
import
init_parsing_model
from
facexlib.parsing
import
init_parsing_model
from
facexlib.utils.face_restoration_helper
import
FaceRestoreHelper
from
facexlib.utils.face_restoration_helper
import
FaceRestoreHelper
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
...
@@ -39,6 +47,19 @@ logger = logging.getLogger(__name__)
...
@@ -39,6 +47,19 @@ logger = logging.getLogger(__name__)
def
check_antelopev2_dir
(
antelopev2_dirpath
:
str
|
os
.
PathLike
[
str
])
->
bool
:
def
check_antelopev2_dir
(
antelopev2_dirpath
:
str
|
os
.
PathLike
[
str
])
->
bool
:
"""
Check if the given directory contains all required AntelopeV2 ONNX model files with correct SHA256 hashes.
Parameters
----------
antelopev2_dirpath : str or os.PathLike
Path to the directory containing AntelopeV2 ONNX models.
Returns
-------
bool
True if all required files exist and have correct hashes, False otherwise.
"""
antelopev2_dirpath
=
Path
(
antelopev2_dirpath
)
antelopev2_dirpath
=
Path
(
antelopev2_dirpath
)
required_files
=
{
required_files
=
{
"1k3d68.onnx"
:
"df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc"
,
"1k3d68.onnx"
:
"df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc"
,
...
@@ -64,6 +85,53 @@ def check_antelopev2_dir(antelopev2_dirpath: str | os.PathLike[str]) -> bool:
...
@@ -64,6 +85,53 @@ def check_antelopev2_dir(antelopev2_dirpath: str | os.PathLike[str]) -> bool:
class
PuLIDPipeline
(
nn
.
Module
):
class
PuLIDPipeline
(
nn
.
Module
):
"""
PyTorch module for extracting identity embeddings using PuLID, InsightFace, and EVA-CLIP.
This class handles face detection, alignment, parsing, and embedding extraction for use in personalized
diffusion pipelines.
Parameters
----------
dit : NunchakuFluxTransformer2dModel
The transformer model to inject PuLID attention modules into.
device : str or torch.device
Device to run the pipeline on.
weight_dtype : str or torch.dtype, optional
Data type for model weights (default: torch.bfloat16).
onnx_provider : str, optional
ONNX runtime provider, "gpu" or "cpu" (default: "gpu").
pulid_path : str or os.PathLike, optional
Path to PuLID weights in safetensors format.
eva_clip_path : str or os.PathLike, optional
Path to EVA-CLIP weights.
insightface_dirpath : str or os.PathLike or None, optional
Path to InsightFace models directory.
facexlib_dirpath : str or os.PathLike or None, optional
Path to FaceXLib models directory.
Attributes
----------
pulid_encoder : IDFormer
The IDFormer encoder for identity embedding.
pulid_ca : nn.ModuleList
List of PerceiverAttentionCA modules injected into the transformer.
face_helper : FaceRestoreHelper
Helper for face alignment and parsing.
clip_vision_model : nn.Module
EVA-CLIP visual backbone.
eva_transform_mean : tuple
Mean for image normalization.
eva_transform_std : tuple
Std for image normalization.
app : FaceAnalysis
InsightFace face analysis application.
handler_ante : insightface.model_zoo.model_zoo.Model
InsightFace embedding model.
debug_img_list : list
List of debug images (for visualization).
"""
def
__init__
(
def
__init__
(
self
,
self
,
dit
:
NunchakuFluxTransformer2dModel
,
dit
:
NunchakuFluxTransformer2dModel
,
...
@@ -177,6 +245,19 @@ class PuLIDPipeline(nn.Module):
...
@@ -177,6 +245,19 @@ class PuLIDPipeline(nn.Module):
self
.
debug_img_list
=
[]
self
.
debug_img_list
=
[]
def
to_gray
(
self
,
img
):
def
to_gray
(
self
,
img
):
"""
Convert an image tensor to grayscale (3 channels).
Parameters
----------
img : torch.Tensor
Image tensor of shape (B, 3, H, W).
Returns
-------
torch.Tensor
Grayscale image tensor of shape (B, 3, H, W).
"""
x
=
0.299
*
img
[:,
0
:
1
]
+
0.587
*
img
[:,
1
:
2
]
+
0.114
*
img
[:,
2
:
3
]
x
=
0.299
*
img
[:,
0
:
1
]
+
0.587
*
img
[:,
1
:
2
]
+
0.114
*
img
[:,
2
:
3
]
x
=
x
.
repeat
(
1
,
3
,
1
,
1
)
x
=
x
.
repeat
(
1
,
3
,
1
,
1
)
return
x
return
x
...
@@ -184,8 +265,21 @@ class PuLIDPipeline(nn.Module):
...
@@ -184,8 +265,21 @@ class PuLIDPipeline(nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
get_id_embedding
(
self
,
image
,
cal_uncond
=
False
):
def
get_id_embedding
(
self
,
image
,
cal_uncond
=
False
):
"""
"""
Args:
Extract identity embedding from an RGB image.
image: numpy rgb image, range [0, 255]
Parameters
----------
image : np.ndarray
Input RGB image as a numpy array, range [0, 255].
cal_uncond : bool, optional
If True, also compute unconditional embedding (default: False).
Returns
-------
id_embedding : torch.Tensor
Identity embedding tensor.
uncond_id_embedding : torch.Tensor or None
Unconditional embedding tensor if cal_uncond is True, else None.
"""
"""
self
.
face_helper
.
clean_all
()
self
.
face_helper
.
clean_all
()
self
.
debug_img_list
=
[]
self
.
debug_img_list
=
[]
...
@@ -260,6 +354,41 @@ class PuLIDPipeline(nn.Module):
...
@@ -260,6 +354,41 @@ class PuLIDPipeline(nn.Module):
class
PuLIDFluxPipeline
(
FluxPipeline
):
class
PuLIDFluxPipeline
(
FluxPipeline
):
"""
FluxPipeline with PuLID identity embedding support.
This pipeline extends the standard FluxPipeline to support personalized image generation using
identity embeddings extracted from a reference image. It injects the PuLID identity encoder into
the transformer and supports all standard FluxPipeline features.
Parameters
----------
scheduler : SchedulerMixin
Scheduler for diffusion process.
vae : AutoencoderKL
Variational autoencoder for image encoding/decoding.
text_encoder : PreTrainedModel
Text encoder for prompt embeddings.
tokenizer : PreTrainedTokenizer
Tokenizer for text encoder.
text_encoder_2 : PreTrainedModel
Second text encoder (optional).
tokenizer_2 : PreTrainedTokenizer
Second tokenizer (optional).
transformer : NunchakuFluxTransformer2dModel
Transformer model for denoising.
image_encoder : nn.Module, optional
Image encoder for IP-Adapter (default: None).
feature_extractor : nn.Module, optional
Feature extractor for images (default: None).
pulid_device : str, optional
Device for PuLID pipeline (default: "cuda").
weight_dtype : torch.dtype, optional
Data type for model weights (default: torch.bfloat16).
onnx_provider : str, optional
ONNX runtime provider (default: "gpu").
"""
def
__init__
(
def
__init__
(
self
,
self
,
scheduler
,
scheduler
,
...
@@ -301,7 +430,6 @@ class PuLIDFluxPipeline(FluxPipeline):
...
@@ -301,7 +430,6 @@ class PuLIDFluxPipeline(FluxPipeline):
)
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
@
replace_example_docstring
(
EXAMPLE_DOC_STRING
)
def
__call__
(
def
__call__
(
self
,
self
,
prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
prompt
:
Union
[
str
,
List
[
str
]]
=
None
,
...
@@ -335,103 +463,79 @@ class PuLIDFluxPipeline(FluxPipeline):
...
@@ -335,103 +463,79 @@ class PuLIDFluxPipeline(FluxPipeline):
callback_on_step_end_tensor_inputs
:
List
[
str
]
=
[
"latents"
],
callback_on_step_end_tensor_inputs
:
List
[
str
]
=
[
"latents"
],
max_sequence_length
:
int
=
512
,
max_sequence_length
:
int
=
512
,
):
):
r
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
will be used instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
negative_prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
`text_encoder_2`. If not defined, `negative_prompt` is used in all the text-encoders.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_ip_adapter_image:
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 512): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
images.
"""
"""
Function invoked when calling the pipeline for generation.
See the parent class :class:`diffusers.FluxPipeline` for full documentation.
Parameters
----------
prompt : str or List[str], optional
The prompt(s) to guide image generation.
prompt_2 : str or List[str], optional
Second prompt(s) for dual-encoder pipelines.
negative_prompt : str or List[str], optional
Negative prompt(s) to avoid in generation.
negative_prompt_2 : str or List[str], optional
Second negative prompt(s) for dual-encoder pipelines.
true_cfg_scale : float, optional
True classifier-free guidance scale.
height : int, optional
Output image height.
width : int, optional
Output image width.
num_inference_steps : int, optional
Number of denoising steps.
sigmas : List[float], optional
Custom sigmas for the scheduler.
guidance_scale : float, optional
Classifier-free guidance scale.
num_images_per_prompt : int, optional
Number of images per prompt.
generator : torch.Generator or List[torch.Generator], optional
Random generator(s) for reproducibility.
latents : torch.FloatTensor, optional
Pre-generated latents.
prompt_embeds : torch.FloatTensor, optional
Pre-generated prompt embeddings.
pooled_prompt_embeds : torch.FloatTensor, optional
Pre-generated pooled prompt embeddings.
ip_adapter_image : PipelineImageInput, optional
Image input for IP-Adapter.
id_image : PIL.Image.Image or np.ndarray, optional
Reference image for identity embedding.
id_weight : float, optional
Weight for identity embedding.
start_step : int, optional
Step to start from (for advanced use).
ip_adapter_image_embeds : List[torch.Tensor], optional
Precomputed IP-Adapter image embeddings.
negative_ip_adapter_image : PipelineImageInput, optional
Negative image input for IP-Adapter.
negative_ip_adapter_image_embeds : List[torch.Tensor], optional
Precomputed negative IP-Adapter image embeddings.
negative_prompt_embeds : torch.FloatTensor, optional
Precomputed negative prompt embeddings.
negative_pooled_prompt_embeds : torch.FloatTensor, optional
Precomputed negative pooled prompt embeddings.
output_type : str, optional
Output format ("pil" or "np").
return_dict : bool, optional
Whether to return a dict or tuple.
joint_attention_kwargs : dict, optional
Additional kwargs for joint attention.
callback_on_step_end : Callable, optional
Callback at the end of each denoising step.
callback_on_step_end_tensor_inputs : List[str], optional
List of tensor names for callback.
max_sequence_length : int, optional
Maximum sequence length for prompts.
Returns
-------
FluxPipelineOutput or tuple
Output images and additional info.
"""
height
=
height
or
self
.
default_sample_size
*
self
.
vae_scale_factor
height
=
height
or
self
.
default_sample_size
*
self
.
vae_scale_factor
width
=
width
or
self
.
default_sample_size
*
self
.
vae_scale_factor
width
=
width
or
self
.
default_sample_size
*
self
.
vae_scale_factor
...
...
nunchaku/test.py
View file @
cd214093
"""
Test script for generating an image using the Nunchaku FLUX.1-schnell.
This script demonstrates how to load a quantized Nunchaku FLUX transformer model and
use it with the Diffusers :class:`~diffusers.FluxPipeline` to generate an image from a text prompt.
**Example usage**
.. code-block:: bash
python -m nunchaku.test
The generated image will be saved as ``flux.1-schnell.png`` in the current directory.
"""
import
torch
import
torch
from
diffusers
import
FluxPipeline
from
diffusers
import
FluxPipeline
...
...
nunchaku/utils.py
View file @
cd214093
"""
Utility functions for Nunchaku.
"""
import
hashlib
import
hashlib
import
os
import
os
import
warnings
import
warnings
...
@@ -9,6 +13,19 @@ from huggingface_hub import hf_hub_download
...
@@ -9,6 +13,19 @@ from huggingface_hub import hf_hub_download
def
sha256sum
(
filepath
:
str
|
os
.
PathLike
[
str
])
->
str
:
def
sha256sum
(
filepath
:
str
|
os
.
PathLike
[
str
])
->
str
:
"""
Compute the SHA-256 checksum of a file.
Parameters
----------
filepath : str or os.PathLike
Path to the file.
Returns
-------
str
The SHA-256 hexadecimal digest of the file.
"""
sha256
=
hashlib
.
sha256
()
sha256
=
hashlib
.
sha256
()
with
open
(
filepath
,
"rb"
)
as
f
:
with
open
(
filepath
,
"rb"
)
as
f
:
for
chunk
in
iter
(
lambda
:
f
.
read
(
8192
),
b
""
):
for
chunk
in
iter
(
lambda
:
f
.
read
(
8192
),
b
""
):
...
@@ -17,6 +34,28 @@ def sha256sum(filepath: str | os.PathLike[str]) -> str:
...
@@ -17,6 +34,28 @@ def sha256sum(filepath: str | os.PathLike[str]) -> str:
def
fetch_or_download
(
path
:
str
|
Path
,
repo_type
:
str
=
"model"
)
->
Path
:
def
fetch_or_download
(
path
:
str
|
Path
,
repo_type
:
str
=
"model"
)
->
Path
:
"""
Fetch a file from a local path or download from HuggingFace Hub if not present.
The remote path should be in the format: ``<repo_id>/<filename>`` or ``<repo_id>/<subfolder>/<filename>``.
Parameters
----------
path : str or Path
Local file path or HuggingFace Hub path.
repo_type : str, optional
Type of HuggingFace repo (default: "model").
Returns
-------
Path
Path to the local file.
Raises
------
ValueError
If the path is too short to extract repo_id and subfolder.
"""
path
=
Path
(
path
)
path
=
Path
(
path
)
if
path
.
exists
():
if
path
.
exists
():
...
@@ -29,24 +68,27 @@ def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
...
@@ -29,24 +68,27 @@ def fetch_or_download(path: str | Path, repo_type: str = "model") -> Path:
repo_id
=
"/"
.
join
(
parts
[:
2
])
repo_id
=
"/"
.
join
(
parts
[:
2
])
sub_path
=
Path
(
*
parts
[
2
:])
sub_path
=
Path
(
*
parts
[
2
:])
filename
=
sub_path
.
name
filename
=
sub_path
.
name
subfolder
=
sub_path
.
parent
if
sub_path
.
parent
!=
Path
(
"."
)
else
None
subfolder
=
str
(
sub_path
.
parent
)
if
sub_path
.
parent
!=
Path
(
"."
)
else
None
path
=
hf_hub_download
(
repo_id
=
repo_id
,
filename
=
filename
,
subfolder
=
subfolder
,
repo_type
=
repo_type
)
path
=
hf_hub_download
(
repo_id
=
repo_id
,
filename
=
filename
,
subfolder
=
subfolder
,
repo_type
=
repo_type
)
return
Path
(
path
)
return
Path
(
path
)
def
ceil_divide
(
x
:
int
,
divisor
:
int
)
->
int
:
def
ceil_divide
(
x
:
int
,
divisor
:
int
)
->
int
:
"""Ceiling division.
"""
Compute the ceiling of x divided by divisor.
Args:
Parameters
x (`int`):
----------
dividend.
x : int
divisor (`int`):
Dividend.
divisor.
divisor : int
Divisor.
Returns:
Returns
`int`:
-------
ceiling division result.
int
The smallest integer >= x / divisor.
"""
"""
return
(
x
+
divisor
-
1
)
//
divisor
return
(
x
+
divisor
-
1
)
//
divisor
...
@@ -57,6 +99,25 @@ def load_state_dict_in_safetensors(
...
@@ -57,6 +99,25 @@ def load_state_dict_in_safetensors(
filter_prefix
:
str
=
""
,
filter_prefix
:
str
=
""
,
return_metadata
:
bool
=
False
,
return_metadata
:
bool
=
False
,
)
->
dict
[
str
,
torch
.
Tensor
]
|
tuple
[
dict
[
str
,
torch
.
Tensor
],
dict
[
str
,
str
]]:
)
->
dict
[
str
,
torch
.
Tensor
]
|
tuple
[
dict
[
str
,
torch
.
Tensor
],
dict
[
str
,
str
]]:
"""
Load a state dict from a safetensors file, optionally filtering by prefix.
Parameters
----------
path : str or os.PathLike
Path to the safetensors file (local or HuggingFace Hub).
device : str or torch.device, optional
Device to load tensors onto (default: "cpu").
filter_prefix : str, optional
Only load keys starting with this prefix (default: "", no filter).
return_metadata : bool, optional
Whether to return safetensors metadata (default: False).
Returns
-------
dict[str, torch.Tensor] or tuple[dict[str, torch.Tensor], dict[str, str]]
The loaded state dict, and optionally the metadata if ``return_metadata`` is True.
"""
state_dict
=
{}
state_dict
=
{}
with
safetensors
.
safe_open
(
fetch_or_download
(
path
),
framework
=
"pt"
,
device
=
device
)
as
f
:
with
safetensors
.
safe_open
(
fetch_or_download
(
path
),
framework
=
"pt"
,
device
=
device
)
as
f
:
metadata
=
f
.
metadata
()
metadata
=
f
.
metadata
()
...
@@ -71,17 +132,20 @@ def load_state_dict_in_safetensors(
...
@@ -71,17 +132,20 @@ def load_state_dict_in_safetensors(
def
filter_state_dict
(
state_dict
:
dict
[
str
,
torch
.
Tensor
],
filter_prefix
:
str
=
""
)
->
dict
[
str
,
torch
.
Tensor
]:
def
filter_state_dict
(
state_dict
:
dict
[
str
,
torch
.
Tensor
],
filter_prefix
:
str
=
""
)
->
dict
[
str
,
torch
.
Tensor
]:
"""Filter state dict.
"""
Filter a state dict to only include keys starting with a given prefix.
Args:
Parameters
state_dict (`dict`):
----------
state dict.
state_dict : dict[str, torch.Tensor]
filter_prefix (`str`):
The input state dict.
filter prefix.
filter_prefix : str, optional
Prefix to filter keys by (default: "", no filter).
Returns:
Returns
`dict`:
-------
filtered state dict.
dict[str, torch.Tensor]
Filtered state dict with prefix removed from keys.
"""
"""
return
{
k
.
removeprefix
(
filter_prefix
):
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
filter_prefix
)}
return
{
k
.
removeprefix
(
filter_prefix
):
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
filter_prefix
)}
...
@@ -91,6 +155,28 @@ def get_precision(
...
@@ -91,6 +155,28 @@ def get_precision(
device
:
str
|
torch
.
device
=
"cuda"
,
device
:
str
|
torch
.
device
=
"cuda"
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
[
str
]
|
None
=
None
,
pretrained_model_name_or_path
:
str
|
os
.
PathLike
[
str
]
|
None
=
None
,
)
->
str
:
)
->
str
:
"""
Determine the quantization precision to use based on device and model.
Parameters
----------
precision : str, optional
"auto", "int4", or "fp4" (default: "auto").
device : str or torch.device, optional
Device to check (default: "cuda").
pretrained_model_name_or_path : str or os.PathLike or None, optional
Model name or path for warning checks.
Returns
-------
str
The selected precision ("int4" or "fp4").
Raises
------
AssertionError
If precision is not one of "auto", "int4", or "fp4".
"""
assert
precision
in
(
"auto"
,
"int4"
,
"fp4"
)
assert
precision
in
(
"auto"
,
"int4"
,
"fp4"
)
if
precision
==
"auto"
:
if
precision
==
"auto"
:
if
isinstance
(
device
,
str
):
if
isinstance
(
device
,
str
):
...
@@ -109,11 +195,18 @@ def get_precision(
...
@@ -109,11 +195,18 @@ def get_precision(
def
is_turing
(
device
:
str
|
torch
.
device
=
"cuda"
)
->
bool
:
def
is_turing
(
device
:
str
|
torch
.
device
=
"cuda"
)
->
bool
:
"""Check if the current GPU is a Turing GPU.
"""
Check if the current GPU is a Turing GPU (compute capability 7.5).
Returns:
Parameters
`bool`:
----------
True if the current GPU is a Turing GPU, False otherwise.
device : str or torch.device, optional
Device to check (default: "cuda").
Returns
-------
bool
True if the current GPU is a Turing GPU, False otherwise.
"""
"""
if
isinstance
(
device
,
str
):
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
...
@@ -124,15 +217,25 @@ def is_turing(device: str | torch.device = "cuda") -> bool:
...
@@ -124,15 +217,25 @@ def is_turing(device: str | torch.device = "cuda") -> bool:
def
get_gpu_memory
(
device
:
str
|
torch
.
device
=
"cuda"
,
unit
:
str
=
"GiB"
)
->
int
:
def
get_gpu_memory
(
device
:
str
|
torch
.
device
=
"cuda"
,
unit
:
str
=
"GiB"
)
->
int
:
"""Get the GPU memory of the current device.
"""
Get the total memory of the current GPU.
Args:
Parameters
device (`str` | `torch.device`, optional, defaults to `"cuda"`):
----------
device.
device : str or torch.device, optional
Device to check (default: "cuda").
unit : str, optional
Unit for memory ("GiB", "MiB", or "B") (default: "GiB").
Returns:
Returns
`int`:
-------
GPU memory in bytes.
int
GPU memory in the specified unit.
Raises
------
AssertionError
If unit is not one of "GiB", "MiB", or "B".
"""
"""
if
isinstance
(
device
,
str
):
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
...
@@ -147,6 +250,21 @@ def get_gpu_memory(device: str | torch.device = "cuda", unit: str = "GiB") -> in
...
@@ -147,6 +250,21 @@ def get_gpu_memory(device: str | torch.device = "cuda", unit: str = "GiB") -> in
def
check_hardware_compatibility
(
quantization_config
:
dict
,
device
:
str
|
torch
.
device
=
"cuda"
):
def
check_hardware_compatibility
(
quantization_config
:
dict
,
device
:
str
|
torch
.
device
=
"cuda"
):
"""
Check if the quantization config is compatible with the current GPU.
Parameters
----------
quantization_config : dict
Quantization configuration dictionary.
device : str or torch.device, optional
Device to check (default: "cuda").
Raises
------
ValueError
If the quantization config is not compatible with the GPU architecture.
"""
if
isinstance
(
device
,
str
):
if
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
device
=
torch
.
device
(
device
)
capability
=
torch
.
cuda
.
get_device_capability
(
0
if
device
.
index
is
None
else
device
.
index
)
capability
=
torch
.
cuda
.
get_device_capability
(
0
if
device
.
index
is
None
else
device
.
index
)
...
...
pyproject.toml
View file @
cd214093
...
@@ -34,3 +34,12 @@ build-backend = "setuptools.build_meta"
...
@@ -34,3 +34,12 @@ build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
[tool.setuptools.packages.find]
include
=
["nunchaku"]
include
=
["nunchaku"]
[tool.doc8]
max-line-length
=
120
ignore-path
=
["docs/_build"]
ignore
=
[
"D000"
,
"D001"
]
[tool.rstcheck]
ignore_directives
=
["tabs"]
ignore_messages
=
[
"ERROR/3"
,
"INFO/1"
]
src/FluxModel.cpp
View file @
cd214093
...
@@ -778,6 +778,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
...
@@ -778,6 +778,8 @@ std::tuple<Tensor, Tensor> JointTransformerBlock::forward(Tensor hidden_states,
FluxModel
::
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
)
FluxModel
::
FluxModel
(
bool
use_fp4
,
bool
offload
,
Tensor
::
ScalarType
dtype
,
Device
device
)
:
dtype
(
dtype
),
offload
(
offload
)
{
:
dtype
(
dtype
),
offload
(
offload
)
{
CUDADeviceContext
model_construction_ctx
(
device
.
idx
);
for
(
int
i
=
0
;
i
<
19
;
i
++
)
{
for
(
int
i
=
0
;
i
<
19
;
i
++
)
{
transformer_blocks
.
push_back
(
transformer_blocks
.
push_back
(
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
use_fp4
,
dtype
,
device
));
std
::
make_unique
<
JointTransformerBlock
>
(
3072
,
24
,
3072
,
false
,
use_fp4
,
dtype
,
device
));
...
...
src/Tensor.h
View file @
cd214093
...
@@ -432,6 +432,13 @@ public:
...
@@ -432,6 +432,13 @@ public:
return
*
this
;
return
*
this
;
}
}
std
::
optional
<
CUDADeviceContext
>
operation_ctx_guard
;
if
(
this
->
device
().
type
==
Device
::
CUDA
)
{
}
else
if
(
other
.
device
().
type
==
Device
::
CUDA
)
{
operation_ctx_guard
.
emplace
(
other
.
device
().
idx
);
}
if
(
this
->
device
().
type
==
Device
::
CPU
&&
other
.
device
().
type
==
Device
::
CPU
)
{
if
(
this
->
device
().
type
==
Device
::
CPU
&&
other
.
device
().
type
==
Device
::
CPU
)
{
memcpy
(
data_ptr
<
char
>
(),
other
.
data_ptr
<
char
>
(),
shape
.
size
()
*
scalar_size
());
memcpy
(
data_ptr
<
char
>
(),
other
.
data_ptr
<
char
>
(),
shape
.
size
()
*
scalar_size
());
return
*
this
;
return
*
this
;
...
...
tests/flux/test_flux_txt2img_cache_controlnet.py
0 → 100644
View file @
cd214093
import
gc
import
torch
from
diffusers
import
(
AutoencoderKL
,
FlowMatchEulerDiscreteScheduler
,
FluxControlNetModel
,
FluxControlNetPipeline
,
FluxPipeline
,
)
from
diffusers.models
import
FluxMultiControlNetModel
from
diffusers.utils
import
load_image
from
transformers
import
CLIPTextModel
,
CLIPTokenizer
,
T5EncoderModel
,
T5TokenizerFast
from
nunchaku
import
NunchakuT5EncoderModel
from
nunchaku.caching.diffusers_adapters
import
apply_cache_on_pipe
from
nunchaku.models.transformers.transformer_flux
import
NunchakuFluxTransformer2dModel
from
nunchaku.utils
import
get_precision
def
test_flux_txt2img_cache_controlnet
():
bfl_repo
=
"black-forest-labs/FLUX.1-dev"
dtype
=
torch
.
bfloat16
# or torch.float16, or torch.float32
device
=
"cuda"
# or "cpu" if you want to run on CPU
scheduler
=
FlowMatchEulerDiscreteScheduler
.
from_pretrained
(
bfl_repo
,
subfolder
=
"scheduler"
,
torch_dtype
=
dtype
)
text_encoder
=
CLIPTextModel
.
from_pretrained
(
bfl_repo
,
subfolder
=
"text_encoder"
,
torch_dtype
=
dtype
)
text_encoder_2
=
T5EncoderModel
.
from_pretrained
(
bfl_repo
,
subfolder
=
"text_encoder_2"
,
torch_dtype
=
dtype
)
tokenizer
=
CLIPTokenizer
.
from_pretrained
(
bfl_repo
,
subfolder
=
"tokenizer"
,
torch_dtype
=
dtype
,
clean_up_tokenization_spaces
=
True
)
tokenizer_2
=
T5TokenizerFast
.
from_pretrained
(
bfl_repo
,
subfolder
=
"tokenizer_2"
,
torch_dtype
=
dtype
,
clean_up_tokenization_spaces
=
True
)
vae
=
AutoencoderKL
.
from_pretrained
(
bfl_repo
,
subfolder
=
"vae"
,
torch_dtype
=
dtype
)
precision
=
get_precision
()
transformer
=
NunchakuFluxTransformer2dModel
.
from_pretrained
(
f
"mit-han-lab/nunchaku-flux.1-dev/svdq-
{
precision
}
_r32-flux.1-dev.safetensors"
,
# offload=True
)
transformer
.
set_attention_impl
(
"nunchaku-fp16"
)
# qencoder
text_encoder_2
=
NunchakuT5EncoderModel
.
from_pretrained
(
"mit-han-lab/nunchaku-t5/awq-int4-flux.1-t5xxl.safetensors"
)
controlnet_union
=
FluxControlNetModel
.
from_pretrained
(
"Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro-2.0"
,
torch_dtype
=
torch
.
bfloat16
)
controlnet
=
FluxMultiControlNetModel
(
[
controlnet_union
]
)
# we always recommend loading via FluxMultiControlNetModel
params
=
{
"scheduler"
:
scheduler
,
"vae"
:
vae
,
"tokenizer"
:
tokenizer
,
"tokenizer_2"
:
tokenizer_2
,
"text_encoder"
:
text_encoder
,
"text_encoder_2"
:
text_encoder_2
,
"transformer"
:
transformer
,
}
# pipe
pipe
=
FluxPipeline
(
**
params
).
to
(
device
,
dtype
=
dtype
)
pipe_cn
=
FluxControlNetPipeline
(
**
params
,
controlnet
=
controlnet
).
to
(
device
,
dtype
)
# offload
pipe
.
enable_sequential_cpu_offload
(
device
=
device
)
pipe_cn
.
enable_sequential_cpu_offload
(
device
=
device
)
# cache
apply_cache_on_pipe
(
pipe_cn
,
use_double_fb_cache
=
True
,
residual_diff_threshold_multi
=
0.09
,
residual_diff_threshold_single
=
0.12
,
)
params
=
{
"prompt"
:
"A bohemian-style female travel blogger with sun-kissed skin and messy beach waves."
,
"height"
:
1152
,
"width"
:
768
,
"num_inference_steps"
:
30
,
"guidance_scale"
:
3.5
,
}
# pipe
txt2img_res
=
pipe
(
**
params
,
).
images
[
0
]
txt2img_res
.
save
(
"flux.1-dev-txt2img.jpg"
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
# cache
apply_cache_on_pipe
(
pipe_cn
,
use_double_fb_cache
=
True
,
residual_diff_threshold_multi
=
0.09
,
residual_diff_threshold_single
=
0.12
,
)
# pipe_cn
control_iamge
=
load_image
(
"https://huggingface.co/Shakker-Labs/FLUX.1-dev-ControlNet-Union-Pro/resolve/main/assets/openpose.jpg"
)
params
[
"control_image"
]
=
[
control_iamge
]
params
[
"controlnet_conditioning_scale"
]
=
[
0.9
]
params
[
"control_guidance_end"
]
=
[
0.65
]
cn_res
=
pipe_cn
(
**
params
,
).
images
[
0
]
cn_res
.
save
(
"flux.1-dev-cn-txt2img.jpg"
)
tests/requirements.txt
View file @
cd214093
# additional requirements for testing
# additional requirements for testing
pytest
pytest
datasets
datasets
<4
torchmetrics
torchmetrics
mediapipe
mediapipe
controlnet_aux
controlnet_aux
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment