Unverified Commit c5376c56 authored by Raul Ciotescu's avatar Raul Ciotescu Committed by GitHub
Browse files

adds the pipeline for pixart alpha controlnet (#8857)



* add the controlnet pipeline for pixart alpha

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarjunsongc <cjs1020440147@icloud.com>
parent 743a5697
...@@ -73,6 +73,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif ...@@ -73,6 +73,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) | | Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) | | FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) | | AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) | | HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) | | [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
...@@ -4445,3 +4446,94 @@ grid_image.save(grid_dir + "sample.png") ...@@ -4445,3 +4446,94 @@ grid_image.save(grid_dir + "sample.png")
`pag_scale` : guidance scale of PAG (ex: 5.0) `pag_scale` : guidance scale of PAG (ex: 5.0)
`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0']) `pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])
# PIXART-α Controlnet pipeline
[Project](https://pixart-alpha.github.io/) / [GitHub](https://github.com/PixArt-alpha/PixArt-alpha/blob/master/asset/docs/pixart_controlnet.md)
This the implementation of the controlnet model and the pipelne for the Pixart-alpha model, adapted to use the HuggingFace Diffusers.
## Example Usage
This example uses the Pixart HED Controlnet model, converted from the control net model as trained by the authors of the paper.
```py
import sys
import os
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
from diffusers.utils import load_image
from diffusers.image_processor import PixArtImageProcessor
from controlnet_aux import HEDdetector
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel
controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet"
weight_dtype = torch.float16
image_size = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
# load controlnet
controlnet = PixArtControlNetAdapterModel.from_pretrained(
controlnet_repo_id,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
pipe = PixArtAlphaControlnetPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
controlnet=controlnet,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
images_path = "images"
control_image_file = "0_7.jpg"
prompt = "battleship in space, galaxy in background"
control_image_name = control_image_file.split('.')[0]
control_image = load_image(f"{images_path}/{control_image_file}")
print(control_image.size)
height, width = control_image.size
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
condition_transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')),
T.CenterCrop([image_size, image_size]),
])
control_image = condition_transform(control_image)
hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)
hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg")
# run pipeline
with torch.no_grad():
out = pipe(
prompt=prompt,
image=hed_edge,
num_inference_steps=14,
guidance_scale=4.5,
height=image_size,
width=image_size,
)
out.images[0].save(f"{images_path}//{control_image_name}_output.jpg")
```
In the folder examples/pixart there is also a script that can be used to train new models.
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
\ No newline at end of file
images/
output/
\ No newline at end of file
from typing import Any, Dict, Optional
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import PixArtTransformer2DModel
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.torch_utils import is_torch_version
class PixArtControlNetAdapterBlock(nn.Module):
def __init__(
self,
block_index,
# taken from PixArtTransformer2DModel
num_attention_heads: int = 16,
attention_head_dim: int = 72,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = 1152,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: Optional[int] = 1000,
upcast_attention: bool = False,
norm_type: str = "ada_norm_single",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
attention_type: Optional[str] = "default",
):
super().__init__()
self.block_index = block_index
self.inner_dim = num_attention_heads * attention_head_dim
# the first block has a zero before layer
if self.block_index == 0:
self.before_proj = nn.Linear(self.inner_dim, self.inner_dim)
nn.init.zeros_(self.before_proj.weight)
nn.init.zeros_(self.before_proj.bias)
self.transformer_block = BasicTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
)
self.after_proj = nn.Linear(self.inner_dim, self.inner_dim)
nn.init.zeros_(self.after_proj.weight)
nn.init.zeros_(self.after_proj.bias)
def train(self, mode: bool = True):
self.transformer_block.train(mode)
if self.block_index == 0:
self.before_proj.train(mode)
self.after_proj.train(mode)
def forward(
self,
hidden_states: torch.Tensor,
controlnet_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
):
if self.block_index == 0:
controlnet_states = self.before_proj(controlnet_states)
controlnet_states = hidden_states + controlnet_states
controlnet_states_down = self.transformer_block(
hidden_states=controlnet_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
class_labels=None,
)
controlnet_states_left = self.after_proj(controlnet_states_down)
return controlnet_states_left, controlnet_states_down
class PixArtControlNetAdapterModel(ModelMixin, ConfigMixin):
# N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer
@register_to_config
def __init__(self, num_layers=13) -> None:
super().__init__()
self.num_layers = num_layers
self.controlnet_blocks = nn.ModuleList(
[PixArtControlNetAdapterBlock(block_index=i) for i in range(num_layers)]
)
@classmethod
def from_transformer(cls, transformer: PixArtTransformer2DModel):
control_net = PixArtControlNetAdapterModel()
# copied the specified number of blocks from the transformer
for depth in range(control_net.num_layers):
control_net.controlnet_blocks[depth].transformer_block.load_state_dict(
transformer.transformer_blocks[depth].state_dict()
)
return control_net
def train(self, mode: bool = True):
for block in self.controlnet_blocks:
block.train(mode)
class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
def __init__(
self,
transformer: PixArtTransformer2DModel,
controlnet: PixArtControlNetAdapterModel,
blocks_num=13,
init_from_transformer=False,
training=False,
):
super().__init__()
self.blocks_num = blocks_num
self.gradient_checkpointing = False
self.register_to_config(**transformer.config)
self.training = training
if init_from_transformer:
# copies the specified number of blocks from the transformer
controlnet.from_transformer(transformer, self.blocks_num)
self.transformer = transformer
self.controlnet = controlnet
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
controlnet_cond: Optional[torch.Tensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
if self.transformer.use_additional_conditions and added_cond_kwargs is None:
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size = hidden_states.shape[0]
height, width = (
hidden_states.shape[-2] // self.transformer.config.patch_size,
hidden_states.shape[-1] // self.transformer.config.patch_size,
)
hidden_states = self.transformer.pos_embed(hidden_states)
timestep, embedded_timestep = self.transformer.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
if self.transformer.caption_projection is not None:
encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
controlnet_states_down = None
if controlnet_cond is not None:
controlnet_states_down = self.transformer.pos_embed(controlnet_cond)
# 2. Blocks
for block_index, block in enumerate(self.transformer.transformer_blocks):
if self.training and self.gradient_checkpointing:
# rc todo: for training and gradient checkpointing
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
# the control nets are only used for the blocks 1 to self.blocks_num
if block_index > 0 and block_index <= self.blocks_num and controlnet_states_down is not None:
controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[
block_index - 1
](
hidden_states=hidden_states, # used only in the first block
controlnet_states=controlnet_states_down,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + controlnet_states_left
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=None,
)
# 3. Output
shift, scale = (
self.transformer.scale_shift_table[None]
+ embedded_timestep[:, None].to(self.transformer.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.transformer.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
hidden_states = self.transformer.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(
-1,
height,
width,
self.transformer.config.patch_size,
self.transformer.config.patch_size,
self.transformer.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(
-1,
self.transformer.out_channels,
height * self.transformer.config.patch_size,
width * self.transformer.config.patch_size,
)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
transformers
SentencePiece
torchvision
controlnet-aux
datasets
# wandb
\ No newline at end of file
import torch
import torchvision.transforms as T
from controlnet_aux import HEDdetector
from diffusers.utils import load_image
from examples.research_projects.pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel
from examples.research_projects.pixart.pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet"
weight_dtype = torch.float16
image_size = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
# load controlnet
controlnet = PixArtControlNetAdapterModel.from_pretrained(
controlnet_repo_id,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
pipe = PixArtAlphaControlnetPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
controlnet=controlnet,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
images_path = "images"
control_image_file = "0_7.jpg"
# prompt = "cinematic photo of superman in action . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
# prompt = "yellow modern car, city in background, beautiful rainy day"
# prompt = "modern villa, clear sky, suny day . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
# prompt = "robot dog toy in park . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
# prompt = "purple car, on highway, beautiful sunny day"
# prompt = "realistical photo of a loving couple standing in the open kitchen of the living room, cooking ."
prompt = "battleship in space, galaxy in background"
control_image_name = control_image_file.split(".")[0]
control_image = load_image(f"{images_path}/{control_image_file}")
print(control_image.size)
height, width = control_image.size
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
condition_transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB")),
T.CenterCrop([image_size, image_size]),
]
)
control_image = condition_transform(control_image)
hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)
hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg")
# run pipeline
with torch.no_grad():
out = pipe(
prompt=prompt,
image=hed_edge,
num_inference_steps=14,
guidance_scale=4.5,
height=image_size,
width=image_size,
)
out.images[0].save(f"{images_path}//{control_image_name}_output.jpg")
#!/bin/bash
# run
# accelerate config
# check with
# accelerate env
export MODEL_DIR="PixArt-alpha/PixArt-XL-2-512x512"
export OUTPUT_DIR="output/pixart-controlnet-hf-diffusers-test"
accelerate launch ./train_pixart_controlnet_hf.py --mixed_precision="fp16" \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=fusing/fill50k \
--resolution=512 \
--learning_rate=1e-5 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--report_to="wandb" \
--seed=42 \
--dataloader_num_workers=8
# --lr_scheduler="cosine" --lr_warmup_steps=0 \
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment