Unverified Commit 0ea357e2 authored by ttio2tech's avatar ttio2tech Committed by GitHub
Browse files

feat: support controlnet for qwenimagemodel (#681)

* add controlnet support to the qwenimagemodel and add example file for controlnet

* add controlnet support

* add controlnet support for qwenimage

* add controlnet support for qwenimage

* style: make linter happy

* update example script for qwen controlnet

* style: make linter happy

* update

* update diffusers version
parent 5809e9fa
# please use diffusers>=0.36
import torch
from diffusers import QwenImageControlNetModel, QwenImageControlNetPipeline
from diffusers.utils import load_image
from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
from nunchaku.utils import get_gpu_memory, get_precision
model_name = "Qwen/Qwen-Image"
rank = 32 # you can also use rank=128 model to improve the quality
# Load components with correct dtype
controlnet = QwenImageControlNetModel.from_pretrained(
"InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16
)
transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
f"nunchaku-tech/nunchaku-qwen-image/svdq-{get_precision()}_r{rank}-qwen-image.safetensors"
)
# pip install git+https://github.com/huggingface/diffusers
# Create pipeline
pipeline = QwenImageControlNetPipeline.from_pretrained(
model_name, transformer=transformer, controlnet=controlnet, torch_dtype=torch.bfloat16
)
if get_gpu_memory() > 18:
pipeline.enable_model_cpu_offload()
else:
# use per-layer offloading for low VRAM. This only requires 3-4GB of VRAM.
transformer.set_offload(True)
pipeline._exclude_from_cpu_offload.append("transformer")
pipeline.enable_sequential_cpu_offload()
control_image = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/depth.png")
# Generate with control
image = pipeline(
prompt="A swanky, minimalist living room with a huge floor-to-ceiling window letting in loads of natural light. A beige couch with white cushions sits on a wooden floor, with a matching coffee table in front. The walls are a soft, warm beige, decorated with two framed botanical prints. A potted plant chills in the corner near the window. Sunlight pours through the leaves outside, casting cool shadows on the floor.",
negative_prompt=" ",
control_image=control_image,
controlnet_conditioning_scale=1.0,
num_inference_steps=30,
true_cfg_scale=4.0,
).images[0]
# Save the result
image.save(f"qwen-image-controlnet-r{rank}.png")
...@@ -9,6 +9,7 @@ from pathlib import Path ...@@ -9,6 +9,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from warnings import warn from warnings import warn
import numpy as np
import torch import torch
from diffusers.models.attention_processor import Attention from diffusers.models.attention_processor import Attention
from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_outputs import Transformer2DModelOutput
...@@ -17,6 +18,7 @@ from diffusers.models.transformers.transformer_qwenimage import ( ...@@ -17,6 +18,7 @@ from diffusers.models.transformers.transformer_qwenimage import (
QwenImageTransformer2DModel, QwenImageTransformer2DModel,
QwenImageTransformerBlock, QwenImageTransformerBlock,
) )
from diffusers.utils import logging as diffusers_logging
from huggingface_hub import utils from huggingface_hub import utils
from ...utils import get_precision from ...utils import get_precision
...@@ -26,6 +28,8 @@ from ..linear import AWQW4A16Linear, SVDQW4A4Linear ...@@ -26,6 +28,8 @@ from ..linear import AWQW4A16Linear, SVDQW4A4Linear
from ..utils import CPUOffloadManager, fuse_linears from ..utils import CPUOffloadManager, fuse_linears
from .utils import NunchakuModelLoaderMixin from .utils import NunchakuModelLoaderMixin
logger = diffusers_logging.get_logger(__name__)
class NunchakuQwenAttention(NunchakuBaseAttention): class NunchakuQwenAttention(NunchakuBaseAttention):
""" """
...@@ -467,19 +471,20 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM ...@@ -467,19 +471,20 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
txt_seq_lens: Optional[List[int]] = None, txt_seq_lens: Optional[List[int]] = None,
guidance: torch.Tensor = None, guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
return_dict: bool = True, return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]: ) -> Union[torch.Tensor, Transformer2DModelOutput]:
""" """
Forward pass for the quantized QwenImage transformer model. Forward pass for the Nunchaku QwenImage transformer model with ControlNet support.
Parameters Parameters
---------- ----------
hidden_states : torch.Tensor hidden_states : torch.Tensor
Image stream input. Image stream input of shape `(batch_size, image_sequence_length, in_channels)`.
encoder_hidden_states : torch.Tensor, optional encoder_hidden_states : torch.Tensor, optional
Text stream input. Text stream input of shape `(batch_size, text_sequence_length, joint_attention_dim)`.
encoder_hidden_states_mask : torch.Tensor, optional encoder_hidden_states_mask : torch.Tensor, optional
Mask for encoder hidden states. Mask for encoder hidden states of shape `(batch_size, text_sequence_length)`.
timestep : torch.LongTensor, optional timestep : torch.LongTensor, optional
Timestep for temporal embedding. Timestep for temporal embedding.
img_shapes : list of tuple, optional img_shapes : list of tuple, optional
...@@ -489,14 +494,17 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM ...@@ -489,14 +494,17 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
guidance : torch.Tensor, optional guidance : torch.Tensor, optional
Guidance tensor (for classifier-free guidance). Guidance tensor (for classifier-free guidance).
attention_kwargs : dict, optional attention_kwargs : dict, optional
Additional attention arguments. Additional attention arguments. A kwargs dictionary that if specified is passed along to the `AttentionProcessor`.
controlnet_block_samples : optional
ControlNet block samples for residual connections.
return_dict : bool, default=True return_dict : bool, default=True
Whether to return a dict or tuple. Whether to return a dict or tuple.
Returns Returns
------- -------
torch.Tensor or Transformer2DModelOutput torch.Tensor or Transformer2DModelOutput
Model output. If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
""" """
device = hidden_states.device device = hidden_states.device
if self.offload: if self.offload:
...@@ -526,14 +534,32 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM ...@@ -526,14 +534,32 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
with torch.cuda.stream(compute_stream): with torch.cuda.stream(compute_stream):
if self.offload: if self.offload:
block = self.offload_manager.get_block(block_idx) block = self.offload_manager.get_block(block_idx)
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
encoder_hidden_states_mask=encoder_hidden_states_mask, block,
temb=temb, hidden_states,
image_rotary_emb=image_rotary_emb, encoder_hidden_states,
joint_attention_kwargs=attention_kwargs, encoder_hidden_states_mask,
) temb,
image_rotary_emb,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
)
# controlnet residual - same logic as in diffusers QwenImageTransformer2DModel
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[block_idx // interval_control]
if self.offload: if self.offload:
self.offload_manager.step(compute_stream) self.offload_manager.step(compute_stream)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment