Unverified Commit 0ea357e2 authored by ttio2tech's avatar ttio2tech Committed by GitHub
Browse files

feat: support controlnet for qwenimagemodel (#681)

* add controlnet support to the qwenimagemodel and add example file for controlnet

* add controlnet support

* add controlnet support for qwenimage

* add controlnet support for qwenimage

* style: make linter happy

* update example script for qwen controlnet

* style: make linter happy

* update

* update diffusers version
parent 5809e9fa
# please use diffusers>=0.36
import torch
from diffusers import QwenImageControlNetModel, QwenImageControlNetPipeline
from diffusers.utils import load_image
from nunchaku.models.transformers.transformer_qwenimage import NunchakuQwenImageTransformer2DModel
from nunchaku.utils import get_gpu_memory, get_precision
model_name = "Qwen/Qwen-Image"
rank = 32 # you can also use rank=128 model to improve the quality
# Load components with correct dtype
controlnet = QwenImageControlNetModel.from_pretrained(
"InstantX/Qwen-Image-ControlNet-Union", torch_dtype=torch.bfloat16
)
transformer = NunchakuQwenImageTransformer2DModel.from_pretrained(
f"nunchaku-tech/nunchaku-qwen-image/svdq-{get_precision()}_r{rank}-qwen-image.safetensors"
)
# pip install git+https://github.com/huggingface/diffusers
# Create pipeline
pipeline = QwenImageControlNetPipeline.from_pretrained(
model_name, transformer=transformer, controlnet=controlnet, torch_dtype=torch.bfloat16
)
if get_gpu_memory() > 18:
pipeline.enable_model_cpu_offload()
else:
# use per-layer offloading for low VRAM. This only requires 3-4GB of VRAM.
transformer.set_offload(True)
pipeline._exclude_from_cpu_offload.append("transformer")
pipeline.enable_sequential_cpu_offload()
control_image = load_image("https://huggingface.co/InstantX/Qwen-Image-ControlNet-Union/resolve/main/conds/depth.png")
# Generate with control
image = pipeline(
prompt="A swanky, minimalist living room with a huge floor-to-ceiling window letting in loads of natural light. A beige couch with white cushions sits on a wooden floor, with a matching coffee table in front. The walls are a soft, warm beige, decorated with two framed botanical prints. A potted plant chills in the corner near the window. Sunlight pours through the leaves outside, casting cool shadows on the floor.",
negative_prompt=" ",
control_image=control_image,
controlnet_conditioning_scale=1.0,
num_inference_steps=30,
true_cfg_scale=4.0,
).images[0]
# Save the result
image.save(f"qwen-image-controlnet-r{rank}.png")
......@@ -9,6 +9,7 @@ from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
from warnings import warn
import numpy as np
import torch
from diffusers.models.attention_processor import Attention
from diffusers.models.modeling_outputs import Transformer2DModelOutput
......@@ -17,6 +18,7 @@ from diffusers.models.transformers.transformer_qwenimage import (
QwenImageTransformer2DModel,
QwenImageTransformerBlock,
)
from diffusers.utils import logging as diffusers_logging
from huggingface_hub import utils
from ...utils import get_precision
......@@ -26,6 +28,8 @@ from ..linear import AWQW4A16Linear, SVDQW4A4Linear
from ..utils import CPUOffloadManager, fuse_linears
from .utils import NunchakuModelLoaderMixin
logger = diffusers_logging.get_logger(__name__)
class NunchakuQwenAttention(NunchakuBaseAttention):
"""
......@@ -467,19 +471,20 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
txt_seq_lens: Optional[List[int]] = None,
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
Forward pass for the quantized QwenImage transformer model.
Forward pass for the Nunchaku QwenImage transformer model with ControlNet support.
Parameters
----------
hidden_states : torch.Tensor
Image stream input.
Image stream input of shape `(batch_size, image_sequence_length, in_channels)`.
encoder_hidden_states : torch.Tensor, optional
Text stream input.
Text stream input of shape `(batch_size, text_sequence_length, joint_attention_dim)`.
encoder_hidden_states_mask : torch.Tensor, optional
Mask for encoder hidden states.
Mask for encoder hidden states of shape `(batch_size, text_sequence_length)`.
timestep : torch.LongTensor, optional
Timestep for temporal embedding.
img_shapes : list of tuple, optional
......@@ -489,14 +494,17 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
guidance : torch.Tensor, optional
Guidance tensor (for classifier-free guidance).
attention_kwargs : dict, optional
Additional attention arguments.
Additional attention arguments. A kwargs dictionary that if specified is passed along to the `AttentionProcessor`.
controlnet_block_samples : optional
ControlNet block samples for residual connections.
return_dict : bool, default=True
Whether to return a dict or tuple.
Returns
-------
torch.Tensor or Transformer2DModelOutput
Model output.
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
device = hidden_states.device
if self.offload:
......@@ -526,14 +534,32 @@ class NunchakuQwenImageTransformer2DModel(QwenImageTransformer2DModel, NunchakuM
with torch.cuda.stream(compute_stream):
if self.offload:
block = self.offload_manager.get_block(block_idx)
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
)
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
encoder_hidden_states_mask,
temb,
image_rotary_emb,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=attention_kwargs,
)
# controlnet residual - same logic as in diffusers QwenImageTransformer2DModel
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states = hidden_states + controlnet_block_samples[block_idx // interval_control]
if self.offload:
self.offload_manager.step(compute_stream)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment