"vscode:/vscode.git/clone" did not exist on "06be7fb4be4d8a56eb73ab4c4176f027325720c0"
Unverified Commit b0484ae0 authored by dengdong's avatar dengdong Committed by GitHub
Browse files

feat: sdxl model support (#674)

* feat: sdxl model support

* code style auto-modified by pre-commit hook

* refine comments

* add tests and examples for sdxl

* refine sdxl tests code

* make linter happy

* mv the locations of the examples

* move the locations of the tests

* refine tests and examples

* add API documentation for unet_sdxl.py

* usage doc for sdxl

* update docs

* update

* refine pipeline initialization

* refine tests for sdxl/sdxl-turbo
parent 657863bb
...@@ -20,6 +20,7 @@ Check out `DeepCompressor <github_deepcompressor_>`_ for the quantization librar ...@@ -20,6 +20,7 @@ Check out `DeepCompressor <github_deepcompressor_>`_ for the quantization librar
usage/qwen-image-edit.rst usage/qwen-image-edit.rst
usage/lora.rst usage/lora.rst
usage/kontext.rst usage/kontext.rst
usage/sdxl.rst
usage/controlnet.rst usage/controlnet.rst
usage/qencoder.rst usage/qencoder.rst
usage/offload.rst usage/offload.rst
......
...@@ -5,6 +5,7 @@ nunchaku.models ...@@ -5,6 +5,7 @@ nunchaku.models
:maxdepth: 4 :maxdepth: 4
nunchaku.models.transformers nunchaku.models.transformers
nunchaku.models.unets
nunchaku.models.text_encoders nunchaku.models.text_encoders
nunchaku.models.linear nunchaku.models.linear
nunchaku.models.attention nunchaku.models.attention
......
nunchaku.models.unets
=====================
.. toctree::
:maxdepth: 4
nunchaku.models.unets.unet_sdxl
nunchaku.models.unets.unet\_sdxl
================================
.. automodule:: nunchaku.models.unets.unet_sdxl
:members:
:undoc-members:
:show-inheritance:
Stable Diffusion XL
===================
The following is the example of running Nunchaku INT4 version of SDXL and SDXL-Turbo text-to-image pipeline.
.. tabs::
.. tab:: SDXL
.. literalinclude:: ../../../examples/v1/sdxl.py
:language: python
:caption: Running Nunchaku SDXL (`examples/v1/sdxl.py <https://github.com/nunchaku-tech/nunchaku/blob/main/examples/v1/sdxl.py>`__)
:linenos:
.. tab:: SDXL-turbo
.. literalinclude:: ../../../examples/v1/sdxl-turbo.py
:language: python
:caption: Running Nunchaku SDXL-Turbo (`examples/v1/sdxl-turbo.py <https://github.com/nunchaku-tech/nunchaku/blob/main/examples/v1/sdxl-turbo.py>`__)
:linenos:
For more details, see :class:`~nunchaku.models.unets.unet_sdxl.NunchakuSDXLUNet2DConditionModel`.
import torch
from diffusers import StableDiffusionXLPipeline
from nunchaku.models.unets.unet_sdxl import NunchakuSDXLUNet2DConditionModel
if __name__ == "__main__":
unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors"
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/sdxl-turbo", unet=unet, torch_dtype=torch.bfloat16, variant="fp16"
).to("cuda")
prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
image = pipeline(prompt=prompt, guidance_scale=0.0, num_inference_steps=4).images[0]
image.save("sdxl-turbo.png")
import torch
from diffusers import StableDiffusionXLPipeline
from nunchaku.models.unets.unet_sdxl import NunchakuSDXLUNet2DConditionModel
if __name__ == "__main__":
unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors"
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=unet,
torch_dtype=torch.bfloat16,
use_safetensors=True,
variant="fp16",
).to("cuda")
prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
image = pipeline(prompt=prompt, guidance_scale=5.0, num_inference_steps=50).images[0]
image.save("sdxl.png")
from typing import Optional
import torch
from torch.nn import functional as F
class NunchakuSDXLFA2Processor:
def __call__(
self,
attn,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
):
# Adapted from https://github.com/huggingface/diffusers/blob/50dea89dc6036e71a00bc3d57ac062a80206d9eb/src/diffusers/models/attention_processor.py#AttnProcessor2_0
# if len(args) > 0 or kwargs.get("scale", None) is not None:
# deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
# deprecate("scale", "1.0.0", deprecation_message)
# residual = hidden_states
# if attn.spatial_norm is not None:
# hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
# attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# # scaled_dot_product_attention expects attention_mask shape to be
# # (batch, heads, source_length, target_length)
# attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
raise NotImplementedError("attention_mask is not supported")
# if attn.group_norm is not None:
# hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
############# qkv ################
# query = attn.to_q(hidden_states)
# if encoder_hidden_states is None:
# encoder_hidden_states = hidden_states
# elif attn.norm_cross:
# encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# key = attn.to_k(encoder_hidden_states)
# value = attn.to_v(encoder_hidden_states)
if not attn.is_cross_attention:
qkv = attn.to_qkv(hidden_states)
query, key, value = qkv.chunk(3, dim=-1)
# query, key, value = attn.to_q(hidden_states), attn.to_k(hidden_states), attn.to_v(hidden_states)
else:
query, key, value = (
attn.to_q(hidden_states),
attn.to_k(encoder_hidden_states),
attn.to_v(encoder_hidden_states),
)
############# end of qkv ################
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# if attn.norm_q is not None:
# query = attn.norm_q(query)
# if attn.norm_k is not None:
# key = attn.norm_k(key)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
# if attn.residual_connection:
# hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
from .unet_sdxl import (
NunchakuSDXLAttention,
NunchakuSDXLConcatShiftedConv2d,
NunchakuSDXLShiftedConv2d,
NunchakuSDXLTransformerBlock,
NunchakuSDXLUNet2DConditionModel,
)
__all__ = [
"NunchakuSDXLAttention",
"NunchakuSDXLTransformerBlock",
"NunchakuSDXLShiftedConv2d",
"NunchakuSDXLConcatShiftedConv2d",
"NunchakuSDXLUNet2DConditionModel",
"NunchakuSDXLFeedForward",
]
"""
Implements the :class:`NunchakuSDXLUNet2DConditionModel`, providing Nunchaku quantized version of Stable Diffusion XL UNet2DConditionModel unet and its building blocks.
"""
import json
import os
from pathlib import Path
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from diffusers.models.attention import BasicTransformerBlock, FeedForward
from diffusers.models.attention_processor import Attention
from diffusers.models.unets.unet_2d_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
Transformer2DModel,
UNetMidBlock2DCrossAttn,
UpBlock2D,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
from huggingface_hub import utils
from torch import nn
from nunchaku.utils import get_precision
from ..attention import NunchakuBaseAttention, _patch_linear
from ..attention_processors.sdxl import NunchakuSDXLFA2Processor
from ..linear import SVDQW4A4Linear
from ..transformers.utils import NunchakuModelLoaderMixin
from ..utils import fuse_linears
class NunchakuSDXLAttention(NunchakuBaseAttention):
"""
Nunchaku-optimized SDXLAttention module with quantized and fused QKV projections.
Parameters
----------
orig_attn : Attention
The original Attention module used by Stable Diffusion XL to wrap and quantize.
processor : str, optional
The attention processor to use (valid value: "flashattn2").
**kwargs
Additional arguments for quantization.
"""
def __init__(self, orig_attn: Attention, processor: str = "flashattn2", **kwargs):
super(NunchakuSDXLAttention, self).__init__(processor)
self.is_cross_attention = orig_attn.is_cross_attention
self.heads = orig_attn.heads
self.rescale_output_factor = orig_attn.rescale_output_factor
if not orig_attn.is_cross_attention:
# fuse the qkv
with torch.device("meta"):
to_qkv = fuse_linears([orig_attn.to_q, orig_attn.to_k, orig_attn.to_v])
self.to_qkv = SVDQW4A4Linear.from_linear(to_qkv, **kwargs)
else:
self.to_q = SVDQW4A4Linear.from_linear(orig_attn.to_q, **kwargs)
self.to_k = orig_attn.to_k
self.to_v = orig_attn.to_v
self.to_out = orig_attn.to_out
self.to_out[0] = SVDQW4A4Linear.from_linear(self.to_out[0], **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**cross_attention_kwargs,
) -> torch.Tensor:
"""
Forward pass for NunchakuSDXLAttention.
Parameters
----------
hidden_states : torch.Tensor
Input tensor.
encoder_hidden_states : torch.Tensor, optional
Encoder hidden states for cross-attention.
attention_mask : torch.Tensor, optional
Attention mask.
**cross_attention_kwargs
Additional arguments for cross attention.
Returns
-------
Output of the attention processor.
"""
return self.processor(
self,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
def set_processor(self, processor: str):
"""
Set the attention processor.
Parameters
----------
processor : str
Name of the processor, "flashattn2" is supported. Others would be supported in future.
- ``"flashattn2"``: Standard FlashAttention-2. Adapted from https://github.com/huggingface/diffusers/blob/50dea89dc6036e71a00bc3d57ac062a80206d9eb/src/diffusers/models/attention_processor.py#AttnProcessor2_0
Raises
------
ValueError
If the processor is not supported.
"""
if processor == "flashattn2":
self.processor = NunchakuSDXLFA2Processor()
else:
raise ValueError(f"Processor {processor} is not supported")
class NunchakuSDXLFeedForward(FeedForward):
"""
Quantized feed-forward (MLP) block for :class:`NunchakuSDXLTransformerBlock`.
Replaces linear layers in a FeedForward block with :class:`~nunchaku.models.linear.SVDQW4A4Linear` for quantized inference.
Parameters
----------
ff : FeedForward
Source FeedForward block to quantize.
**kwargs :
Additional arguments for SVDQW4A4Linear.
"""
def __init__(self, ff: FeedForward, **kwargs):
super(FeedForward, self).__init__()
self.net = _patch_linear(ff.net, SVDQW4A4Linear, **kwargs)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the quantized feed-forward block.
Parameters
----------
hidden_states : torch.Tensor, shape (B, D)
Input tensor.
Returns
-------
torch.Tensor, shape (B, D)
Output tensor after feed-forward transformation.
"""
for module in self.net:
hidden_states = module(hidden_states)
return hidden_states
class NunchakuSDXLTransformerBlock(BasicTransformerBlock):
"""
Nunchaku-optimized transformer block for Stable Diffusion XL with quantized attention and feedforward layers.
Parameters
----------
block : BasicTransformerBlock
The original block from within UNet2DConditionModel to wrap and quantize.
**kwargs
Additional arguments for quantization.
"""
def __init__(self, block: BasicTransformerBlock, **kwargs):
super(BasicTransformerBlock, self).__init__()
self.norm_type = block.norm_type
self.pos_embed = block.pos_embed
self.only_cross_attention = block.only_cross_attention
self.norm1 = block.norm1
self.norm2 = block.norm2
self.norm3 = block.norm3
self.attn1 = NunchakuSDXLAttention(block.attn1, **kwargs)
self.attn2 = NunchakuSDXLAttention(block.attn2, **kwargs)
self.ff = NunchakuSDXLFeedForward(block.ff, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
) -> torch.Tensor:
"""
Forward pass for the transformer block.
Parameters
----------
hidden_states : torch.Tensor
Input hidden states.
attention_mask: torch.Tensor, optional
The attention mask.
encoder_hidden_states : torch.Tensor
Encoder hidden states for cross-attention.
encoder_attention_mask: torch.Tensor, optional
The encoder attention mask.
cross_attention_kwargs: dict
Addtional cross attention kwargs.
Returns
-------
hidden_states: torch.Tensor
The hidden states after processing.
Raises
------
ValueError
If norm_type is not "layer_norm" or only_cross_attetion is true.
"""
# Adapted from diffusers.models.attention#BasicTransformerBlock#forward
if self.norm_type == "layer_norm":
norm_hidden_states = self.norm1(hidden_states)
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
if self.only_cross_attention:
raise ValueError("only_cross_attetion cannot be True")
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1.2 GLIGEN Control # TODO check
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.norm_type == "layer_norm":
norm_hidden_states = self.norm2(hidden_states)
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)
ff_output = self.ff(norm_hidden_states)
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class NunchakuSDXLShiftedConv2d(nn.Module):
# Adapted from https://github.com/nunchaku-tech/deepcompressor/blob/main/deepcompressor/nn/patch/conv.py#ShiftedConv2d
def __init__(
self,
orig_in_channels,
orig_out_channels,
orig_kernel_size,
orig_stride,
orig_padding,
orig_dilation,
orig_groups,
# orig_bias,
orig_padding_mode,
orig_device,
orig_dtype,
):
super().__init__()
self.conv = nn.Conv2d(
in_channels=orig_in_channels,
out_channels=orig_out_channels,
kernel_size=orig_kernel_size,
stride=orig_stride,
padding=orig_padding,
dilation=orig_dilation,
groups=orig_groups,
bias=True,
padding_mode=orig_padding_mode,
device=orig_device,
dtype=orig_dtype,
)
self.shift = nn.Parameter(torch.empty(1, 1, 1, 1, dtype=orig_dtype), requires_grad=False) # hard code
self.out_channels = orig_out_channels
self.padding_size = self.conv._reversed_padding_repeated_twice
if all(p == 0 for p in self.padding_size):
self.padding_mode = ""
elif orig_padding_mode == "zeros":
self.padding_mode = "constant"
# use shift
else:
self.padding_mode = orig_padding_mode
self.conv.padding = "valid"
self.conv.padding_mode = "zeros"
self.conv._reversed_padding_repeated_twice = [0, 0] * len(self.conv.kernel_size)
def forward(self, input: torch.Tensor) -> torch.Tensor:
input = input + self.shift
if self.padding_mode == "constant":
input = F.pad(input, self.padding_size, mode=self.padding_mode, value=self.shift.item())
elif self.padding_mode:
input = F.pad(input, self.padding_size, mode=self.padding_mode, value=None)
return self.conv(input)
class NunchakuSDXLConcatShiftedConv2d(nn.Module):
# Adapted from https://github.com/nunchaku-tech/deepcompressor/blob/main/deepcompressor/nn/patch/conv.py#ConcatConv2d
def __init__(self, orig_conv: nn.Conv2d, split: int):
super().__init__()
splits = [split, orig_conv.in_channels - split] if orig_conv.in_channels - split > 0 else [split]
assert len(splits) > 1, "ConcatShiftedConv2d requires at least 2 input channels"
self.in_channels_list = splits
self.in_channels = orig_conv.in_channels
self.out_channels = orig_conv.out_channels
self.convs = nn.ModuleList(
[
NunchakuSDXLShiftedConv2d(
split_in_channels,
orig_conv.out_channels,
orig_conv.kernel_size,
orig_conv.stride,
orig_conv.padding,
orig_conv.dilation,
orig_conv.groups,
# bias if idx == num_convs - 1 else False,
orig_conv.padding_mode,
orig_conv.weight.device,
orig_conv.weight.dtype,
)
for split_in_channels in splits
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# slice x based on in_channels_list
x_splits: list[torch.Tensor] = x.split(self.in_channels_list, dim=1)
# apply each conv to each slice (we have to make contiguous input for quantization)
# out_splits = [conv(x_split.contiguous()) for conv, x_split in zip(self.convs, x_splits, strict=True)]
out_splits = [conv(x_split) for conv, x_split in zip(self.convs, x_splits, strict=True)]
# sum the results
return sum(out_splits)
class NunchakuSDXLUNet2DConditionModel(UNet2DConditionModel, NunchakuModelLoaderMixin):
"""
Nunchaku-optimized UNet2DConditionModel for Stable Diffusion XL.
"""
def _patch_model(self, **kwargs):
"""
Patch the model by replace the orginal BasicTransformerBlock with :class:`NunchakuSDXLTransformerBlock`
Parameters
----------
**kwargs
Additional arguments for quantization.
Returns
-------
self : NunchakuSDXLUNet2DConditionModel
The patched model.
"""
def _patch_attentions(block: CrossAttnDownBlock2D | CrossAttnUpBlock2D | UNetMidBlock2DCrossAttn):
for _, attn in enumerate(block.attentions):
assert isinstance(attn, Transformer2DModel), "Dual cross attention is not supported"
nunchaku_sdxl_transformer_blocks = []
for _, transformer_block in enumerate(attn.transformer_blocks):
assert isinstance(transformer_block, BasicTransformerBlock)
nunchaku_sdxl_transformer_block = NunchakuSDXLTransformerBlock(transformer_block, **kwargs)
nunchaku_sdxl_transformer_blocks.append(nunchaku_sdxl_transformer_block)
attn.transformer_blocks = nn.ModuleList(nunchaku_sdxl_transformer_blocks)
# _patch_resnets_convs is not used since the support from the inference engine is not completed.
def _patch_resnets_convs(
block: CrossAttnDownBlock2D | CrossAttnUpBlock2D | UNetMidBlock2DCrossAttn | UpBlock2D | DownBlock2D,
up_block_idx: int | None = None,
):
for resnet_idx, resnet in enumerate(block.resnets):
if isinstance(block, (CrossAttnUpBlock2D, UpBlock2D)):
if resnet_idx == 0:
if up_block_idx == 0:
prev_block = self.mid_block
else:
prev_block = self.up_blocks[up_block_idx - 1]
split = prev_block.resnets[-1].conv2.out_channels
else:
split = block.resnets[resnet_idx - 1].conv2.out_channels
resnet.conv1 = NunchakuSDXLConcatShiftedConv2d(resnet.conv1, split)
else:
resnet.conv1 = NunchakuSDXLShiftedConv2d(
resnet.conv1.in_channels,
resnet.conv1.out_channels,
resnet.conv1.kernel_size,
resnet.conv1.stride,
resnet.conv1.padding,
resnet.conv1.dilation,
resnet.conv1.groups,
# orig_bias,
resnet.conv1.padding_mode,
resnet.conv1.weight.device,
resnet.conv1.weight.dtype,
)
resnet.conv2 = NunchakuSDXLShiftedConv2d(
resnet.conv2.in_channels,
resnet.conv2.out_channels,
resnet.conv2.kernel_size,
resnet.conv2.stride,
resnet.conv2.padding,
resnet.conv2.dilation,
resnet.conv2.groups,
# orig_bias,
resnet.conv2.padding_mode,
resnet.conv2.weight.device,
resnet.conv2.weight.dtype,
)
for _, down_block in enumerate(self.down_blocks):
if isinstance(down_block, CrossAttnDownBlock2D):
_patch_attentions(down_block)
for _, up_block in enumerate(self.up_blocks):
if isinstance(up_block, CrossAttnUpBlock2D):
_patch_attentions(up_block)
assert isinstance(self.mid_block, UNetMidBlock2DCrossAttn), "Only UNetMidBlock2DCrossAttn is supported"
_patch_attentions(self.mid_block)
return self
@classmethod
@utils.validate_hf_hub_args
def from_pretrained(cls, pretrained_model_path: str | os.PathLike[str], **kwargs):
"""
Load a pretrained NunchakuSDXLUNet2DConditionModel from a safetensors file.
Parameters
----------
pretrained_model_name_or_path : str or os.PathLike
Path to the safetensors file. It can be a local file or a remote HuggingFace path.
**kwargs
Additional arguments (e.g., device, torch_dtype).
Returns
-------
NunchakuSDXLUNet2DConditionModel
The loaded and quantized model.
Raises
------
NotImplementedError
If offload is requested.
AssertionError
If the file is not a safetensors file.
"""
device = kwargs.get("device", "cpu")
offload = kwargs.get("offload", False)
if offload:
raise NotImplementedError("Offload is not supported for NunchakuSDXLUNet2DConditionModel")
torch_dtype = kwargs.get("torch_dtype", torch.bfloat16)
if isinstance(pretrained_model_path, str):
pretrained_model_path = Path(pretrained_model_path)
assert pretrained_model_path.is_file() or pretrained_model_path.name.endswith(
(".safetensors", ".sft")
), "Only safetensors are supported"
unet, model_state_dict, metadata = cls._build_model(pretrained_model_path, **kwargs)
quantization_config = json.loads(metadata.get("quantization_config", "{}"))
rank = quantization_config.get("rank", 32)
unet = unet.to(torch_dtype)
precision = get_precision()
if precision == "fp4":
precision = "nvfp4"
unet._patch_model(precision=precision, rank=rank)
unet = unet.to_empty(device=device)
converted_state_dict = convert_sdxl_state_dict(model_state_dict)
unet.load_state_dict(converted_state_dict)
return unet
def convert_sdxl_state_dict(state_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
new_state_dict = {}
for k, v in state_dict.items():
if ".transformer_blocks." in k:
if ".lora_down" in k:
new_k = k.replace(".lora_down", ".proj_down")
elif ".lora_up" in k:
new_k = k.replace(".lora_up", ".proj_up")
elif ".smooth_orig" in k:
new_k = k.replace(".smooth_orig", ".smooth_factor_orig")
elif ".smooth" in k:
new_k = k.replace(".smooth", ".smooth_factor")
else:
new_k = k
new_state_dict[new_k] = v
else:
new_state_dict[k] = v
return new_state_dict
...@@ -82,9 +82,9 @@ def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pa ...@@ -82,9 +82,9 @@ def fused_gelu_mlp(x: torch.Tensor, fc1: SVDQW4A4Linear, fc2: SVDQW4A4Linear, pa
def fused_qkv_norm_rottary( def fused_qkv_norm_rottary(
x: torch.Tensor, x: torch.Tensor,
proj: SVDQW4A4Linear, proj: SVDQW4A4Linear,
norm_q: RMSNorm, norm_q: RMSNorm | None = None,
norm_k: RMSNorm, norm_k: RMSNorm | None = None,
rotary_emb: torch.Tensor, rotary_emb: torch.Tensor | None = None,
output: torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, output: torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
attn_tokens: int = 0, attn_tokens: int = 0,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
...@@ -124,8 +124,8 @@ def fused_qkv_norm_rottary( ...@@ -124,8 +124,8 @@ def fused_qkv_norm_rottary(
- C_in: input features - C_in: input features
- C_out: output features - C_out: output features
""" """
assert isinstance(norm_q, RMSNorm) assert norm_q is None or isinstance(norm_q, RMSNorm)
assert isinstance(norm_k, RMSNorm) assert norm_k is None or isinstance(norm_k, RMSNorm)
batch_size, seq_len, channels = x.shape batch_size, seq_len, channels = x.shape
x = x.view(batch_size * seq_len, channels) x = x.view(batch_size * seq_len, channels)
...@@ -148,8 +148,8 @@ def fused_qkv_norm_rottary( ...@@ -148,8 +148,8 @@ def fused_qkv_norm_rottary(
fp4=proj.precision == "nvfp4", fp4=proj.precision == "nvfp4",
alpha=proj.wtscale, alpha=proj.wtscale,
wcscales=proj.wcscales, wcscales=proj.wcscales,
norm_q=norm_q.weight, norm_q=norm_q.weight if norm_q is not None else None,
norm_k=norm_k.weight, norm_k=norm_k.weight if norm_k is not None else None,
rotary_emb=rotary_emb, rotary_emb=rotary_emb,
out_q=output_q, out_q=output_q,
out_k=output_k, out_k=output_k,
...@@ -170,8 +170,8 @@ def fused_qkv_norm_rottary( ...@@ -170,8 +170,8 @@ def fused_qkv_norm_rottary(
fp4=proj.precision == "nvfp4", fp4=proj.precision == "nvfp4",
alpha=proj.wtscale, alpha=proj.wtscale,
wcscales=proj.wcscales, wcscales=proj.wcscales,
norm_q=norm_q.weight, norm_q=norm_q.weight if norm_q is not None else None,
norm_k=norm_k.weight, norm_k=norm_k.weight if norm_k is not None else None,
rotary_emb=rotary_emb, rotary_emb=rotary_emb,
) )
output = output.view(batch_size, seq_len, -1) output = output.view(batch_size, seq_len, -1)
......
import gc
import os
from pathlib import Path
import pytest
import torch
from diffusers import StableDiffusionXLPipeline
from nunchaku.models.unets.unet_sdxl import NunchakuSDXLUNet2DConditionModel
from nunchaku.utils import get_precision, is_turing
from ...flux.utils import already_generate, compute_lpips, hash_str_to_int
from .test_sdxl_turbo import plot, run_benchmark
@pytest.mark.skipif(
is_turing() or get_precision() == "fp4", reason="Skip tests due to using Turing GPUs or FP4 precision"
)
@pytest.mark.parametrize("expected_lpips", [0.25 if get_precision() == "int4" else 0.18])
def test_sdxl_lpips(expected_lpips: float):
gc.collect()
torch.cuda.empty_cache()
precision = get_precision()
ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
results_dir_original = ref_root / "fp16" / "sdxl"
results_dir_nunchaku = ref_root / precision / "sdxl"
os.makedirs(results_dir_original, exist_ok=True)
os.makedirs(results_dir_nunchaku, exist_ok=True)
prompts = [
"Ilya Repin, Moebius, Yoshitaka Amano, 1980s nubian punk rock glam core fashion shoot, closeup, 35mm ",
"A honeybee sitting on a flower in a garden full of yellow flowers",
"Vibrant, tropical rainforest, teeming with wildlife, nature photography ",
"very realistic photo of barak obama in a wing eating contest",
"oil paint of colorful wildflowers in a meadow, Paul Signac divisionism style ",
]
if not already_generate(results_dir_original, 5):
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16, use_safetensors=True, variant="fp16"
).to("cuda")
for prompt in prompts:
seed = hash_str_to_int(prompt)
result = pipeline(
prompt=prompt, guidance_scale=5.0, num_inference_steps=50, generator=torch.Generator().manual_seed(seed)
).images[0]
result.save(os.path.join(results_dir_original, f"{seed}.png"))
del pipeline.unet
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.vae
del pipeline
del result
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info()
print(f"After original generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
if not already_generate(results_dir_nunchaku, 5):
quantized_unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors"
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=quantized_unet,
torch_dtype=torch.bfloat16,
use_safetensors=True,
variant="fp16",
)
pipeline.unet = quantized_unet
pipeline = pipeline.to("cuda")
for prompt in prompts:
seed = hash_str_to_int(prompt)
result = pipeline(
prompt=prompt, guidance_scale=5.0, num_inference_steps=50, generator=torch.Generator().manual_seed(seed)
).images[0]
result.save(os.path.join(results_dir_nunchaku, f"{seed}.png"))
del pipeline
del quantized_unet
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info()
print(f"After Nunchaku generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
lpips = compute_lpips(results_dir_original, results_dir_nunchaku)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.15
@pytest.mark.skipif(
is_turing() or get_precision() == "fp4", reason="Skip tests due to using Turing GPUs or FP4 precision"
)
@pytest.mark.parametrize("expected_latency", [7.455])
def test_sdxl_time_cost(expected_latency: float):
batch_size = 2
runs = 5
inference_steps = 50
guidance_scale = 5.0
device_name = torch.cuda.get_device_name(0)
results = {"Nunchaku INT4": []}
quantized_unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl/svdq-int4_r32-sdxl.safetensors"
)
pipeline_quantized = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
unet=quantized_unet,
torch_dtype=torch.bfloat16,
use_safetensors=True,
variant="fp16",
)
pipeline_quantized = pipeline_quantized.to("cuda")
benchmark_quantized = run_benchmark(
pipeline_quantized, batch_size, guidance_scale, device_name, runs, inference_steps
)
avg_latency = benchmark_quantized.mean() * inference_steps
results["Nunchaku INT4"].append(avg_latency)
ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
plot_save_path = ref_root / "time_cost" / "sdxl"
os.makedirs(plot_save_path, exist_ok=True)
plot([batch_size], results, device_name, runs, inference_steps, plot_save_path, "SDXL")
assert avg_latency < expected_latency * 1.1
import gc
import os
import time
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import pytest
import torch
from diffusers import StableDiffusionXLPipeline
from nunchaku.models.unets.unet_sdxl import NunchakuSDXLUNet2DConditionModel
from nunchaku.utils import get_precision, is_turing
from ...flux.utils import already_generate, compute_lpips, hash_str_to_int
@pytest.mark.skipif(
is_turing() or get_precision() == "fp4", reason="Skip tests due to using Turing GPUs or FP4 precision"
)
@pytest.mark.parametrize("expected_lpips", [0.25 if get_precision() == "int4" else 0.18])
def test_sdxl_turbo_lpips(expected_lpips: float):
gc.collect()
torch.cuda.empty_cache()
precision = get_precision()
ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
results_dir_original = ref_root / "fp16" / "sdxl-turbo"
results_dir_nunchaku = ref_root / precision / "sdxl-turbo"
os.makedirs(results_dir_original, exist_ok=True)
os.makedirs(results_dir_nunchaku, exist_ok=True)
prompts = [
"Ilya Repin, Moebius, Yoshitaka Amano, 1980s nubian punk rock glam core fashion shoot, closeup, 35mm ",
"A honeybee sitting on a flower in a garden full of yellow flowers",
"Vibrant, tropical rainforest, teeming with wildlife, nature photography ",
"very realistic photo of barak obama in a wing eating contest",
"oil paint of colorful wildflowers in a meadow, Paul Signac divisionism style ",
]
if not already_generate(results_dir_original, 5):
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/sdxl-turbo", torch_dtype=torch.bfloat16, variant="fp16"
).to("cuda")
for prompt in prompts:
seed = hash_str_to_int(prompt)
result = pipeline(
prompt=prompt, guidance_scale=0.0, num_inference_steps=4, generator=torch.Generator().manual_seed(seed)
).images[0]
result.save(os.path.join(results_dir_original, f"{seed}.png"))
del pipeline.unet
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.vae
del pipeline
del result
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info()
print(f"After original generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
if not already_generate(results_dir_nunchaku, 5):
quantized_unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors"
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/sdxl-turbo", unet=quantized_unet, torch_dtype=torch.bfloat16, variant="fp16"
)
pipeline = pipeline.to("cuda")
for prompt in prompts:
seed = hash_str_to_int(prompt)
result = pipeline(
prompt=prompt, guidance_scale=0.0, num_inference_steps=4, generator=torch.Generator().manual_seed(seed)
).images[0]
result.save(os.path.join(results_dir_nunchaku, f"{seed}.png"))
del pipeline
del quantized_unet
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
free, total = torch.cuda.mem_get_info()
print(f"After Nunchaku generation: Free: {free/1024**2:.0f} MB / Total: {total/1024**2:.0f} MB")
lpips = compute_lpips(results_dir_original, results_dir_nunchaku)
print(f"lpips: {lpips}")
assert lpips < expected_lpips * 1.15
class PerfHook:
def __init__(self):
self.start = []
self.end = []
def pre_hook(self, module, input):
self.start.append(time.perf_counter())
def post_hook(self, module, input, output):
self.end.append(time.perf_counter())
def run_benchmark(pipeline, batch_size, guidance_scale, device, runs, inference_steps):
prompt = "A cinematic shot of a baby racoon wearing an intricate italian priest robe."
# warmup
_ = pipeline(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=inference_steps,
num_images_per_prompt=batch_size,
).images
time_cost = []
unet = pipeline.unet
perf_hook = PerfHook()
handle_pre = unet.register_forward_pre_hook(perf_hook.pre_hook)
handle_post = unet.register_forward_hook(perf_hook.post_hook)
# run
for _ in range(runs):
_ = pipeline(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=inference_steps,
num_images_per_prompt=batch_size,
).images
time_cost = [perf_hook.end[i] - perf_hook.start[i] for i in range(len(perf_hook.start))]
# to numpy for stats
time_cost = np.array(time_cost)
print(f"device: {device}")
print(f"runs :{runs}")
print(f"batch_size: {batch_size}")
print(f"max :{time_cost.max():.4f}")
print(f"min :{time_cost.min():.4f}")
print(f"avg :{time_cost.mean():.4f}")
print(f"std :{time_cost.std():.4f}")
handle_pre.remove()
handle_post.remove()
return time_cost
def plot(batch_sizes, results, device_name, runs, inference_steps, plot_save_path, title):
x = np.arange(len(batch_sizes))
width = 0.35
fig, ax = plt.subplots()
rects2 = ax.bar(x + width / 2, results["Nunchaku INT4"], width, label="Nunchaku INT4")
ax.set_ylabel(f"Average time cost (seconds)\n{runs} runs of {inference_steps} inference steps each.")
ax.set_xlabel("Batch size")
ax.set_title(f"{title} diffusion time cost\n(GPU: {device_name})")
ax.set_xticks(x)
ax.set_xticklabels(batch_sizes)
ax.legend()
def autolabel(rects):
for rect in rects:
height = rect.get_height()
ax.annotate(
f"{height:.3f}",
xy=(rect.get_x() + rect.get_width() / 2, height),
xytext=(0, 3),
textcoords="offset points",
ha="center",
va="bottom",
)
autolabel(rects2)
plt.tight_layout()
plt.savefig(plot_save_path / "plot.png", dpi=300, bbox_inches="tight")
@pytest.mark.skipif(
is_turing() or get_precision() == "fp4", reason="Skip tests due to using Turing GPUs or FP4 precision"
)
@pytest.mark.parametrize("expected_latency", [0.306])
def test_sdxl_turbo_time_cost(expected_latency: float):
batch_size = 8
runs = 5
guidance_scale = 0.0
inference_steps = 4
device_name = torch.cuda.get_device_name(0)
results = {"Nunchaku INT4": []}
quantized_unet = NunchakuSDXLUNet2DConditionModel.from_pretrained(
"nunchaku-tech/nunchaku-sdxl-turbo/svdq-int4_r32-sdxl-turbo.safetensors"
)
pipeline_quantized = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/sdxl-turbo", unet=quantized_unet, torch_dtype=torch.bfloat16, variant="fp16"
)
pipeline_quantized = pipeline_quantized.to("cuda")
benchmark_quantized = run_benchmark(
pipeline_quantized, batch_size, guidance_scale, device_name, runs, inference_steps
)
avg_latency = benchmark_quantized.mean() * inference_steps
results["Nunchaku INT4"].append(avg_latency)
ref_root = Path(os.environ.get("NUNCHAKU_TEST_CACHE_ROOT", os.path.join("test_results", "ref")))
plot_save_path = ref_root / "time_cost" / "sdxl-turbo"
os.makedirs(plot_save_path, exist_ok=True)
plot([batch_size], results, device_name, runs, inference_steps, plot_save_path, "SDXL-Turbo")
assert avg_latency < expected_latency * 1.1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment