Unverified Commit 070c45bb authored by Muyang Li's avatar Muyang Li Committed by GitHub
Browse files

docs: add the docstrings for v1.0.0 (#656)

* add v2 flux examples

* add the docs

* add docs

* update

* finished ops

* add ops

* update

* update

* update

* update

* update

* update

* update

* update docstrings

* update

* update

* update

* update

* update

* update

* update

* finished the api docs

* update

* update
parent e0392e42
nunchaku.ops
============
.. toctree::
:maxdepth: 4
nunchaku.ops.gemm
nunchaku.ops.gemv
nunchaku.ops.quantize
nunchaku.ops.fused
......@@ -11,6 +11,7 @@ Subpackages
nunchaku.models
nunchaku.lora
nunchaku.pipeline
nunchaku.ops
nunchaku.caching
nunchaku.utils
......
Qwen-Image
==========
Original Qwen-Image
-------------------
.. image:: https://huggingface.co/datasets/nunchaku-tech/cdn/resolve/main/nunchaku/assets/qwen-image.jpg
:alt: Qwen-Image with Nunchaku
Below is a minimal example for running the 4-bit quantized `Qwen-Image <hf_qwen-image>`_ model with Nunchaku.
Nunchaku offers an API compatible with `Diffusers <github_diffusers_>`_, allowing for a familiar user experience.
.. literalinclude:: ../../../examples/v1/qwen-image.py
:language: python
:caption: Running Qwen-Image (`examples/v1/qwen-image.py <https://github.com/nunchaku-tech/nunchaku/blob/main/examples/v1/qwen-image.py>`__)
:linenos:
When using Nunchaku, replace the standard ``QwenImageTransformer2dModel`` with :class:`~nunchaku.models.transformers.transformer_qwenimage.NunchakuQwenImageTransformer2DModel`.
The :meth:`~nunchaku.models.transformers.transformer_qwenimage.NunchakuQwenImageTransformer2DModel.from_pretrained` method loads quantized models from either Hugging Face or local file paths.
.. note::
- The :func:`~nunchaku.utils.get_precision` function automatically detects whether your GPU supports INT4 or FP4 quantization.
Use FP4 models for Blackwell GPUs (RTX 50-series) and INT4 models for other architectures.
- Increasing the rank (e.g., to 128) can improve output quality.
- To reduce VRAM usage, enable asynchronous CPU offloading with :meth:`~nunchaku.models.transformers.transformer_qwenimage.NunchakuQwenImageTransformer2DModel.set_offload`. For further savings, you may also enable Diffusers' ``pipeline.enable_sequential_cpu_offload()``, but be sure to exclude ``transformer`` from offloading, as Nunchaku's offloading mechanism differs from Diffusers'. With these settings, VRAM usage can be reduced to approximately 3GB.
Distilled Qwen-Image (Qwen-Image-Lightning)
-------------------------------------------
For faster inference, we provide pre-quantized 4-step and 8-step Qwen-Image models by integrating `Qwen-Image-Lightning LoRAs <hf_qwen-image-lightning>`_.
See the example script below:
.. literalinclude:: ../../../examples/v1/qwen-image-lightning.py
:language: python
:caption: Running Qwen-Image-Lightning (`examples/v1/qwen-image-lightning.py <https://github.com/nunchaku-tech/nunchaku/blob/main/examples/v1/qwen-image-lightning.py>`__)
:linenos:
Custom LoRA support is under development.
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2DModelV2
from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
f"nunchaku-tech/nunchaku-flux.1-dev/svdq-{precision}_r32-flux.1-dev.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=50, guidance_scale=3.5).images[0]
image.save(f"flux.1-dev-{precision}.png")
import torch
from diffusers import FluxPipeline
from nunchaku import NunchakuFluxTransformer2DModelV2
from nunchaku.utils import get_precision
precision = get_precision() # auto-detect your precision is 'int4' or 'fp4' based on your GPU
transformer = NunchakuFluxTransformer2DModelV2.from_pretrained(
f"nunchaku-tech/nunchaku-flux.1-schnell/svdq-{precision}_r32-flux.1-schnell.safetensors"
)
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
image = pipeline(
"A cat holding a sign that says hello world",
width=1024,
height=1024,
num_inference_steps=4,
guidance_scale=0,
).images[0]
image.save(f"flux.1-schnell-{precision}.png")
from .models import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel, NunchakuT5EncoderModel
from .models import (
NunchakuFluxTransformer2dModel,
NunchakuFluxTransformer2DModelV2,
NunchakuQwenImageTransformer2DModel,
NunchakuSanaTransformer2DModel,
NunchakuT5EncoderModel,
)
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel", "NunchakuT5EncoderModel"]
__all__ = [
"NunchakuFluxTransformer2dModel",
"NunchakuSanaTransformer2DModel",
"NunchakuT5EncoderModel",
"NunchakuFluxTransformer2DModelV2",
"NunchakuQwenImageTransformer2DModel",
]
from .text_encoders.t5_encoder import NunchakuT5EncoderModel
from .transformers import NunchakuFluxTransformer2dModel, NunchakuSanaTransformer2DModel
from .transformers import (
NunchakuFluxTransformer2dModel,
NunchakuFluxTransformer2DModelV2,
NunchakuQwenImageTransformer2DModel,
NunchakuSanaTransformer2DModel,
)
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel", "NunchakuT5EncoderModel"]
__all__ = [
"NunchakuFluxTransformer2dModel",
"NunchakuSanaTransformer2DModel",
"NunchakuT5EncoderModel",
"NunchakuFluxTransformer2DModelV2",
"NunchakuQwenImageTransformer2DModel",
]
"""
Nunchaku quantized attention-related modules.
"""
import torch
from diffusers.models.activations import GELU
from diffusers.models.attention import FeedForward
......@@ -8,16 +12,59 @@ from .linear import SVDQW4A4Linear
class NunchakuBaseAttention(nn.Module):
"""
Base class for Nunchaku attention modules.
Provides a common interface for attention modules with processor selection.
Parameters
----------
processor : str, optional
Name of the attention processor to use. Default is "flashattn2".
*args, **kwargs :
Additional arguments for subclass initialization.
"""
def __init__(self, processor: str = "flashattn2", *args, **kwargs):
super(NunchakuBaseAttention, self).__init__()
self.processor = None
self.set_processor(processor)
def set_processor(self, processor: str):
"""
Set the attention processor. Must be implemented by subclasses.
Parameters
----------
processor : str
Name of the processor to use.
Raises
------
NotImplementedError
If not implemented in subclass.
"""
raise NotImplementedError("Subclass must implement this method")
def _patch_linear(module: nn.Module, linear_cls, **kwargs) -> nn.Module:
"""
Recursively replace all nn.Linear modules in a given module with a custom linear class.
Parameters
----------
module : nn.Module
The module to patch.
linear_cls : type
The custom linear class to use for replacement.
**kwargs :
Additional arguments passed to ``from_linear``.
Returns
-------
nn.Module
The patched module with custom linear layers.
"""
for name, child in module.named_children():
if isinstance(child, nn.Linear):
setattr(module, name, linear_cls.from_linear(child, **kwargs))
......@@ -27,17 +74,50 @@ def _patch_linear(module: nn.Module, linear_cls, **kwargs) -> nn.Module:
class NunchakuFeedForward(FeedForward):
"""
Quantized feed-forward (MLP) block with fused GELU support.
Replaces linear layers in a FeedForward block with :class:`~nunchaku.models.linear.SVDQW4A4Linear` for quantized inference.
Supports fused GELU-MLP computation for efficiency.
Parameters
----------
ff : FeedForward
Source FeedForward block to quantize.
**kwargs :
Additional arguments for SVDQW4A4Linear.
Notes
-----
For int4 quantization, the activation of the second MLP layer is shifted to be unsigned.
"""
def __init__(self, ff: FeedForward, **kwargs):
super(FeedForward, self).__init__()
self.net = _patch_linear(ff.net, SVDQW4A4Linear, **kwargs)
# for int4, we shift the activation of mlp_fc2 to make it unsigned
# For int4, shift the activation of mlp_fc2 to make it unsigned
self.net[2].act_unsigned = self.net[2].precision != "nvfp4"
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the quantized feed-forward block.
It will call :func:`~nunchaku.ops.fused.fused_gelu_mlp` if the first layer is GELU;
otherwise, apply modules sequentially.
Parameters
----------
hidden_states : torch.Tensor, shape (B, D)
Input tensor.
Returns
-------
torch.Tensor, shape (B, D)
Output tensor after feed-forward transformation.
"""
if isinstance(self.net[0], GELU):
return fused_gelu_mlp(hidden_states, self.net[0].proj, self.net[2])
else:
# fallback to original implementation
# Fallback to original implementation
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
"""
Attention processor implementations for :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxAttention`.
"""
import math
from typing import Optional, Tuple
......@@ -9,6 +13,12 @@ from ...ops.fused import fused_qkv_norm_rottary
class NunchakuFluxFA2Processor:
"""
Fused attention processor using PyTorch's scaled dot-product attention.
This processor applies fused QKV projection, normalization, and rotary embedding,
then computes attention using PyTorch's built-in scaled_dot_product_attention.
"""
def __call__(
self,
......@@ -19,7 +29,40 @@ class NunchakuFluxFA2Processor:
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor = None,
**kwargs,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
# Adapted from https://github.com/huggingface/diffusers/blob/50dea89dc6036e71a00bc3d57ac062a80206d9eb/src/diffusers/models/attention_processor.py#L2275
"""
Forward pass for fused attention.
Parameters
----------
attn : :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxAttention`
Attention module.
hidden_states : torch.Tensor, shape (B, T, C), dtype float32/float16
Input hidden states.
encoder_hidden_states : Optional[torch.Tensor], shape (B, T_enc, C), optional
Encoder hidden states for cross/joint attention.
attention_mask : Optional[torch.Tensor], optional
Not supported. Must be None.
image_rotary_emb : tuple or torch.Tensor
Rotary embeddings. Tuple for joint attention, tensor for single stream.
**kwargs
Additional arguments (unused).
Returns
-------
out : torch.Tensor or tuple of torch.Tensor
Output hidden states. If joint attention, returns (img_out, txt_out).
Raises
------
NotImplementedError
If attention_mask is not None.
Notes
-----
- B: batch size
- T: sequence length
- C: channels (C = heads * head_dim)
"""
if attention_mask is not None:
raise NotImplementedError("attention_mask is not supported")
......@@ -69,6 +112,15 @@ class NunchakuFluxFA2Processor:
class NunchakuFluxFP16AttnProcessor:
"""
Fused attention processor with custom Nunchaku FP16 accumulation.
This is faster than the :class:`~nunchaku.models.attention_processors.flux.NunchakuFluxFA2Processor`.
Parameters
----------
pad_size : int, optional
Padding size for sequence length. Default is 256.
"""
def __init__(self, pad_size: int = 256):
self.pad_size = pad_size
......@@ -82,6 +134,40 @@ class NunchakuFluxFP16AttnProcessor:
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor = None,
**kwargs,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for Nunchaku FP16 attention.
Parameters
----------
attn : :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxAttention`
Attention module.
hidden_states : torch.Tensor, shape (B, T, C), dtype float16
Input hidden states.
encoder_hidden_states : Optional[torch.Tensor], shape (B, T_enc, C), optional
Encoder hidden states for cross/joint attention.
attention_mask : Optional[torch.Tensor], optional
Not supported. Must be None.
image_rotary_emb : tuple or torch.Tensor
Rotary embeddings. Tuple for joint attention, tensor for single stream.
**kwargs
Additional arguments (unused).
Returns
-------
out : torch.Tensor or tuple of torch.Tensor
Output hidden states. If joint attention, returns (img_out, txt_out).
Raises
------
AssertionError
If input shapes or types are incompatible.
Notes
-----
- B: batch size
- T: sequence length
- C: channels (C = heads * head_dim)
"""
pad_size = self.pad_size
batch_size, _, channels = hidden_states.shape
assert channels == attn.heads * attn.head_dim
......
"""
Attention processors for :class:`~nunchaku.models.transformers.transformer_qwenimage.NunchakuQwenAttention`.
"""
from typing import Optional, Tuple
import torch
......@@ -6,23 +10,62 @@ from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb
class NunchakuQwenImageNaiveFA2Processor:
"""
Naive attention processor for Qwen-Image joint text-image attention.
"""
def __call__(
self,
attn,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor | Tuple[torch.Tensor, torch.Tensor]:
# Adapted from https://github.com/huggingface/diffusers/blob/baa9b582f348e52aa2fc245e366611f454e1082b/src/diffusers/models/transformers/transformer_qwenimage.py#L246
"""
Forward pass for joint text-image attention.
Parameters
----------
attn : :class:`~nunchaku.models.transformers.transformer_qwenimage.NunchakuQwenAttention`
Attention module.
hidden_states : torch.FloatTensor, shape (B, L, H*D)
Image stream hidden states.
encoder_hidden_states : torch.FloatTensor, shape (B, L_txt, H*D)
Text stream hidden states.
encoder_hidden_states_mask : torch.FloatTensor, optional
Not used.
attention_mask : Optional[torch.FloatTensor], shape (B, 1, L_total, L_total), optional
Attention mask for joint attention.
image_rotary_emb : Optional[Tuple[torch.Tensor, torch.Tensor]]
Tuple of rotary embeddings for image and text streams.
Returns
-------
img_attn_output : torch.Tensor, shape (B, L, H*D)
Output for image stream after attention and projection.
txt_attn_output : torch.Tensor, shape (B, L_txt, H*D)
Output for text stream after attention and projection.
Raises
------
ValueError
If ``encoder_hidden_states`` (text stream) is not provided.
Notes
-----
- B: batch size
- L: sequence length (image)
- L_txt: sequence length (text)
- H: number of attention heads
- D: head dimension
"""
if encoder_hidden_states is None:
raise ValueError("NunchakuQwenImageFA2Processor requires encoder_hidden_states (text stream)")
seq_txt = encoder_hidden_states.shape[1]
# TODO: fuse the QKV, norm and RoPE in a single kernel to boost the performance
# Compute QKV for image stream (sample projections)
img_qkv = attn.to_qkv(hidden_states)
img_query, img_key, img_value = img_qkv.chunk(3, dim=-1)
......@@ -50,7 +93,7 @@ class NunchakuQwenImageNaiveFA2Processor:
assert attn.norm_added_k is not None
txt_key = attn.norm_added_k(txt_key)
# Apply RoPE
# Apply rotary embeddings
if image_rotary_emb is not None:
img_freqs, txt_freqs = image_rotary_emb
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
......@@ -58,8 +101,7 @@ class NunchakuQwenImageNaiveFA2Processor:
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
# Concatenate for joint attention
# Order: [text, image]
# Concatenate for joint attention: [text, image]
joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
......@@ -74,13 +116,6 @@ class NunchakuQwenImageNaiveFA2Processor:
is_causal=False,
backend=None,
)
# joint_query = joint_query.transpose(1, 2)
# joint_key = joint_key.transpose(1, 2)
# joint_value = joint_value.transpose(1, 2)
# joint_hidden_states = F.scaled_dot_product_attention(
# joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
# )
# joint_hidden_states = joint_hidden_states.transpose(1, 2)
# Reshape back
joint_hidden_states = joint_hidden_states.flatten(2, 3)
......
"""
Embedding layers for Nunchaku.
"""
import diffusers
import torch
from packaging.version import Version
......@@ -11,8 +15,8 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
Parameters
----------
pos : torch.Tensor
Position tensor of shape (..., n).
pos : torch.Tensor, shape (..., n), dtype int
Position indices.
dim : int
Embedding dimension (must be even).
theta : int
......@@ -20,8 +24,14 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
Returns
-------
torch.Tensor
out : torch.Tensor, shape (B, M, D//2, 1, 2), dtype float32
Rotary embedding tensor.
Notes
-----
- B: batch size
- M: sequence length
- D: embedding dimension
"""
assert dim % 2 == 0, "The dimension must be even."
......@@ -31,21 +41,18 @@ def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
batch_size, seq_length = pos.shape
out = torch.einsum("...n,d->...nd", pos, omega)
USE_SINCOS = True
if USE_SINCOS:
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 1, 2)
else:
out = out.view(batch_size, -1, dim // 2, 1, 1)
# Sin/cos representation for rotary embedding
cos_out = torch.cos(out)
sin_out = torch.sin(out)
stacked_out = torch.stack([sin_out, cos_out], dim=-1)
out = stacked_out.view(batch_size, -1, dim // 2, 1, 2)
return out.float()
class NunchakuFluxPosEmbed(nn.Module):
"""
Multi-dimensional rotary embedding module.
Nunchaku multi-dimensional rotary embedding module for FLUX.
Adapted from https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/transformers/transformer_flux.py#L55
Parameters
......@@ -54,8 +61,8 @@ class NunchakuFluxPosEmbed(nn.Module):
Embedding dimension.
theta : int
Rotary base.
axes_dim : list[int]
List of axis dimensions for each spatial axis.
axes_dim : list of int
Dimension for each spatial axis.
"""
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
......@@ -66,17 +73,22 @@ class NunchakuFluxPosEmbed(nn.Module):
def forward(self, ids: torch.Tensor) -> torch.Tensor:
"""
Computes rotary embeddings for multi-dimensional positions.
Compute rotary embeddings for multi-dimensional positions.
Parameters
----------
ids : torch.Tensor
Position indices tensor of shape (..., n_axes).
ids : torch.Tensor, shape (..., n_axes), dtype int
Position indices.
Returns
-------
torch.Tensor
out : torch.Tensor, shape (B, 1, ...), dtype float32
Rotary embedding tensor.
Notes
-----
- B: batch size
- n_axes: number of spatial axes
"""
if Version(diffusers.__version__) >= Version("0.31.0"):
ids = ids[None, ...]
......@@ -87,17 +99,23 @@ class NunchakuFluxPosEmbed(nn.Module):
def pack_rotemb(rotemb: torch.Tensor) -> torch.Tensor:
"""
Packs rotary embeddings for efficient computation.
Pack rotary embeddings for efficient CUDA computation.
Parameters
----------
rotemb : torch.Tensor
Rotary embedding tensor of shape (B, M, D//2, 1, 2), dtype float32.
rotemb : torch.Tensor, shape (B, M, D//2, 1, 2), dtype float32
Rotary embedding tensor.
Returns
-------
torch.Tensor
Packed rotary embedding tensor of shape (B, M, D).
packed : torch.Tensor, shape (B, M, D), dtype float32
Packed rotary embedding tensor.
Notes
-----
- B: batch size
- M: sequence length (must be divisible by 16)
- D: embedding dimension (must be divisible by 8)
"""
assert rotemb.dtype == torch.float32
B = rotemb.shape[0]
......
"""
Quantized linear layers for Nunchaku.
"""
import torch
from torch import nn
......@@ -7,6 +11,60 @@ from ..ops.quantize import svdq_quantize_w4a4_act_fuse_lora_cuda
class SVDQW4A4Linear(nn.Module):
"""
`SVDQuant <paper_svdquant_>`_ W4A4 quantized linear layer.
Parameters
----------
in_features : int
Input feature dimension.
out_features : int
Output feature dimension.
rank : int, optional
SVD low-rank dimension. Default is 32.
bias : bool, optional
If True, adds a learnable bias. Default is True.
precision : {'int4', 'nvfp4'}, optional
Quantization precision data type ('int4' or 'nvfp4'). Default is 'int4'.
act_unsigned : bool, optional
If True, use unsigned activation quantization (int4 only). Default is False.
torch_dtype : torch.dtype, optional
Parameter dtype. Default is torch.bfloat16.
device : str or torch.device or None, optional
Device for parameters. Default is CPU.
Attributes
----------
in_features : int
out_features : int
rank : int
precision : str
'int4' or 'nvfp4'.
group_size : int
64 for int4, 16 for nvfp4.
qweight : nn.Parameter
Packed quantized weights, shape (out_features, in_features // 2), dtype int8.
bias : nn.Parameter or None
Bias tensor.
wscales : nn.Parameter
Weight scales, shape (in_features // group_size, out_features).
Dtype: bfloat16/float16 (int4), float8_e4m3fn (nvfp4).
smooth_factor : nn.Parameter
Smoothing factors, shape (in_features,).
smooth_factor_orig : nn.Parameter
Original smoothing factors, shape (in_features,). (Unused)
proj_down : nn.Parameter
Packed low-rank down projection, shape (in_features, rank), dtype bfloat16/float16.
proj_up : nn.Parameter
Packed low-rank up projection, shape (out_features, rank), dtype bfloat16/float16.
wtscale : float or None
Global weight scale (nvfp4 only).
wcscales : nn.Parameter or None
Channel-wise weight scale (nvfp4 only), shape (out_features,), dtype float8_e4m3fn.
act_unsigned : bool
If True, input activations are unsigned (int4 only).
"""
def __init__(
self,
in_features: int,
......@@ -27,7 +85,6 @@ class SVDQW4A4Linear(nn.Module):
self.precision = precision
self.torch_dtype = torch_dtype
self.group_size = None
if precision == "nvfp4":
self.group_size = 16
......@@ -77,6 +134,20 @@ class SVDQW4A4Linear(nn.Module):
@classmethod
def from_linear(cls, linear: nn.Linear, **kwargs):
"""
Create an SVDQW4A4Linear from a standard nn.Linear. The weight and bias are dummy tensors.
Parameters
----------
linear : nn.Linear
Source linear layer.
**kwargs
Additional init arguments.
Returns
-------
SVDQW4A4Linear
"""
in_features = kwargs.pop("in_features", linear.in_features)
return cls(
in_features=in_features,
......@@ -88,7 +159,25 @@ class SVDQW4A4Linear(nn.Module):
)
def forward(self, x: torch.Tensor, output: torch.Tensor | None = None) -> torch.Tensor:
# quantize the input run the down projection
"""
Forward pass with 16-bit input. It will call :meth:`quantize` and :meth:`forward_quant`.
Parameters
----------
x : torch.Tensor, shape (B, S, in_features), dtype float16 or bfloat16
Input tensor.
output : torch.Tensor or None, optional
Optional output buffer.
Returns
-------
torch.Tensor, shape (B, S, out_features)
Output tensor.
Notes
-----
B: batch size, S: sequence length
"""
batch_size, seq_len, channels = x.shape
x = x.view(batch_size * seq_len, channels)
if output is None:
......@@ -98,9 +187,32 @@ class SVDQW4A4Linear(nn.Module):
output = output.view(batch_size, seq_len, -1)
return output
def quantize(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
def quantize(self, x: torch.Tensor, pad_size: int = 256) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Quantize input to 4-bit and compute low-rank hidden states. It will call :func:`~nunchaku.ops.quantize.svdq_quantize_w4a4_act_fuse_lora_cuda`.
Parameters
----------
x : torch.Tensor, shape (N, in_features), dtype float16 or bfloat16
Input tensor.
pad_size : int, optional
Batch padding size. Default is 256.
Returns
-------
quantized_x : torch.Tensor
Quantized input, shape (pad_size * ceil(N / pad_size), in_features // 2), dtype uint8.
ascales : torch.Tensor
Activation scales, shape (in_features // group_size,), dtype float8_e4m3fn for nvfp4 and input dtype for int4.
lora_act_out : torch.Tensor
Low-rank hidden states, shape (pad_size * ceil(N / pad_size), rank), dtype float32.
Notes
-----
N: batch size
"""
quantized_x, ascales, lora_act_out = svdq_quantize_w4a4_act_fuse_lora_cuda(
x, lora_down=self.proj_down, smooth=self.smooth_factor, fp4=self.precision == "nvfp4"
x, lora_down=self.proj_down, smooth=self.smooth_factor, fp4=self.precision == "nvfp4", pad_size=pad_size
)
return quantized_x, ascales, lora_act_out
......@@ -111,6 +223,29 @@ class SVDQW4A4Linear(nn.Module):
lora_act: torch.Tensor,
output: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Forward pass with pre-quantized input. It will call :func:`~nunchaku.ops.gemm.svdq_gemm_w4a4_cuda`.
Parameters
----------
quantized_x : torch.Tensor
Quantized input, shape (N, in_features // 2), dtype uint8.
ascales : torch.Tensor
Activation scales, shape (in_features // group_size,), dtype float8_e4m3fn for nvfp4 and input dtype for int4.
lora_act : torch.Tensor
Low-rank hidden states, shape (N, rank), dtype float32.
output : torch.Tensor or None, optional
Optional output buffer.
Returns
-------
torch.Tensor
Output tensor, shape (N, out_features), dtype bfloat16/float16 for int4 and float8_e4m3fn for nvfp4.
Notes
-----
N: batch size
"""
if output is None:
output = torch.empty(
quantized_x.shape[0], self.out_features, dtype=self.proj_up.dtype, device=quantized_x.device
......@@ -133,10 +268,46 @@ class SVDQW4A4Linear(nn.Module):
return output
def __repr__(self):
return f"SVDQW4A4Linear(in_features={self.in_features}, out_features={self.out_features}, rank={self.rank}, precision={self.precision}, act_unsigned={self.act_unsigned})"
return (
f"SVDQW4A4Linear(in_features={self.in_features}, out_features={self.out_features}, "
f"rank={self.rank}, precision={self.precision}, act_unsigned={self.act_unsigned})"
)
class AWQW4A16Linear(nn.Module):
"""
`AWQ <paper_awq_>`_ W4A16 quantized linear layer.
Parameters
----------
in_features : int
Input feature dimension.
out_features : int
Output feature dimension.
bias : bool, optional
If True, adds learnable bias. Default is True.
group_size : int, optional
Quantization group size. Default is 64.
torch_dtype : torch.dtype, optional
Parameter dtype. Default is torch.bfloat16.
device : str or torch.device or None, optional
Device for parameters. Default is CPU.
Attributes
----------
in_features : int
out_features : int
group_size : int
qweight : nn.Parameter
Packed quantized weights, shape (out_features // 4, in_features // 2), dtype int32.
bias : nn.Parameter or None
Bias tensor.
wscales : nn.Parameter
Weight scales, shape (in_features // group_size, out_features), dtype float16 or bfloat16.
wzeros : nn.Parameter
Weight zero points, shape (in_features // group_size, out_features), dtype float16 or bfloat16.
"""
def __init__(
self,
in_features: int,
......@@ -151,7 +322,6 @@ class AWQW4A16Linear(nn.Module):
device = torch.device("cpu")
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size
self.qweight = nn.Parameter(
......@@ -172,6 +342,23 @@ class AWQW4A16Linear(nn.Module):
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for AWQW4A16Linear.
Parameters
----------
x : torch.Tensor, shape (N, in_features)
Input tensor.
Returns
-------
torch.Tensor, shape (N, out_features)
Output tensor.
Notes
-----
N: batch size
"""
output = awq_gemv_w4a16_cuda(
in_feats=x,
kernel=self.qweight,
......@@ -196,6 +383,24 @@ class AWQW4A16Linear(nn.Module):
device: str = "cpu",
**kwargs,
):
"""
Create an uninitialized AWQW4A16Linear from a standard nn.Linear.
Parameters
----------
linear : nn.Linear
Source linear layer.
group_size : int, optional
Quantization group size.
torch_dtype : torch.dtype, optional
Parameter dtype.
device : str, optional
Device for parameters.
Returns
-------
AWQW4A16Linear
"""
return cls(
in_features=linear.in_features,
out_features=linear.out_features,
......
"""
Quantized normalization layers for efficient inference.
"""
from typing import Optional, Tuple
import torch
......@@ -7,11 +11,28 @@ from .linear import AWQW4A16Linear
class NunchakuAdaLayerNormZero(AdaLayerNormZero):
"""
Nunchaku quantized AdaLayerNormZero for diffusion models.
Replaces the linear projection with AWQW4A16Linear for quantized inference.
Parameters
----------
other : AdaLayerNormZero
Source AdaLayerNormZero instance to copy weights and structure from.
scale_shift : float, optional
Value to add to scale parameters. Default is 1.0.
Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
Notes
-----
- B: batch size
- D: hidden dimension
"""
def __init__(self, other: AdaLayerNormZero, scale_shift: float = 1.0):
super(AdaLayerNormZero, self).__init__()
self.scale_shift = scale_shift
self.emb = other.emb
self.silu = other.silu
self.linear = AWQW4A16Linear.from_linear(other.linear)
......@@ -25,6 +46,40 @@ class NunchakuAdaLayerNormZero(AdaLayerNormZero):
hidden_dtype: Optional[torch.dtype] = None,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass for quantized AdaLayerNormZero.
Parameters
----------
x : torch.Tensor, shape (B, D), dtype float32/float16
Input tensor.
timestep : Optional[torch.Tensor], shape (B,) or (1,), optional
Timestep embedding input.
class_labels : Optional[torch.LongTensor], shape (B,) or (1,), optional
Class label input.
hidden_dtype : Optional[torch.dtype], optional
Dtype for embedding computation.
emb : Optional[torch.Tensor], shape (B, E), optional
Precomputed embedding. If None, computed from timestep and class_labels.
Returns
-------
norm_x_scaled : torch.Tensor, shape (B, D)
Normalized and scaled input.
gate_msa : torch.Tensor, shape (B, D)
Gate for MSA branch.
shift_mlp : torch.Tensor, shape (B, D)
Shift for MLP branch.
scale_mlp : torch.Tensor, shape (B, D)
Scale for MLP branch.
gate_mlp : torch.Tensor, shape (B, D)
Gate for MLP branch.
Notes
-----
- B: batch size
- D: hidden dimension
"""
if self.emb is not None:
emb = self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)
emb = self.linear(self.silu(emb))
......@@ -44,10 +99,27 @@ class NunchakuAdaLayerNormZero(AdaLayerNormZero):
class NunchakuAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
"""
Nunchaku quantized AdaLayerNormZeroSingle.
Uses AWQW4A16Linear for quantized embedding projection. Suitable for single-branch normalization.
Parameters
----------
other : AdaLayerNormZeroSingle
Source AdaLayerNormZeroSingle instance to copy weights and structure from.
scale_shift : float, optional
Value to add to scale parameters. Default is 1.0.
Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
Notes
-----
- B: batch size
- D: hidden dimension
"""
def __init__(self, other: AdaLayerNormZeroSingle, scale_shift: float = 1.0):
super(AdaLayerNormZeroSingle, self).__init__()
self.scale_shift = scale_shift
self.silu = other.silu
self.linear = AWQW4A16Linear.from_linear(other.linear)
......@@ -57,7 +129,29 @@ class NunchakuAdaLayerNormZeroSingle(AdaLayerNormZeroSingle):
self,
x: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for quantized AdaLayerNormZeroSingle.
Parameters
----------
x : torch.Tensor, shape (B, D), dtype float32/float16
Input tensor.
emb : Optional[torch.Tensor], shape (B, E), optional
Embedding tensor.
Returns
-------
norm_x_scaled : torch.Tensor, shape (B, D)
Normalized and scaled input.
gate_msa : torch.Tensor, shape (B, D)
Gate for MSA branch.
Notes
-----
- B: batch size
- D: hidden dimension
"""
emb = self.linear(self.silu(emb))
# The weight layout has changed; use split_mod rather than chunk to separate the embedding.
......
from .transformer_flux import NunchakuFluxTransformer2dModel
from .transformer_flux_v2 import NunchakuFluxTransformer2DModelV2
from .transformer_qwenimage import NunchakuQwenImageTransformer2DModel
from .transformer_sana import NunchakuSanaTransformer2DModel
__all__ = ["NunchakuFluxTransformer2dModel", "NunchakuSanaTransformer2DModel"]
__all__ = [
"NunchakuFluxTransformer2dModel",
"NunchakuSanaTransformer2DModel",
"NunchakuFluxTransformer2DModelV2",
"NunchakuQwenImageTransformer2DModel",
]
"""
This module provides Nunchaku FluxTransformer2DModel and its building blocks in Python.
"""
import json
import os
from pathlib import Path
......@@ -26,9 +30,21 @@ from .utils import NunchakuModelLoaderMixin
class NunchakuFluxAttention(NunchakuBaseAttention):
def __init__(self, other: FluxAttention, processor: str = "nunchaku-fp16", **kwargs):
"""
Nunchaku-optimized FluxAttention module with quantized and fused QKV projections.
Parameters
----------
other : FluxAttention
The original FluxAttention module to wrap and quantize.
processor : str, optional
The attention processor to use ("flashattn2" or "nunchaku-fp16").
**kwargs
Additional arguments for quantization.
"""
def __init__(self, other: FluxAttention, processor: str = "flashattn2", **kwargs):
super(NunchakuFluxAttention, self).__init__(processor)
self.head_dim = other.head_dim
self.inner_dim = other.inner_dim
self.query_dim = other.query_dim
......@@ -44,7 +60,7 @@ class NunchakuFluxAttention(NunchakuBaseAttention):
self.norm_q = other.norm_q
self.norm_k = other.norm_k
# fuse the qkv
# Fuse the QKV projections for efficiency.
with torch.device("meta"):
to_qkv = fuse_linears([other.to_q, other.to_k, other.to_v])
self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs)
......@@ -57,7 +73,7 @@ class NunchakuFluxAttention(NunchakuBaseAttention):
self.norm_added_q = other.norm_added_q
self.norm_added_k = other.norm_added_k
# fuse the add_qkv
# Fuse the additional QKV projections.
with torch.device("meta"):
add_qkv_proj = fuse_linears([other.add_q_proj, other.add_k_proj, other.add_v_proj])
self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, **kwargs)
......@@ -71,6 +87,26 @@ class NunchakuFluxAttention(NunchakuBaseAttention):
image_rotary_emb: Tuple[torch.Tensor, torch.Tensor] | torch.Tensor = None,
**kwargs,
):
"""
Forward pass for NunchakuFluxAttention.
Parameters
----------
hidden_states : torch.Tensor
Input tensor.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states for cross-attention.
attention_mask : torch.Tensor, optional
Attention mask.
image_rotary_emb : tuple or torch.Tensor, optional
Rotary embeddings for image/text tokens.
**kwargs
Additional arguments.
Returns
-------
Output of the attention processor.
"""
return self.processor(
attn=self,
hidden_states=hidden_states,
......@@ -80,6 +116,22 @@ class NunchakuFluxAttention(NunchakuBaseAttention):
)
def set_processor(self, processor: str):
"""
Set the attention processor.
Parameters
----------
processor : str
Name of the processor ("flashattn2" or "nunchaku-fp16").
- ``"flashattn2"``: Standard FlashAttention-2. See :class:`~nunchaku.models.attention_processors.flux.NunchakuFluxFA2Processor`.
- ``"nunchaku-fp16"``: Uses FP16 attention accumulation, up to 1.2× faster than FlashAttention-2 on NVIDIA 30-, 40-, and 50-series GPUs. See :class:`~nunchaku.models.attention_processors.flux.NunchakuFluxFP16AttnProcessor`.
Raises
------
ValueError
If the processor is not supported.
"""
if processor == "flashattn2":
self.processor = NunchakuFluxFA2Processor()
elif processor == "nunchaku-fp16":
......@@ -89,6 +141,19 @@ class NunchakuFluxAttention(NunchakuBaseAttention):
class NunchakuFluxTransformerBlock(FluxTransformerBlock):
"""
Nunchaku-optimized FluxTransformerBlock with quantized attention and feedforward layers.
Parameters
----------
block : FluxTransformerBlock
The original block to wrap and quantize.
scale_shift : float, optional
Value to add to scale parameters. Default is 1.0.
Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
**kwargs
Additional arguments for quantization.
"""
def __init__(self, block: FluxTransformerBlock, scale_shift: float = 1, **kwargs):
super(FluxTransformerBlock, self).__init__()
......@@ -113,6 +178,32 @@ class NunchakuFluxTransformerBlock(FluxTransformerBlock):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
"""
Forward pass for the transformer block.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states.
encoder_hidden_states : torch.Tensor
Encoder hidden states for cross-attention.
temb : torch.Tensor
Time or conditioning embedding.
image_rotary_emb : tuple of torch.Tensor, optional
Rotary embeddings for image/text tokens.
joint_attention_kwargs : dict, optional
Additional attention arguments (not supported).
Returns
-------
tuple
(encoder_hidden_states, hidden_states) after block processing.
Raises
------
NotImplementedError
If joint_attention_kwargs is provided.
"""
if joint_attention_kwargs is not None and len(joint_attention_kwargs) > 0:
raise NotImplementedError("joint_attention_kwargs is not supported")
......@@ -167,6 +258,20 @@ class NunchakuFluxTransformerBlock(FluxTransformerBlock):
class NunchakuFluxSingleTransformerBlock(FluxSingleTransformerBlock):
"""
Nunchaku-optimized single transformer block with quantized attention and MLP.
Parameters
----------
block : FluxSingleTransformerBlock
The original block to wrap and quantize.
scale_shift : float, optional
Value to add to scale parameters. Default is 1.0.
Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
**kwargs
Additional arguments for quantization.
"""
def __init__(self, block: FluxSingleTransformerBlock, scale_shift: float = 1, **kwargs):
super(FluxSingleTransformerBlock, self).__init__()
self.mlp_hidden_dim = block.mlp_hidden_dim
......@@ -176,7 +281,7 @@ class NunchakuFluxSingleTransformerBlock(FluxSingleTransformerBlock):
self.mlp_fc1 = SVDQW4A4Linear.from_linear(block.proj_mlp, **kwargs)
self.act_mlp = block.act_mlp
self.mlp_fc2 = SVDQW4A4Linear.from_linear(block.proj_out, in_features=self.mlp_hidden_dim, **kwargs)
# for int4, we shift the activation of mlp_fc2 to make it unsigned
# For int4, we shift the activation of mlp_fc2 to make it unsigned.
self.mlp_fc2.act_unsigned = self.mlp_fc2.precision != "nvfp4"
self.attn = NunchakuFluxAttention(block.attn, **kwargs)
......@@ -189,16 +294,34 @@ class NunchakuFluxSingleTransformerBlock(FluxSingleTransformerBlock):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
"""
Forward pass for the single transformer block.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states.
temb : torch.Tensor
Time or conditioning embedding.
image_rotary_emb : tuple of torch.Tensor, optional
Rotary embeddings for tokens.
joint_attention_kwargs : dict, optional
Additional attention arguments.
Returns
-------
torch.Tensor
Output hidden states after block processing.
"""
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
# Feedforward
if isinstance(self.act_mlp, GELU):
# use fused gelu mlp
# Use fused GELU MLP for efficiency.
mlp_hidden_states = fused_gelu_mlp(norm_hidden_states, self.mlp_fc1, self.mlp_fc2)
else:
# fallback to original gelu mlp
# Fallback to original MLP.
mlp_hidden_states = self.mlp_fc1(norm_hidden_states)
mlp_hidden_states = self.act_mlp(mlp_hidden_states)
mlp_hidden_states = self.mlp_fc2(mlp_hidden_states)
......@@ -220,8 +343,25 @@ class NunchakuFluxSingleTransformerBlock(FluxSingleTransformerBlock):
class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoaderMixin):
"""
Nunchaku-optimized FluxTransformer2DModel.
"""
def _patch_model(self, **kwargs):
"""
Patch the model with :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformerBlock`
and :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxSingleTransformerBlock`.
Parameters
----------
**kwargs
Additional arguments for quantization.
Returns
-------
self : NunchakuFluxTransformer2DModelV2
The patched model.
"""
self.pos_embed = NunchakuFluxPosEmbed(dim=self.inner_dim, theta=10000, axes_dim=self.pos_embed.axes_dim)
for i, block in enumerate(self.transformer_blocks):
self.transformer_blocks[i] = NunchakuFluxTransformerBlock(block, scale_shift=0, **kwargs)
......@@ -232,6 +372,28 @@ class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoad
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
"""
Load a pretrained NunchakuFluxTransformer2DModelV2 from a safetensors file.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the safetensors file. It can be a local file or a remote HuggingFace path.
**kwargs
Additional arguments (e.g., device, torch_dtype).
Returns
-------
NunchakuFluxTransformer2DModelV2
The loaded and quantized model.
Raises
------
NotImplementedError
If offload is requested.
AssertionError
If the file is not a safetensors file.
"""
device = kwargs.get("device", "cpu")
offload = kwargs.get("offload", False)
......@@ -268,7 +430,7 @@ class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoad
else:
assert state_dict[k].dtype == converted_state_dict[k].dtype
# load the wtscale from the converted state dict
# Load the wtscale from the converted state dict.
for n, m in transformer.named_modules():
if isinstance(m, SVDQW4A4Linear):
if m.wtscale is not None:
......@@ -294,30 +456,44 @@ class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoad
controlnet_blocks_repeat: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
Input `hidden_states`.
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
Forward pass for the NunchakuFluxTransformer2DModelV2.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states of shape (batch_size, image_sequence_length, in_channels).
encoder_hidden_states : torch.Tensor, optional
Conditional embeddings (e.g., from text).
pooled_projections : torch.Tensor, optional
Projected embeddings from input conditions.
timestep : torch.LongTensor, optional
Denoising step.
img_ids : torch.Tensor, optional
Image token IDs.
txt_ids : torch.Tensor, optional
Text token IDs.
guidance : torch.Tensor, optional
Guidance tensor for classifier-free guidance.
joint_attention_kwargs : dict, optional
Additional attention arguments.
controlnet_block_samples : any, optional
Not supported.
controlnet_single_block_samples : any, optional
Not supported.
return_dict : bool, optional
Whether to return a Transformer2DModelOutput (default: True).
controlnet_blocks_repeat : bool, optional
Not supported.
Returns
-------
Transformer2DModelOutput or tuple
Output sample tensor or output tuple.
Raises
------
NotImplementedError
If controlnet is requested.
"""
hidden_states = self.x_embedder(hidden_states)
......@@ -371,7 +547,7 @@ class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoad
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
# Controlnet residual (not supported for now)
if controlnet_block_samples is not None:
raise NotImplementedError("Controlnet is not supported for FluxTransformer2DModelV2 for now")
......@@ -384,7 +560,7 @@ class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoad
joint_attention_kwargs=joint_attention_kwargs,
)
# controlnet residual
# Controlnet residual (not supported for now)
if controlnet_single_block_samples is not None:
raise NotImplementedError("Controlnet is not supported for FluxTransformer2DModelV2 for now")
......@@ -399,6 +575,20 @@ class NunchakuFluxTransformer2DModelV2(FluxTransformer2DModel, NunchakuModelLoad
def convert_flux_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
"""
Convert a state dict from the :class:`~nunchaku.models.transformers.transformer_flux.NunchakuFluxTransformer2dModel`
format to :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformer2DModelV2` format.
Parameters
----------
state_dict : dict[str, torch.Tensor]
The original state dict.
Returns
-------
dict[str, torch.Tensor]
The converted state dict compatible with :class:`~nunchaku.models.transformers.transformer_flux_v2.NunchakuFluxTransformer2DModelV2`.
"""
new_state_dict = {}
for k, v in state_dict.items():
if "single_transformer_blocks." in k:
......
"""
This module provides implementations of NunchakuQwenImageTransformer2DModel and its building blocks.
"""
import gc
import json
import os
......@@ -24,6 +28,19 @@ from .utils import NunchakuModelLoaderMixin
class NunchakuQwenAttention(NunchakuBaseAttention):
"""
Nunchaku-optimized quantized attention module for QwenImage.
Parameters
----------
other : Attention
The original QwenImage Attention module to wrap and quantize.
processor : str, default="flashattn2"
The attention processor to use.
**kwargs
Additional arguments for quantization.
"""
def __init__(self, other: Attention, processor: str = "flashattn2", **kwargs):
super(NunchakuQwenAttention, self).__init__(processor)
self.inner_dim = other.inner_dim
......@@ -59,7 +76,7 @@ class NunchakuQwenAttention(NunchakuBaseAttention):
self.norm_added_q = other.norm_added_q
self.norm_added_k = other.norm_added_k
# fuse the qkv
# Fuse the QKV projections for quantization
with torch.device("meta"):
to_qkv = fuse_linears([other.to_q, other.to_k, other.to_v])
self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs)
......@@ -67,7 +84,7 @@ class NunchakuQwenAttention(NunchakuBaseAttention):
self.to_out[0] = SVDQW4A4Linear.from_linear(self.to_out[0], **kwargs)
assert self.added_kv_proj_dim is not None
# fuse the add_qkv
# Fuse the additional QKV projections
with torch.device("meta"):
add_qkv_proj = fuse_linears([other.add_q_proj, other.add_k_proj, other.add_v_proj])
self.add_qkv_proj = SVDQW4A4Linear.from_linear(add_qkv_proj, **kwargs)
......@@ -75,13 +92,36 @@ class NunchakuQwenAttention(NunchakuBaseAttention):
def forward(
self,
hidden_states: torch.FloatTensor, # Image stream
encoder_hidden_states: torch.FloatTensor = None, # Text stream
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
encoder_hidden_states_mask: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
):
"""
Forward pass for NunchakuQwenAttention.
Parameters
----------
hidden_states : torch.FloatTensor
Image stream input.
encoder_hidden_states : torch.FloatTensor, optional
Text stream input.
encoder_hidden_states_mask : torch.FloatTensor, optional
Mask for encoder hidden states.
attention_mask : torch.FloatTensor, optional
Attention mask.
image_rotary_emb : torch.Tensor, optional
Rotary embedding for images.
**kwargs
Additional arguments.
Returns
-------
tuple
Attention outputs for image and text streams.
"""
return self.processor(
self,
hidden_states,
......@@ -93,6 +133,19 @@ class NunchakuQwenAttention(NunchakuBaseAttention):
)
def set_processor(self, processor: str):
"""
Set the attention processor.
Parameters
----------
processor : str
Name of the processor to use. Only "flashattn2" is supported for now. See :class:`~nunchaku.models.attention_processors.qwenimage.NunchakuQwenImageNaiveFA2Processor`.
Raises
------
ValueError
If the processor is not supported.
"""
if processor == "flashattn2":
self.processor = NunchakuQwenImageNaiveFA2Processor()
else:
......@@ -100,6 +153,22 @@ class NunchakuQwenAttention(NunchakuBaseAttention):
class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock):
"""
Quantized QwenImage Transformer Block.
This block supports quantized linear layers and joint attention for image and text streams.
Parameters
----------
other : QwenImageTransformerBlock
The original transformer block to wrap and quantize.
scale_shift : float, default=1.0
Value to add to scale parameters. Default is 1.0.
Nunchaku may have already fused the scale_shift into the linear weights, so you may want to set it to 0.
**kwargs
Additional arguments for quantization.
"""
def __init__(self, other: QwenImageTransformerBlock, scale_shift: float = 1.0, **kwargs):
super(QwenImageTransformerBlock, self).__init__()
......@@ -122,7 +191,21 @@ class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock):
self.scale_shift = scale_shift
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Apply modulation to input tensor"""
"""
Apply modulation to input tensor.
Parameters
----------
x : torch.Tensor
Input tensor.
mod_params : torch.Tensor
Modulation parameters.
Returns
-------
tuple
Modulated tensor and gate tensor.
"""
shift, scale, gate = mod_params.chunk(3, dim=-1)
if self.scale_shift != 0:
scale.add_(self.scale_shift)
......@@ -137,6 +220,29 @@ class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock):
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Forward pass for NunchakuQwenImageTransformerBlock.
Parameters
----------
hidden_states : torch.Tensor
Image stream input.
encoder_hidden_states : torch.Tensor
Text stream input.
encoder_hidden_states_mask : torch.Tensor
Mask for encoder hidden states.
temb : torch.Tensor
Temporal embedding.
image_rotary_emb : tuple of torch.Tensor, optional
Rotary embedding for images.
joint_attention_kwargs : dict, optional
Additional arguments for joint attention.
Returns
-------
tuple
Updated encoder_hidden_states and hidden_states.
"""
# Get modulation parameters for both streams
img_mod_params = self.img_mod(temb) # [B, 6*dim]
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
......@@ -152,10 +258,6 @@ class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock):
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# Split modulation parameters for norm1 and norm2
# img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
# Process image stream - norm1 + modulation
img_normed = self.img_norm1(hidden_states)
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
......@@ -164,16 +266,10 @@ class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock):
txt_normed = self.txt_norm1(encoder_hidden_states)
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
# Use QwenAttnProcessor2_0 for joint attention computation
# This directly implements the DoubleStreamLayerMegatron logic:
# 1. Computes QKV for both streams
# 2. Applies QK normalization and RoPE
# 3. Concatenates and runs joint attention
# 4. Splits results back to separate streams
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=img_modulated, # Image stream (will be processed as "sample")
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
hidden_states=img_modulated,
encoder_hidden_states=txt_modulated,
encoder_hidden_states_mask=encoder_hidden_states_mask,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
......@@ -208,6 +304,27 @@ class NunchakuQwenImageTransformerBlock(QwenImageTransformerBlock):
class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuModelLoaderMixin):
"""
Quantized QwenImage Transformer2DModel.
This model supports quantized transformer blocks and optional CPU offloading for memory efficiency.
Parameters
----------
*args
Positional arguments for the base model.
**kwargs
Keyword arguments for the base model and quantization.
Attributes
----------
offload : bool
Whether CPU offloading is enabled.
offload_manager : CPUOffloadManager or None
Manager for offloading transformer blocks.
_is_initialized : bool
Whether the model has been patched for quantization.
"""
def __init__(self, *args, **kwargs):
self.offload = kwargs.pop("offload", False)
......@@ -216,6 +333,18 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
super().__init__(*args, **kwargs)
def _patch_model(self, **kwargs):
"""
Patch the transformer blocks for quantization.
Parameters
----------
**kwargs
Additional arguments for quantization.
Returns
-------
self
"""
for i, block in enumerate(self.transformer_blocks):
self.transformer_blocks[i] = NunchakuQwenImageTransformerBlock(block, scale_shift=0, **kwargs)
self._is_initialized = True
......@@ -224,6 +353,26 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: str | os.PathLike[str], **kwargs):
"""
Load a quantized model from a pretrained checkpoint.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the pretrained model checkpoint. It can be a local file or a remote HuggingFace path.
**kwargs
Additional arguments for loading and quantization.
Returns
-------
NunchakuQwenImageTransformer2DModel
The loaded and quantized model.
Raises
------
AssertionError
If the checkpoint is not a safetensors file.
"""
device = kwargs.get("device", "cpu")
offload = kwargs.get("offload", False)
......@@ -271,6 +420,20 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
return transformer
def set_offload(self, offload: bool, **kwargs):
"""
Enable or disable asynchronous CPU offloading for transformer blocks.
Parameters
----------
offload : bool
Whether to enable offloading.
**kwargs
Additional arguments for offload manager.
See Also
--------
:class:`~nunchaku.models.utils.CPUOffloadManager`
"""
if offload == self.offload:
# nothing changed, just return
return
......@@ -302,10 +465,39 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
timestep: torch.LongTensor = None,
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
txt_seq_lens: Optional[List[int]] = None,
guidance: torch.Tensor = None, # TODO: this should probably be removed
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
Forward pass for the quantized QwenImage transformer model.
Parameters
----------
hidden_states : torch.Tensor
Image stream input.
encoder_hidden_states : torch.Tensor, optional
Text stream input.
encoder_hidden_states_mask : torch.Tensor, optional
Mask for encoder hidden states.
timestep : torch.LongTensor, optional
Timestep for temporal embedding.
img_shapes : list of tuple, optional
Image shapes for rotary embedding.
txt_seq_lens : list of int, optional
Text sequence lengths.
guidance : torch.Tensor, optional
Guidance tensor (for classifier-free guidance).
attention_kwargs : dict, optional
Additional attention arguments.
return_dict : bool, default=True
Whether to return a dict or tuple.
Returns
-------
torch.Tensor or Transformer2DModelOutput
Model output.
"""
device = hidden_states.device
if self.offload:
self.offload_manager.set_device(device)
......@@ -357,8 +549,26 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
def to(self, *args, **kwargs):
"""
Overwrite the default .to() method.
If self.offload is True, avoid moving the model to GPU.
Override the default ``.to()`` method.
If offload is enabled, prevents moving the model to GPU.
Prevents changing dtype after quantization.
Parameters
----------
*args
Positional arguments for ``.to()``.
**kwargs
Keyword arguments for ``.to()``.
Returns
-------
self
Raises
------
ValueError
If attempting to change dtype after quantization.
"""
device_arg_or_kwarg_present = any(isinstance(arg, torch.device) for arg in args) or "device" in kwargs
dtype_present_in_args = "dtype" in kwargs
......@@ -382,7 +592,7 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
if dtype_present_in_args and self._is_initialized:
raise ValueError(
"Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
"use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
"use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`."
)
if self.offload:
if device_arg_or_kwarg_present:
......
"""
Utility functions and classes for efficient transformer model management in Nunchaku.
"""
import copy
import torch
......@@ -7,6 +11,28 @@ from ..utils import copy_params_into
def fuse_linears(linears: list[nn.Linear]) -> nn.Linear:
"""
Fuse a list of nn.Linear layers into a single nn.Linear with concatenated output features.
Parameters
----------
linears : list of nn.Linear
List of linear layers to fuse. All must have the same input feature dimension.
Returns
-------
fused : nn.Linear
A new linear layer with concatenated output features and the same input features.
Raises
------
AssertionError
If the input feature dimensions do not match.
Notes
-----
The fused layer does not copy weights or biases from the input layers.
"""
assert len(linears) > 0
if len(linears) == 1:
return linears[0]
......@@ -24,10 +50,45 @@ def fuse_linears(linears: list[nn.Linear]) -> nn.Linear:
class CPUOffloadManager:
"""Generic manager for per-transformer-block CPU offloading with async memory operations.
"""
Manager for per-transformer-block CPU offloading with asynchronous memory operations using a Ping-Pong buffer strategy.
This class enables memory-efficient inference or training by keeping only a subset
of transformer blocks on GPU, offloading the rest to CPU, and preloading blocks as needed.
Parameters
----------
blocks : list of nn.Module
List of transformer blocks to manage.
device : str or torch.device, optional
Target CUDA device for GPU operations. Default is "cuda".
use_pin_memory : bool, optional
Whether to use pinned memory for faster CPU-to-GPU transfers. Default is True.
on_gpu_modules : list of nn.Module, optional
Additional modules to keep on GPU at all times. Default is [].
num_blocks_on_gpu : int, optional
Number of blocks to keep on GPU simultaneously. Must be > 0. Default is 1.
empty_cache_freq : int, optional
Frequency (in forward passes) to call torch.cuda.empty_cache(). Default is 0 (never).
This class can be used with any transformer model that has a list of transformer blocks.
It provides memory-efficient processing by keeping only the current block on GPU.
Attributes
----------
blocks : list of nn.Module
The managed transformer blocks.
buffer_blocks : list of nn.Module
Buffers for preloading blocks onto GPU.
device : torch.device
The current CUDA device.
current_block_idx : int
Index of the current block on GPU.
forward_counter : int
Number of forward passes completed.
memory_stream : torch.cuda.Stream
CUDA stream for memory operations.
compute_done : torch.cuda.Event
CUDA event signaling compute completion.
memory_done : torch.cuda.Event
CUDA event signaling memory completion.
"""
def __init__(
......@@ -61,6 +122,22 @@ class CPUOffloadManager:
self.empty_cache_freq = empty_cache_freq
def set_device(self, device: torch.device | str, force: bool = False):
"""
Set the CUDA device for offloading and memory operations.
It will move buffer blocks and on-GPU modules to the specified device and offload other blocks to CPU, optionally using pinned memory.
Parameters
----------
device : torch.device or str
Target CUDA device.
force : bool, optional
If True, force re-initialization even if device is unchanged. Default is False.
Raises
------
AssertionError
If the device is not a CUDA device.
"""
if isinstance(device, str):
device = torch.device(device)
assert device.type == "cuda"
......@@ -84,7 +161,20 @@ class CPUOffloadManager:
b.data = b.data.pin_memory()
def load_block(self, block_idx: int, non_blocking: bool = True):
"""Move a transformer block to GPU."""
"""
Move a transformer block from CPU to GPU buffer.
Parameters
----------
block_idx : int
Index of the block to load.
non_blocking : bool, optional
Whether to use non-blocking memory copy. Default is True.
Notes
-----
- No action is taken if the block is already on GPU or index is out of range.
"""
# if the block is already on GPU, don't load it to the buffer
if block_idx < self.num_blocks_on_gpu:
return
......@@ -96,7 +186,17 @@ class CPUOffloadManager:
copy_params_into(block, self.buffer_blocks[block_idx % 2], non_blocking=non_blocking)
def step(self, compute_stream: torch.cuda.Stream | None = None):
"""Move to the next block, triggering preload operations."""
"""
Advance to the next transformer block, triggering asynchronous preloading.
It will preload the next block onto GPU in the background and synchronize between compute and memory streams.
After all the blocks are processed, it will call torch.cuda.empty_cache() periodically if ``empty_cache_freq`` > 0.
Parameters
----------
compute_stream : torch.cuda.Stream, optional
CUDA stream for compute operations. If None, uses current stream.
"""
if compute_stream is None:
compute_stream = torch.cuda.current_stream()
next_compute_done = torch.cuda.Event()
......@@ -121,6 +221,20 @@ class CPUOffloadManager:
torch.cuda.empty_cache()
def get_block(self, block_idx: int | None = None) -> nn.Module:
"""
Retrieve the current or specified transformer block for computation.
It will return a buffer block if the requested block is offloaded.
Parameters
----------
block_idx : int, optional
Index of the block to retrieve. If None, returns the current block.
Returns
-------
block : nn.Module
The requested transformer block (on GPU if needed).
"""
if block_idx is None:
block_idx = self.current_block_idx
if block_idx < self.num_blocks_on_gpu:
......@@ -129,6 +243,19 @@ class CPUOffloadManager:
return self.buffer_blocks[block_idx % 2]
def initialize(self, stream: torch.cuda.Stream | None = None):
"""
Initialize CUDA events for compute and memory streams.
It will record the initial events for the compute and memory streams.
Parameters
----------
stream : torch.cuda.Stream, optional
CUDA stream to record initial events. If None, uses current stream.
Notes
-----
- Should be called before the first forward pass.
"""
if stream is None:
stream = torch.cuda.current_stream()
self.compute_done.record(stream)
......
"""
High-performance fused operators for quantized neural network inference.
"""
import torch
from torch.nn import RMSNorm
......@@ -7,8 +11,38 @@ from ..utils import ceil_divide
from .gemm import svdq_gemm_w4a4_cuda
def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pad_size: int = 256):
# a fused operator of fc1 and fc2 with gelu
def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pad_size: int = 256) -> torch.Tensor:
"""
Fused quantized MLP with GELU activation.
Combines the first quantized linear layer, GELU activation, and the second quantized linear layer into a single CUDA kernel. Supports INT4 and NVFP4 quantization.
Parameters
----------
x : torch.Tensor, shape (B, S, C_in), dtype float16 or bfloat16
Input tensor.
fc1 : SVDQW4A4Linear
First quantized linear layer (input → hidden).
fc2 : SVDQW4A4Linear
Second quantized linear layer (hidden → output).
pad_size : int, optional
Batch padding size for CUDA kernel efficiency. Default is 256.
Returns
-------
torch.Tensor, shape (B, S, C_out), dtype as input
Output tensor.
Notes
-----
- Notations:
- B: batch size
- S: sequence length
- C_in: input features
- C_out: output features
- For INT4 quantization, GELU activations are shifted by 0.171875 to ensure non-negativity, enabling unsigned quantization for improved quality. See: https://github.com/nunchaku-tech/nunchaku/blob/433f0b228a61a53fb700ac676fd2e290368ac94d/src/kernels/zgemm/gemm_w4a4_launch_impl.cuh#L286
"""
batch_size, seq_len, channels = x.shape
x = x.view(batch_size * seq_len, channels)
quantized_x, ascales, lora_act = fc1.quantize(x)
......@@ -22,8 +56,6 @@ def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pa
qout_ascales = torch.empty(fc1.out_features // 64, batch_size_pad, dtype=x.dtype, device=x.device)
qout_lora_act = torch.empty(batch_size_pad, fc2.proj_down.shape[1], dtype=torch.float32, device=x.device)
# for int4, we shift the activation after gelu to make it all positive to improve quality.
# if we pass the qout to this kernel, it will do the gelu fusion.
svdq_gemm_w4a4_cuda(
act=quantized_x,
wgt=fc1.qweight,
......@@ -56,6 +88,42 @@ def fused_qkv_norm_rottary(
output: torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
attn_tokens: int = 0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Fused quantized QKV projection with RMSNorm and rotary embeddings.
Performs quantized QKV projection, applies RMS normalization to Q and K, and fuses rotary embeddings in a single CUDA kernel call.
Parameters
----------
x : torch.Tensor, shape (B, S, C_in), dtype float16 or bfloat16
Input tensor.
proj : SVDQW4A4Linear
Quantized QKV projection layer.
norm_q : RMSNorm
RMSNorm for query.
norm_k : RMSNorm
RMSNorm for key.
rotary_emb : torch.Tensor
Packed rotary embedding tensor (see :func:`~nunchaku.models.embeddings.pack_rotemb`).
output : torch.Tensor or tuple of torch.Tensor, optional
Output tensor(s). If None, a new tensor is allocated.
If tuple, should be (output_q, output_k, output_v) for fused attention packing.
attn_tokens : int, optional
Number of attention tokens. Default is 0.
Returns
-------
torch.Tensor or tuple of torch.Tensor
Output tensor of shape (B, S, C_out), or tuple (output_q, output_k, output_v).
Notes
-----
Notations:
- B: batch size
- S: sequence length
- C_in: input features
- C_out: output features
"""
assert isinstance(norm_q, RMSNorm)
assert isinstance(norm_k, RMSNorm)
......
"""
Python wrappers for Nunchaku's quantized GEMM operations.
Python wrappers for Nunchaku's high-performance quantized GEMM (General Matrix-Matrix Multiplication) CUDA kernels.
"""
import math
......@@ -41,85 +41,86 @@ def svdq_gemm_w4a4_cuda(
attn_tokens: int = 0,
):
"""
This function wraps the high-performance CUDA kernel for SVDQuant W4A4 quantized GEMM.
Notation
--------
M : int
Batch size (number of input samples).
K : int
Number of input channels (feature dimension).
N : int
Number of output channels.
G : int
Number of groups. 64 for INT4 and 16 for NVFP4.
Quantized GEMM using SVDQuant W4A4 CUDA kernel, with support for LoRA, rotary embeddings, normalization, and fused activations.
Parameters
----------
act : torch.Tensor
Input activation tensor. Packed shape (M, K // 2). Packed datatype: torch.int8
wgt : torch.Tensor
Quantized weight tensor. Packed shape (N, K // 2). Packed datatype: torch.int8
out : torch.Tensor or None
Output tensor for the linear layer. Shape (M, N). Datatype: torch.float16 or torch.bfloat16. If None, we will create a new tensor.
qout : torch.Tensor or None
Quantized output tensor for the next layer. Packed shape (M, N // 2). Packed datatype: torch.int8. If None, we will create a new tensor.
ascales : torch.Tensor
Activation scales tensor. Shape (K // G, M). Datatype: torch.float16 or torch.bfloat16 for INT4 and torch.float8_e4m3 for NVFP4.
wscales : torch.Tensor
Weight scales tensor. Shape (K // G, N). Datatype: torch.float16 or torch.bfloat16 for INT4 and torch.float8_e4m3 for NVFP4.
oscales : torch.Tensor or None
Output scales tensor. Shape (N // G, M). Datatype: torch.float16 or torch.bfloat16 for INT4 and torch.float8_e4m3 for NVFP4.
poolout : torch.Tensor or None
Not used for now. Just leave it as None.
lora_act_in : torch.Tensor
Low-rank down output tensor. Packed shape (M, R). Packed datatype: torch.float32.
lora_up : torch.Tensor
Low-rank up-projection weights. Packed shape (N, R). Packed datatype: torch.float16 or torch.bfloat16.
lora_down : torch.Tensor or None
Low-rank down-projection weights in the next layer. Packed shape (N, R). Packed datatype: torch.float16 or torch.bfloat16.
lora_act_out : torch.Tensor or None
Output tensor for low-rank down-projection in the next layer. Packed shape (M, R). Packed datatype: torch.float32.
norm_q : torch.Tensor or None
Query normalization tensor. Shape (HEAD_DIM,). Datatype: torch.float16 or torch.bfloat16.
norm_k : torch.Tensor or None
Key normalization tensor. Shape (HEAD_DIM,). Datatype: torch.float16 or torch.bfloat16.
rotary_emb : torch.Tensor or None
Rotary embedding tensor. Shape (M, HEAD_DIM // 2, 2, 2). Datatype: torch.float32. TODO: double check this.
bias : torch.Tensor or None
Bias tensor. Shape (N,). Datatype: torch.float16 or torch.bfloat16.
smooth_factor : torch.Tensor or None
Smoothing factor tensor for quantization in the next layer. Shape (N,). Datatype: torch.float16 or torch.bfloat16.
out_vk : torch.Tensor or None
Used only in SANA.
out_linearattn : torch.Tensor or None
Used only in SANA.
act : torch.Tensor, shape (M, K // 2), dtype int8
Packed input activations.
wgt : torch.Tensor, shape (N, K // 2), dtype int8
Packed quantized weights.
out : torch.Tensor or None, shape (M, N), dtype float16 or bfloat16, optional
Output tensor for the linear layer.
qout : torch.Tensor or None, shape (M, N // 2), dtype int8, optional
Packed quantized input for the next layer.
ascales : torch.Tensor or None, shape (K // G, M), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
Activation scales.
wscales : torch.Tensor or None, shape (K // G, N), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
Weight scales.
oscales : torch.Tensor or None, shape (N // G, M), dtype float16/bfloat16 (INT4) or float8_e4m3fn (NVFP4), optional
Output scales.
poolout : torch.Tensor or None, optional
Reserved for future use.
lora_act_in : torch.Tensor or None, shape (M, R), dtype float32, optional
LoRA down-projection activations.
lora_up : torch.Tensor or None, shape (N, R), dtype float16 or bfloat16, optional
Packed LoRA up-projection weights.
lora_down : torch.Tensor or None, shape (N, R), dtype float16 or bfloat16, optional
Packed LoRA down-projection weights for the next layer.
lora_act_out : torch.Tensor or None, shape (M, R), dtype float32, optional
Output for LoRA down-projection in the next layer.
norm_q : torch.Tensor or None, shape (HEAD_DIM,), dtype float16 or bfloat16, optional
Query RMS normalization.
norm_k : torch.Tensor or None, shape (HEAD_DIM,), dtype float16 or bfloat16, optional
Key RMS normalization.
rotary_emb : torch.Tensor or None, shape (M, HEAD_DIM // 2, 2, 2), dtype float32, optional
Packed rotary embeddings.
bias : torch.Tensor or None, shape (N,), dtype float16 or bfloat16, optional
Bias tensor.
smooth_factor : torch.Tensor or None, shape (N,), dtype float16 or bfloat16, optional
Smoothing factor for quantization in the next layer.
out_vk : torch.Tensor or None, optional
Used only in SANA. Leave as None.
out_linearattn : torch.Tensor or None, optional
Used only in SANA. Leave as None.
act_unsigned : bool, default=False
Whether activations are unsigned.
lora_scales : list of float, default=[]
Scaling factors for the low-rank branch.
If True, activations are unsigned (e.g., after GeLU, shifted by 0.171875). This is only used for INT4 to enable unsigned INT4 activation quantization for better quantization quality.
lora_scales : list of float or None, optional
Per-group LoRA scaling factors (16 channels per group). Defaults to 1.0 per group.
fuse_silu : bool, default=False
Whether to fuse SiLU activation.
If True, fuse SiLU activation.
fp4 : bool, default=False
Whether to use 4-bit floating point quantization (NVFP4).
alpha : float, default=1.0
Per tensor scaling factor for NVFP4.
wcscales : torch.Tensor or None, default=None
Per channel scaling factors for NVFP4. Shape (N,). Datatype: torch.float8_e4m3.
out_q : torch.Tensor or None, default=None
Output tensor for quantized Q, used for Nunchaku attention. Packed shape (B, H, M, D). Datatype: torch.int8.
out_k : torch.Tensor or None, default=None
Output tensor for quantized K, used for Nunchaku attention. Packed shape (B, H, M, D). Datatype: torch.int8.
out_v : torch.Tensor or None, default=None
Output tensor for quantized V, used for Nunchaku attention. Packed shape (B, H, M, D). Datatype: torch.int8.
If True, use 4-bit floating point quantization (NVFP4).
alpha : float or None, default=1.0
Per-tensor scaling factor for NVFP4.
wcscales : torch.Tensor or None, shape (N,), dtype float8_e4m3fn, optional
Per-channel scaling for NVFP4.
out_q : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
Packed quantized Q for attention (used in ``nunchaku-fp16`` attention).
out_k : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
Packed quantized K for attention (used in ``nunchaku-fp16`` attention).
out_v : torch.Tensor or None, shape (B, H, M, D), dtype int8, optional
Packed quantized V for attention (used in ``nunchaku-fp16`` attention).
attn_tokens : int, default=0
Number of attention tokens.
Returns
-------
None
The results are written in-place to the provided output tensors.
Results are written in-place to the provided output tensors.
Notes
-----
Notations:
- M: batch size (input tokens)
- K: input channels (feature dimension)
- N: output channels
- G: group size (64 for INT4, 16 for NVFP4)
- R: LoRA rank
- B: batch size for attention
- H: number of heads
- D: head dimension
"""
if lora_scales is None:
rank = lora_up.shape[1]
......
"""
Python wrappers for Nunchaku's quantized GEMV operations.
Python wrapper for Nunchaku's high-performance GEMV (General Matrix-Vector Multiplication) CUDA kernels.
"""
import torch
......@@ -17,4 +17,40 @@ def awq_gemv_w4a16_cuda(
k: int,
group_size: int = 64,
) -> torch.Tensor:
"""
Performs quantized GEMV using the AWQ W4A16 format.
Parameters
----------
in_feats : torch.Tensor, shape (k,) or (m, k), dtype float16 or bfloat16
Input feature vector or batch of vectors.
kernel : torch.Tensor, shape (n // 4, k // 2), dtype int32
Packed quantized weight matrix.
scaling_factors : torch.Tensor, shape (k // group_size, n), dtype float16 or bfloat16
Per-group scaling factors.
zeros : torch.Tensor, shape (k // group_size, n), dtype float16 or bfloat16
Per-group zero points.
m : int
Batch size (number of input vectors).
n : int
Output feature dimension.
k : int
Input feature dimension.
group_size : int, optional
Number of input channels per quantization group. Default is 64.
Returns
-------
torch.Tensor, shape (m, n), dtype float16 or bfloat16
Output tensor.
Notes
-----
Notations:
- m: batch size
- n: output features
- k: input features
- group_size: quantization group size
"""
return ops.gemv_awq(in_feats, kernel, scaling_factors, zeros, m, n, k, group_size)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment