Unverified Commit c5376c56 authored by Raul Ciotescu's avatar Raul Ciotescu Committed by GitHub
Browse files

adds the pipeline for pixart alpha controlnet (#8857)



* add the controlnet pipeline for pixart alpha

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarjunsongc <cjs1020440147@icloud.com>
parent 743a5697
...@@ -73,6 +73,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif ...@@ -73,6 +73,7 @@ Please also check out our [Community Scripts](https://github.com/huggingface/dif
| Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) | | Stable Diffusion BoxDiff Pipeline | Training-free controlled generation with bounding boxes using [BoxDiff](https://github.com/showlab/BoxDiff) | [Stable Diffusion BoxDiff Pipeline](#stable-diffusion-boxdiff) | - | [Jingyang Zhang](https://github.com/zjysteven/) |
| FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) | | FRESCO V2V Pipeline | Implementation of [[CVPR 2024] FRESCO: Spatial-Temporal Correspondence for Zero-Shot Video Translation](https://arxiv.org/abs/2403.12962) | [FRESCO V2V Pipeline](#fresco) | - | [Yifan Zhou](https://github.com/SingleZombie) |
| AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) | | AnimateDiff IPEX Pipeline | Accelerate AnimateDiff inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [AnimateDiff on IPEX](#animatediff-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) |
PIXART-α Controlnet pipeline | Implementation of the controlnet model for pixart alpha and its diffusers pipeline | [PIXART-α Controlnet pipeline](#pixart-α-controlnet-pipeline) | - | [Raul Ciotescu](https://github.com/raulc0399/) |
| HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) | | HunyuanDiT Differential Diffusion Pipeline | Applies [Differential Diffusion](https://github.com/exx8/differential-diffusion) to [HunyuanDiT](https://github.com/huggingface/diffusers/pull/8240). | [HunyuanDiT with Differential Diffusion](#hunyuandit-with-differential-diffusion) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1v44a5fpzyr4Ffr4v2XBQ7BajzG874N4P?usp=sharing) | [Monjoy Choudhury](https://github.com/MnCSSJ4x) |
| [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) | | [🪆Matryoshka Diffusion Models](https://huggingface.co/papers/2310.15111) | A diffusion process that denoises inputs at multiple resolutions jointly and uses a NestedUNet architecture where features and parameters for small scale inputs are nested within those of the large scales. See [original codebase](https://github.com/apple/ml-mdm). | [🪆Matryoshka Diffusion Models](#matryoshka-diffusion-models) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/pcuenq/mdm) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/tolgacangoz/1f54875fc7aeaabcf284ebde64820966/matryoshka_hf.ipynb) | [M. Tolga Cangöz](https://github.com/tolgacangoz) |
...@@ -4445,3 +4446,94 @@ grid_image.save(grid_dir + "sample.png") ...@@ -4445,3 +4446,94 @@ grid_image.save(grid_dir + "sample.png")
`pag_scale` : guidance scale of PAG (ex: 5.0) `pag_scale` : guidance scale of PAG (ex: 5.0)
`pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0']) `pag_applied_layers_index` : index of the layer to apply perturbation (ex: ['m0'])
# PIXART-α Controlnet pipeline
[Project](https://pixart-alpha.github.io/) / [GitHub](https://github.com/PixArt-alpha/PixArt-alpha/blob/master/asset/docs/pixart_controlnet.md)
This the implementation of the controlnet model and the pipelne for the Pixart-alpha model, adapted to use the HuggingFace Diffusers.
## Example Usage
This example uses the Pixart HED Controlnet model, converted from the control net model as trained by the authors of the paper.
```py
import sys
import os
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
from diffusers.utils import load_image
from diffusers.image_processor import PixArtImageProcessor
from controlnet_aux import HEDdetector
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel
controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet"
weight_dtype = torch.float16
image_size = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
# load controlnet
controlnet = PixArtControlNetAdapterModel.from_pretrained(
controlnet_repo_id,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
pipe = PixArtAlphaControlnetPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
controlnet=controlnet,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
images_path = "images"
control_image_file = "0_7.jpg"
prompt = "battleship in space, galaxy in background"
control_image_name = control_image_file.split('.')[0]
control_image = load_image(f"{images_path}/{control_image_file}")
print(control_image.size)
height, width = control_image.size
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
condition_transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB')),
T.CenterCrop([image_size, image_size]),
])
control_image = condition_transform(control_image)
hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)
hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg")
# run pipeline
with torch.no_grad():
out = pipe(
prompt=prompt,
image=hed_edge,
num_inference_steps=14,
guidance_scale=4.5,
height=image_size,
width=image_size,
)
out.images[0].save(f"{images_path}//{control_image_name}_output.jpg")
```
In the folder examples/pixart there is also a script that can be used to train new models.
Please check the script `train_controlnet_hf_diffusers.sh` on how to start the training.
\ No newline at end of file
images/
output/
\ No newline at end of file
from typing import Any, Dict, Optional
import torch
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import PixArtTransformer2DModel
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.torch_utils import is_torch_version
class PixArtControlNetAdapterBlock(nn.Module):
def __init__(
self,
block_index,
# taken from PixArtTransformer2DModel
num_attention_heads: int = 16,
attention_head_dim: int = 72,
dropout: float = 0.0,
cross_attention_dim: Optional[int] = 1152,
attention_bias: bool = True,
activation_fn: str = "gelu-approximate",
num_embeds_ada_norm: Optional[int] = 1000,
upcast_attention: bool = False,
norm_type: str = "ada_norm_single",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
attention_type: Optional[str] = "default",
):
super().__init__()
self.block_index = block_index
self.inner_dim = num_attention_heads * attention_head_dim
# the first block has a zero before layer
if self.block_index == 0:
self.before_proj = nn.Linear(self.inner_dim, self.inner_dim)
nn.init.zeros_(self.before_proj.weight)
nn.init.zeros_(self.before_proj.bias)
self.transformer_block = BasicTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
)
self.after_proj = nn.Linear(self.inner_dim, self.inner_dim)
nn.init.zeros_(self.after_proj.weight)
nn.init.zeros_(self.after_proj.bias)
def train(self, mode: bool = True):
self.transformer_block.train(mode)
if self.block_index == 0:
self.before_proj.train(mode)
self.after_proj.train(mode)
def forward(
self,
hidden_states: torch.Tensor,
controlnet_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
):
if self.block_index == 0:
controlnet_states = self.before_proj(controlnet_states)
controlnet_states = hidden_states + controlnet_states
controlnet_states_down = self.transformer_block(
hidden_states=controlnet_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
class_labels=None,
)
controlnet_states_left = self.after_proj(controlnet_states_down)
return controlnet_states_left, controlnet_states_down
class PixArtControlNetAdapterModel(ModelMixin, ConfigMixin):
# N=13, as specified in the paper https://arxiv.org/html/2401.05252v1/#S4 ControlNet-Transformer
@register_to_config
def __init__(self, num_layers=13) -> None:
super().__init__()
self.num_layers = num_layers
self.controlnet_blocks = nn.ModuleList(
[PixArtControlNetAdapterBlock(block_index=i) for i in range(num_layers)]
)
@classmethod
def from_transformer(cls, transformer: PixArtTransformer2DModel):
control_net = PixArtControlNetAdapterModel()
# copied the specified number of blocks from the transformer
for depth in range(control_net.num_layers):
control_net.controlnet_blocks[depth].transformer_block.load_state_dict(
transformer.transformer_blocks[depth].state_dict()
)
return control_net
def train(self, mode: bool = True):
for block in self.controlnet_blocks:
block.train(mode)
class PixArtControlNetTransformerModel(ModelMixin, ConfigMixin):
def __init__(
self,
transformer: PixArtTransformer2DModel,
controlnet: PixArtControlNetAdapterModel,
blocks_num=13,
init_from_transformer=False,
training=False,
):
super().__init__()
self.blocks_num = blocks_num
self.gradient_checkpointing = False
self.register_to_config(**transformer.config)
self.training = training
if init_from_transformer:
# copies the specified number of blocks from the transformer
controlnet.from_transformer(transformer, self.blocks_num)
self.transformer = transformer
self.controlnet = controlnet
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
controlnet_cond: Optional[torch.Tensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
if self.transformer.use_additional_conditions and added_cond_kwargs is None:
raise ValueError("`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`.")
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Input
batch_size = hidden_states.shape[0]
height, width = (
hidden_states.shape[-2] // self.transformer.config.patch_size,
hidden_states.shape[-1] // self.transformer.config.patch_size,
)
hidden_states = self.transformer.pos_embed(hidden_states)
timestep, embedded_timestep = self.transformer.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
if self.transformer.caption_projection is not None:
encoder_hidden_states = self.transformer.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
controlnet_states_down = None
if controlnet_cond is not None:
controlnet_states_down = self.transformer.pos_embed(controlnet_cond)
# 2. Blocks
for block_index, block in enumerate(self.transformer.transformer_blocks):
if self.training and self.gradient_checkpointing:
# rc todo: for training and gradient checkpointing
print("Gradient checkpointing is not supported for the controlnet transformer model, yet.")
exit(1)
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
None,
**ckpt_kwargs,
)
else:
# the control nets are only used for the blocks 1 to self.blocks_num
if block_index > 0 and block_index <= self.blocks_num and controlnet_states_down is not None:
controlnet_states_left, controlnet_states_down = self.controlnet.controlnet_blocks[
block_index - 1
](
hidden_states=hidden_states, # used only in the first block
controlnet_states=controlnet_states_down,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + controlnet_states_left
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=None,
)
# 3. Output
shift, scale = (
self.transformer.scale_shift_table[None]
+ embedded_timestep[:, None].to(self.transformer.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.transformer.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale.to(hidden_states.device)) + shift.to(hidden_states.device)
hidden_states = self.transformer.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
hidden_states = hidden_states.reshape(
shape=(
-1,
height,
width,
self.transformer.config.patch_size,
self.transformer.config.patch_size,
self.transformer.out_channels,
)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(
-1,
self.transformer.out_channels,
height * self.transformer.config.patch_size,
width * self.transformer.config.patch_size,
)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
# Copyright 2024 PixArt-Alpha Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import inspect
import re
import urllib.parse as ul
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from controlnet_pixart_alpha import PixArtControlNetAdapterModel, PixArtControlNetTransformerModel
from transformers import T5EncoderModel, T5Tokenizer
from diffusers.image_processor import PipelineImageInput, PixArtImageProcessor
from diffusers.models import AutoencoderKL, PixArtTransformer2DModel
from diffusers.pipelines import DiffusionPipeline, ImagePipelineOutput
from diffusers.schedulers import DPMSolverMultistepScheduler
from diffusers.utils import (
BACKENDS_MAPPING,
deprecate,
is_bs4_available,
is_ftfy_available,
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_bs4_available():
from bs4 import BeautifulSoup
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import PixArtAlphaPipeline
>>> # You can replace the checkpoint id with "PixArt-alpha/PixArt-XL-2-512x512" too.
>>> pipe = PixArtAlphaPipeline.from_pretrained("PixArt-alpha/PixArt-XL-2-1024-MS", torch_dtype=torch.float16)
>>> # Enable memory optimizations.
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A small cactus with a happy face in the Sahara desert."
>>> image = pipe(prompt).images[0]
```
"""
ASPECT_RATIO_1024_BIN = {
"0.25": [512.0, 2048.0],
"0.28": [512.0, 1856.0],
"0.32": [576.0, 1792.0],
"0.33": [576.0, 1728.0],
"0.35": [576.0, 1664.0],
"0.4": [640.0, 1600.0],
"0.42": [640.0, 1536.0],
"0.48": [704.0, 1472.0],
"0.5": [704.0, 1408.0],
"0.52": [704.0, 1344.0],
"0.57": [768.0, 1344.0],
"0.6": [768.0, 1280.0],
"0.68": [832.0, 1216.0],
"0.72": [832.0, 1152.0],
"0.78": [896.0, 1152.0],
"0.82": [896.0, 1088.0],
"0.88": [960.0, 1088.0],
"0.94": [960.0, 1024.0],
"1.0": [1024.0, 1024.0],
"1.07": [1024.0, 960.0],
"1.13": [1088.0, 960.0],
"1.21": [1088.0, 896.0],
"1.29": [1152.0, 896.0],
"1.38": [1152.0, 832.0],
"1.46": [1216.0, 832.0],
"1.67": [1280.0, 768.0],
"1.75": [1344.0, 768.0],
"2.0": [1408.0, 704.0],
"2.09": [1472.0, 704.0],
"2.4": [1536.0, 640.0],
"2.5": [1600.0, 640.0],
"3.0": [1728.0, 576.0],
"4.0": [2048.0, 512.0],
}
ASPECT_RATIO_512_BIN = {
"0.25": [256.0, 1024.0],
"0.28": [256.0, 928.0],
"0.32": [288.0, 896.0],
"0.33": [288.0, 864.0],
"0.35": [288.0, 832.0],
"0.4": [320.0, 800.0],
"0.42": [320.0, 768.0],
"0.48": [352.0, 736.0],
"0.5": [352.0, 704.0],
"0.52": [352.0, 672.0],
"0.57": [384.0, 672.0],
"0.6": [384.0, 640.0],
"0.68": [416.0, 608.0],
"0.72": [416.0, 576.0],
"0.78": [448.0, 576.0],
"0.82": [448.0, 544.0],
"0.88": [480.0, 544.0],
"0.94": [480.0, 512.0],
"1.0": [512.0, 512.0],
"1.07": [512.0, 480.0],
"1.13": [544.0, 480.0],
"1.21": [544.0, 448.0],
"1.29": [576.0, 448.0],
"1.38": [576.0, 416.0],
"1.46": [608.0, 416.0],
"1.67": [640.0, 384.0],
"1.75": [672.0, 384.0],
"2.0": [704.0, 352.0],
"2.09": [736.0, 352.0],
"2.4": [768.0, 320.0],
"2.5": [800.0, 320.0],
"3.0": [864.0, 288.0],
"4.0": [1024.0, 256.0],
}
ASPECT_RATIO_256_BIN = {
"0.25": [128.0, 512.0],
"0.28": [128.0, 464.0],
"0.32": [144.0, 448.0],
"0.33": [144.0, 432.0],
"0.35": [144.0, 416.0],
"0.4": [160.0, 400.0],
"0.42": [160.0, 384.0],
"0.48": [176.0, 368.0],
"0.5": [176.0, 352.0],
"0.52": [176.0, 336.0],
"0.57": [192.0, 336.0],
"0.6": [192.0, 320.0],
"0.68": [208.0, 304.0],
"0.72": [208.0, 288.0],
"0.78": [224.0, 288.0],
"0.82": [224.0, 272.0],
"0.88": [240.0, 272.0],
"0.94": [240.0, 256.0],
"1.0": [256.0, 256.0],
"1.07": [256.0, 240.0],
"1.13": [272.0, 240.0],
"1.21": [272.0, 224.0],
"1.29": [288.0, 224.0],
"1.38": [288.0, 208.0],
"1.46": [304.0, 208.0],
"1.67": [320.0, 192.0],
"1.75": [336.0, 192.0],
"2.0": [352.0, 176.0],
"2.09": [368.0, 176.0],
"2.4": [384.0, 160.0],
"2.5": [400.0, 160.0],
"3.0": [432.0, 144.0],
"4.0": [512.0, 128.0],
}
def get_closest_hw(width, height, image_size):
if image_size == 1024:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
elif image_size == 512:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
else:
raise ValueError("Invalid image size")
height, width = PixArtImageProcessor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
return width, height
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class PixArtAlphaControlnetPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using PixArt-Alpha.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder. PixArt-Alpha uses
[T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the
[t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant.
tokenizer (`T5Tokenizer`):
Tokenizer of class
[T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer).
transformer ([`PixArtTransformer2DModel`]):
A text conditioned `PixArtTransformer2DModel` to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
bad_punct_regex = re.compile(
r"["
+ "#®•©™&@·º½¾¿¡§~"
+ r"\)"
+ r"\("
+ r"\]"
+ r"\["
+ r"\}"
+ r"\{"
+ r"\|"
+ "\\"
+ r"\/"
+ r"\*"
+ r"]{1,}"
) # noqa
_optional_components = ["tokenizer", "text_encoder"]
model_cpu_offload_seq = "text_encoder->transformer->vae"
def __init__(
self,
tokenizer: T5Tokenizer,
text_encoder: T5EncoderModel,
vae: AutoencoderKL,
transformer: PixArtTransformer2DModel,
controlnet: PixArtControlNetAdapterModel,
scheduler: DPMSolverMultistepScheduler,
):
super().__init__()
# change to the controlnet transformer model
transformer = PixArtControlNetTransformerModel(transformer=transformer, controlnet=controlnet)
self.register_modules(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
controlnet=controlnet,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.control_image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
clean_caption: bool = False,
max_sequence_length: int = 120,
**kwargs,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
PixArt-Alpha, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha, it's should be the embeddings of the ""
string.
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 120): Maximum sequence length to use for the prompt.
"""
if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
if device is None:
device = self._execution_device
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# See Section 3.1. of the paper.
max_length = max_sequence_length
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
text_input_ids, untruncated_ids
):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because T5 can only handle sequences up to"
f" {max_length} tokens: {removed_text}"
)
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0]
if self.text_encoder is not None:
dtype = self.text_encoder.dtype
elif self.transformer is not None:
dtype = self.transformer.controlnet.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt,
callback_steps,
image=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
if image is not None:
self.check_image(image, prompt, prompt_embeds)
# Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
def check_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_np = isinstance(image, np.ndarray)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
if (
not image_is_pil
and not image_is_tensor
and not image_is_np
and not image_is_pil_list
and not image_is_tensor_list
and not image_is_np_list
):
raise TypeError(
f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
)
if image_is_pil:
image_batch_size = 1
else:
image_batch_size = len(image)
if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
if image_batch_size != 1 and image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
if clean_caption and not is_bs4_available():
logger.warning(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if clean_caption and not is_ftfy_available():
logger.warning(BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`"))
logger.warning("Setting `clean_caption` to False...")
clean_caption = False
if not isinstance(text, (tuple, list)):
text = [text]
def process(text: str):
if clean_caption:
text = self._clean_caption(text)
text = self._clean_caption(text)
else:
text = text.lower().strip()
return text
return [process(t) for t in text]
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption
def _clean_caption(self, caption):
caption = str(caption)
caption = ul.unquote_plus(caption)
caption = caption.strip().lower()
caption = re.sub("<person>", "person", caption)
# urls:
caption = re.sub(
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
caption = re.sub(
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
"",
caption,
) # regex for urls
# html:
caption = BeautifulSoup(caption, features="html.parser").text
# @<nickname>
caption = re.sub(r"@[\w\d]+\b", "", caption)
# 31C0—31EF CJK Strokes
# 31F0—31FF Katakana Phonetic Extensions
# 3200—32FF Enclosed CJK Letters and Months
# 3300—33FF CJK Compatibility
# 3400—4DBF CJK Unified Ideographs Extension A
# 4DC0—4DFF Yijing Hexagram Symbols
# 4E00—9FFF CJK Unified Ideographs
caption = re.sub(r"[\u31c0-\u31ef]+", "", caption)
caption = re.sub(r"[\u31f0-\u31ff]+", "", caption)
caption = re.sub(r"[\u3200-\u32ff]+", "", caption)
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
#######################################################
# все виды тире / all types of dash --> "-"
caption = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
"-",
caption,
)
# кавычки к одному стандарту
caption = re.sub(r"[`´«»“”¨]", '"', caption)
caption = re.sub(r"[‘’]", "'", caption)
# &quot;
caption = re.sub(r"&quot;?", "", caption)
# &amp
caption = re.sub(r"&amp", "", caption)
# ip adresses:
caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption)
# article ids:
caption = re.sub(r"\d:\d\d\s+$", "", caption)
# \n
caption = re.sub(r"\\n", " ", caption)
# "#123"
caption = re.sub(r"#\d{1,3}\b", "", caption)
# "#12345.."
caption = re.sub(r"#\d{5,}\b", "", caption)
# "123456.."
caption = re.sub(r"\b\d{6,}\b", "", caption)
# filenames:
caption = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption)
#
caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT"""
caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT"""
caption = re.sub(self.bad_punct_regex, r" ", caption) # ***AUSVERKAUFT***, #AUSVERKAUFT
caption = re.sub(r"\s+\.\s+", r" ", caption) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, caption)) > 3:
caption = re.sub(regex2, " ", caption)
caption = ftfy.fix_text(caption)
caption = html.unescape(html.unescape(caption))
caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640
caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc
caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231
caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption)
caption = re.sub(r"(free\s)?download(\sfree)?", "", caption)
caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption)
caption = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption)
caption = re.sub(r"\bpage\s+\d+\b", "", caption)
caption = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption) # j2d1a2a...
caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption)
caption = re.sub(r"\b\s+\:\s+", r": ", caption)
caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption)
caption = re.sub(r"\s+", " ", caption)
caption.strip()
caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption)
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
caption = re.sub(r"^\.\S+$", "", caption)
return caption.strip()
def prepare_image(
self,
image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance=False,
):
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
image_batch_size = image.shape[0]
if image_batch_size == 1:
repeat_by = batch_size
else:
# image batch size is the same as prompt batch size
repeat_by = num_images_per_prompt
image = image.repeat_interleave(repeat_by, dim=0)
image = image.to(device=device, dtype=dtype)
if do_classifier_free_guidance:
image = torch.cat([image] * 2)
return image
# based on pipeline_pixart_inpaiting.py
def prepare_image_latents(self, image, device, dtype):
image = image.to(device=device, dtype=dtype)
image_latents = self.vae.encode(image).latent_dist.sample()
image_latents = image_latents * self.vae.config.scaling_factor
return image_latents
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: Optional[int] = None,
width: Optional[int] = None,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
# rc todo: controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
# rc todo: control_guidance_start = 0.0,
# rc todo: control_guidance_end = 1.0,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.Tensor], None]] = None,
callback_steps: int = 1,
clean_caption: bool = True,
use_resolution_binning: bool = True,
max_sequence_length: int = 120,
**kwargs,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
height (`int`, *optional*, defaults to self.unet.config.sample_size):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size):
The width in pixels of the generated image.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*): Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Alpha this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
clean_caption (`bool`, *optional*, defaults to `True`):
Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to
be installed. If the dependencies are not installed, the embeddings will be created from the raw
prompt.
use_resolution_binning (`bool` defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
`ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to
the requested resolution. Useful for generating non-square images.
max_sequence_length (`int` defaults to 120): Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is a list with the generated images
"""
if "mask_feature" in kwargs:
deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version."
deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False)
# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
if use_resolution_binning:
if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
raise ValueError("Invalid sample size")
orig_height, orig_width = height, width
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
self.check_inputs(
prompt,
height,
width,
negative_prompt,
callback_steps,
image,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
)
# 2. Default height and width to transformer
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. Encode input prompt
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt,
do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
)
# 4.1 Prepare image
image_latents = None
if image is not None:
image = self.prepare_image(
image=image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.transformer.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
image_latents = self.prepare_image_latents(image, device, self.transformer.controlnet.dtype)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
latent_channels,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 6.1 Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if self.transformer.config.sample_size == 128:
resolution = torch.tensor([height, width]).repeat(batch_size * num_images_per_prompt, 1)
aspect_ratio = torch.tensor([float(height / width)]).repeat(batch_size * num_images_per_prompt, 1)
resolution = resolution.to(dtype=prompt_embeds.dtype, device=device)
aspect_ratio = aspect_ratio.to(dtype=prompt_embeds.dtype, device=device)
if do_classifier_free_guidance:
resolution = torch.cat([resolution, resolution], dim=0)
aspect_ratio = torch.cat([aspect_ratio, aspect_ratio], dim=0)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
latent_model_input,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=current_timestep,
controlnet_cond=image_latents,
# rc todo: controlnet_conditioning_scale=1.0,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
noise_pred = noise_pred.chunk(2, dim=1)[0]
else:
noise_pred = noise_pred
# compute previous image: x_t -> x_t-1
if num_inference_steps == 1:
# For DMD one step sampling: https://arxiv.org/abs/2311.18828
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
else:
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
else:
image = latents
if not output_type == "latent":
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
transformers
SentencePiece
torchvision
controlnet-aux
datasets
# wandb
\ No newline at end of file
import torch
import torchvision.transforms as T
from controlnet_aux import HEDdetector
from diffusers.utils import load_image
from examples.research_projects.pixart.controlnet_pixart_alpha import PixArtControlNetAdapterModel
from examples.research_projects.pixart.pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
controlnet_repo_id = "raulc0399/pixart-alpha-hed-controlnet"
weight_dtype = torch.float16
image_size = 1024
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(0)
# load controlnet
controlnet = PixArtControlNetAdapterModel.from_pretrained(
controlnet_repo_id,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
pipe = PixArtAlphaControlnetPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
controlnet=controlnet,
torch_dtype=weight_dtype,
use_safetensors=True,
).to(device)
images_path = "images"
control_image_file = "0_7.jpg"
# prompt = "cinematic photo of superman in action . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
# prompt = "yellow modern car, city in background, beautiful rainy day"
# prompt = "modern villa, clear sky, suny day . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
# prompt = "robot dog toy in park . 35mm photograph, film, bokeh, professional, 4k, highly detailed"
# prompt = "purple car, on highway, beautiful sunny day"
# prompt = "realistical photo of a loving couple standing in the open kitchen of the living room, cooking ."
prompt = "battleship in space, galaxy in background"
control_image_name = control_image_file.split(".")[0]
control_image = load_image(f"{images_path}/{control_image_file}")
print(control_image.size)
height, width = control_image.size
hed = HEDdetector.from_pretrained("lllyasviel/Annotators")
condition_transform = T.Compose(
[
T.Lambda(lambda img: img.convert("RGB")),
T.CenterCrop([image_size, image_size]),
]
)
control_image = condition_transform(control_image)
hed_edge = hed(control_image, detect_resolution=image_size, image_resolution=image_size)
hed_edge.save(f"{images_path}/{control_image_name}_hed.jpg")
# run pipeline
with torch.no_grad():
out = pipe(
prompt=prompt,
image=hed_edge,
num_inference_steps=14,
guidance_scale=4.5,
height=image_size,
width=image_size,
)
out.images[0].save(f"{images_path}//{control_image_name}_output.jpg")
#!/bin/bash
# run
# accelerate config
# check with
# accelerate env
export MODEL_DIR="PixArt-alpha/PixArt-XL-2-512x512"
export OUTPUT_DIR="output/pixart-controlnet-hf-diffusers-test"
accelerate launch ./train_pixart_controlnet_hf.py --mixed_precision="fp16" \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=fusing/fill50k \
--resolution=512 \
--learning_rate=1e-5 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--report_to="wandb" \
--seed=42 \
--dataloader_num_workers=8
# --lr_scheduler="cosine" --lr_warmup_steps=0 \
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fine-tuning script for Stable Diffusion for text2image with HuggingFace diffusers."""
import argparse
import gc
import logging
import math
import os
import random
import shutil
from pathlib import Path
import accelerate
import datasets
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from PIL import Image
from pipeline_pixart_alpha_controlnet import PixArtAlphaControlnetPipeline
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import T5EncoderModel, T5Tokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler
from diffusers.models import PixArtTransformer2DModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import compute_snr
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
from examples.research_projects.pixart.controlnet_pixart_alpha import (
PixArtControlNetAdapterModel,
PixArtControlNetTransformerModel,
)
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.29.2")
logger = get_logger(__name__, log_level="INFO")
def log_validation(
vae,
transformer,
controlnet,
tokenizer,
scheduler,
text_encoder,
args,
accelerator,
weight_dtype,
step,
is_final_validation=False,
):
if weight_dtype == torch.float16 or weight_dtype == torch.bfloat16:
raise ValueError(
"Validation is not supported with mixed precision training, disable validation and use the validation script, that will generate images from the saved checkpoints."
)
if not is_final_validation:
logger.info(f"Running validation step {step} ... ")
controlnet = accelerator.unwrap_model(controlnet)
pipeline = PixArtAlphaControlnetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
else:
logger.info("Running validation - final ... ")
controlnet = PixArtControlNetAdapterModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype)
pipeline = PixArtAlphaControlnetPipeline.from_pretrained(
args.pretrained_model_name_or_path,
controlnet=controlnet,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
if args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if len(args.validation_image) == len(args.validation_prompt):
validation_images = args.validation_image
validation_prompts = args.validation_prompt
elif len(args.validation_image) == 1:
validation_images = args.validation_image * len(args.validation_prompt)
validation_prompts = args.validation_prompt
elif len(args.validation_prompt) == 1:
validation_images = args.validation_image
validation_prompts = args.validation_prompt * len(args.validation_image)
else:
raise ValueError(
"number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`"
)
image_logs = []
for validation_prompt, validation_image in zip(validation_prompts, validation_images):
validation_image = Image.open(validation_image).convert("RGB")
validation_image = validation_image.resize((args.resolution, args.resolution))
images = []
for _ in range(args.num_validation_images):
image = pipeline(
prompt=validation_prompt, image=validation_image, num_inference_steps=20, generator=generator
).images[0]
images.append(image)
image_logs.append(
{"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt}
)
tracker_key = "test" if is_final_validation else "validation"
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
for log in image_logs:
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images = []
formatted_images.append(np.asarray(validation_image))
for image in images:
formatted_images.append(np.asarray(image))
formatted_images = np.stack(formatted_images)
tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC")
elif tracker.name == "wandb":
formatted_images = []
for log in image_logs:
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning"))
for image in images:
image = wandb.Image(image, caption=validation_prompt)
formatted_images.append(image)
tracker.log({tracker_key: formatted_images})
else:
logger.warning(f"image logging not implemented for {tracker.name}")
del pipeline
gc.collect()
torch.cuda.empty_cache()
logger.info("Validation done!!")
return image_logs
def save_model_card(repo_id: str, image_logs=None, base_model=str, dataset_name=str, repo_folder=None):
img_str = ""
if image_logs is not None:
img_str = "You can find some example images below.\n\n"
for i, log in enumerate(image_logs):
images = log["images"]
validation_prompt = log["validation_prompt"]
validation_image = log["validation_image"]
validation_image.save(os.path.join(repo_folder, "image_control.png"))
img_str += f"prompt: {validation_prompt}\n"
images = [validation_image] + images
make_image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
img_str += f"![images_{i})](./images_{i}.png)\n"
model_description = f"""
# controlnet-{repo_id}
These are controlnet weights trained on {base_model} with new type of conditioning.
{img_str}
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
base_model=base_model,
model_description=model_description,
inference=True,
)
tags = [
"pixart-alpha",
"pixart-alpha-diffusers",
"text-to-image",
"diffusers",
"controlnet",
"diffusers-training",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--variant",
type=str,
default=None,
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
)
parser.add_argument(
"--controlnet_model_name_or_path",
type=str,
default=None,
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
" If not specified controlnet weights are initialized from the transformer.",
)
parser.add_argument(
"--dataset_name",
type=str,
default=None,
help=(
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
" or to a folder containing files that 🤗 Datasets can understand."
),
)
parser.add_argument(
"--dataset_config_name",
type=str,
default=None,
help="The config of the Dataset, leave as None if there's only one config.",
)
parser.add_argument(
"--train_data_dir",
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
)
parser.add_argument(
"--conditioning_image_column",
type=str,
default="conditioning_image",
help="The column of the dataset containing the controlnet conditioning image.",
)
parser.add_argument(
"--caption_column",
type=str,
default="text",
help="The column of the dataset containing a caption or a list of captions.",
)
parser.add_argument(
"--validation_prompt",
type=str,
nargs="+",
default=None,
help="One or more prompts to be evaluated every `--validation_steps`."
" Provide either a matching number of `--validation_image`s, a single `--validation_image`"
" to be used with all prompts, or a single prompt that will be used with all `--validation_image`s.",
)
parser.add_argument(
"--validation_image",
type=str,
default=None,
nargs="+",
help=(
"A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`"
" and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a"
" a single `--validation_prompt` to be used with all `--validation_image`s, or a single"
" `--validation_image` that will be used with all `--validation_prompt`s."
),
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_steps",
type=int,
default=100,
help=(
"Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--max_train_samples",
type=int,
default=None,
help=(
"For debugging purposes or quicker training, truncate the number of training examples to this "
"value if set."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="pixart-controlnet",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--cache_dir",
type=str,
default=None,
help="The directory where the downloaded models and datasets will be stored.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
# ----Diffusion Training Arguments----
parser.add_argument(
"--proportion_empty_prompts",
type=float,
default=0,
help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).",
)
parser.add_argument(
"--prediction_type",
type=str,
default=None,
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
)
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--tracker_project_name",
type=str,
default="pixart_controlnet",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
args = parser.parse_args()
# Sanity checks
if args.dataset_name is None and args.train_data_dir is None:
raise ValueError("Need either a dataset name or a training folder.")
if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1:
raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].")
return args
def main():
args = parse_args()
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
"You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token."
" Please use `huggingface-cli login` to authenticate with the Hub."
)
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id
# See Section 3.1. of the paper.
max_length = 120
# For mixed precision training we cast all non-trainable weigths (vae, text_encoder) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(
args.pretrained_model_name_or_path, subfolder="scheduler", torch_dtype=weight_dtype
)
tokenizer = T5Tokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, torch_dtype=weight_dtype
)
text_encoder = T5EncoderModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, torch_dtype=weight_dtype
)
text_encoder.requires_grad_(False)
text_encoder.to(accelerator.device)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="vae",
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
vae.requires_grad_(False)
vae.to(accelerator.device)
transformer = PixArtTransformer2DModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="transformer")
transformer.to(accelerator.device)
transformer.requires_grad_(False)
if args.controlnet_model_name_or_path:
logger.info("Loading existing controlnet weights")
controlnet = PixArtControlNetAdapterModel.from_pretrained(args.controlnet_model_name_or_path)
else:
logger.info("Initializing controlnet weights from transformer.")
controlnet = PixArtControlNetAdapterModel.from_transformer(transformer)
transformer.to(dtype=weight_dtype)
controlnet.to(accelerator.device)
controlnet.train()
def unwrap_model(model, keep_fp32_wrapper=True):
model = accelerator.unwrap_model(model, keep_fp32_wrapper=keep_fp32_wrapper)
model = model._orig_mod if is_compiled_module(model) else model
return model
# 10. Handle saving and loading of checkpoints
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
for _, model in enumerate(models):
if isinstance(model, PixArtControlNetTransformerModel):
print(f"Saving model {model.__class__.__name__} to {output_dir}")
model.controlnet.save_pretrained(os.path.join(output_dir, "controlnet"))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
def load_model_hook(models, input_dir):
# rc todo: test and load the controlenet adapter and transformer
raise ValueError("load model hook not tested")
for i in range(len(models)):
# pop models so that they are not loaded again
model = models.pop()
if isinstance(model, PixArtControlNetTransformerModel):
load_model = PixArtControlNetAdapterModel.from_pretrained(input_dir, subfolder="controlnet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warn(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
transformer.enable_xformers_memory_efficient_attention()
controlnet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
if unwrap_model(controlnet).dtype != torch.float32:
raise ValueError(
f"Transformer loaded as datatype {unwrap_model(controlnet).dtype}. The trainable parameters should be in torch.float32."
)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
controlnet.enable_gradient_checkpointing()
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Initialize the optimizer
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
)
optimizer_cls = bnb.optim.AdamW8bit
else:
optimizer_cls = torch.optim.AdamW
params_to_optimize = controlnet.parameters()
optimizer = optimizer_cls(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Get the datasets: you can either provide your own training and evaluation files (see below)
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
# download the dataset.
if args.dataset_name is not None:
# Downloading and loading a dataset from the hub.
dataset = load_dataset(
args.dataset_name,
args.dataset_config_name,
cache_dir=args.cache_dir,
data_dir=args.train_data_dir,
)
else:
data_files = {}
if args.train_data_dir is not None:
data_files["train"] = os.path.join(args.train_data_dir, "**")
dataset = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets.
# We need to tokenize inputs and targets.
column_names = dataset["train"].column_names
# 6. Get the column names for input/target.
if args.image_column is None:
image_column = column_names[0]
else:
image_column = args.image_column
if image_column not in column_names:
raise ValueError(
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
)
if args.caption_column is None:
caption_column = column_names[1]
else:
caption_column = args.caption_column
if caption_column not in column_names:
raise ValueError(
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
)
if args.conditioning_image_column is None:
conditioning_image_column = column_names[2]
logger.info(f"conditioning image column defaulting to {conditioning_image_column}")
else:
conditioning_image_column = args.conditioning_image_column
if conditioning_image_column not in column_names:
raise ValueError(
f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
)
# Preprocessing the datasets.
# We need to tokenize input captions and transform the images.
def tokenize_captions(examples, is_train=True, proportion_empty_prompts=0.0, max_length=120):
captions = []
for caption in examples[caption_column]:
if random.random() < proportion_empty_prompts:
captions.append("")
elif isinstance(caption, str):
captions.append(caption)
elif isinstance(caption, (list, np.ndarray)):
# take a random caption if there are multiple
captions.append(random.choice(caption) if is_train else caption[0])
else:
raise ValueError(
f"Caption column `{caption_column}` should contain either strings or lists of strings."
)
inputs = tokenizer(captions, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
return inputs.input_ids, inputs.attention_mask
# Preprocessing the datasets.
train_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
]
)
def preprocess_train(examples):
images = [image.convert("RGB") for image in examples[image_column]]
examples["pixel_values"] = [train_transforms(image) for image in images]
conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]]
examples["conditioning_pixel_values"] = [conditioning_image_transforms(image) for image in conditioning_images]
examples["input_ids"], examples["prompt_attention_mask"] = tokenize_captions(
examples, proportion_empty_prompts=args.proportion_empty_prompts, max_length=max_length
)
return examples
with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples])
conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.stack([example["input_ids"] for example in examples])
prompt_attention_mask = torch.stack([example["prompt_attention_mask"] for example in examples])
return {
"pixel_values": pixel_values,
"conditioning_pixel_values": conditioning_pixel_values,
"input_ids": input_ids,
"prompt_attention_mask": prompt_attention_mask,
}
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
# Prepare everything with our `accelerator`.
controlnet_transformer = PixArtControlNetTransformerModel(transformer, controlnet, training=True)
controlnet_transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
controlnet_transformer, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers(args.tracker_project_name, config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
latent_channels = transformer.config.in_channels
for epoch in range(first_epoch, args.num_train_epochs):
controlnet_transformer.controlnet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(controlnet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Convert control images to latent space
controlnet_image_latents = vae.encode(
batch["conditioning_pixel_values"].to(dtype=weight_dtype)
).latent_dist.sample()
controlnet_image_latents = controlnet_image_latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn(
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
prompt_embeds = text_encoder(batch["input_ids"], attention_mask=batch["prompt_attention_mask"])[0]
prompt_attention_mask = batch["prompt_attention_mask"]
# Get the target for loss depending on the prediction type
if args.prediction_type is not None:
# set prediction_type of scheduler if defined
noise_scheduler.register_to_config(prediction_type=args.prediction_type)
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Prepare micro-conditions.
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
if getattr(transformer, "module", transformer).config.sample_size == 128:
resolution = torch.tensor([args.resolution, args.resolution]).repeat(bsz, 1)
aspect_ratio = torch.tensor([float(args.resolution / args.resolution)]).repeat(bsz, 1)
resolution = resolution.to(dtype=weight_dtype, device=latents.device)
aspect_ratio = aspect_ratio.to(dtype=weight_dtype, device=latents.device)
added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
# Predict the noise residual and compute loss
model_pred = controlnet_transformer(
noisy_latents,
encoder_hidden_states=prompt_embeds,
encoder_attention_mask=prompt_attention_mask,
timestep=timesteps,
controlnet_cond=controlnet_image_latents,
added_cond_kwargs=added_cond_kwargs,
return_dict=False,
)[0]
if transformer.config.out_channels // 2 == latent_channels:
model_pred = model_pred.chunk(2, dim=1)[0]
else:
model_pred = model_pred
if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(noise_scheduler, timesteps)
if noise_scheduler.config.prediction_type == "v_prediction":
# Velocity objective requires that we add one to SNR values before we divide by them.
snr = snr + 1
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = controlnet_transformer.controlnet.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
train_loss = 0.0
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(
vae,
transformer,
controlnet_transformer.controlnet,
tokenizer,
noise_scheduler,
text_encoder,
args,
accelerator,
weight_dtype,
global_step,
is_final_validation=False,
)
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
if global_step >= args.max_train_steps:
break
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
controlnet = unwrap_model(controlnet_transformer.controlnet, keep_fp32_wrapper=False)
controlnet.save_pretrained(os.path.join(args.output_dir, "controlnet"))
image_logs = None
if args.validation_prompt is not None:
image_logs = log_validation(
vae,
transformer,
controlnet,
tokenizer,
noise_scheduler,
text_encoder,
args,
accelerator,
weight_dtype,
global_step,
is_final_validation=True,
)
if args.push_to_hub:
save_model_card(
repo_id,
image_logs=image_logs,
base_model=args.pretrained_model_name_or_path,
dataset_name=args.dataset_name,
repo_folder=args.output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)
accelerator.end_training()
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment