"packaging/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "df8d7767d0f47f7e6869b9d2f92a902c5cb6e03d"
Unverified Commit 82188cef authored by Yuxuan Zhang's avatar Yuxuan Zhang Committed by GitHub
Browse files

CogView4 Control Block (#10809)




* cogview4 control training


---------
Co-authored-by: default avatarOleehyO <leehy0357@gmail.com>
Co-authored-by: default avataryiyixuxu <yixu310@gmail.com>
parent cc19726f
# Training CogView4 Control
This (experimental) example shows how to train Control LoRAs with [CogView4](https://huggingface.co/THUDM/CogView4-6B) by conditioning it with additional structural controls (like depth maps, poses, etc.). We provide a script for full fine-tuning, too, refer to [this section](#full-fine-tuning). To know more about CogView4 Control family, refer to the following resources:
To incorporate additional condition latents, we expand the input features of CogView-4 from 64 to 128. The first 64 channels correspond to the original input latents to be denoised, while the latter 64 channels correspond to control latents. This expansion happens on the `patch_embed` layer, where the combined latents are projected to the expected feature dimension of rest of the network. Inference is performed using the `CogView4ControlPipeline`.
> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [CogView4 Hugging Face page](https://huggingface.co/THUDM/CogView4-6B), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
```bash
huggingface-cli login
```
The example command below shows how to launch fine-tuning for pose conditions. The dataset ([`raulc0399/open_pose_controlnet`](https://huggingface.co/datasets/raulc0399/open_pose_controlnet)) being used here already has the pose conditions of the original images, so we don't have to compute them.
```bash
accelerate launch train_control_lora_cogview4.py \
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
--dataset_name="raulc0399/open_pose_controlnet" \
--output_dir="pose-control-lora" \
--mixed_precision="bf16" \
--train_batch_size=1 \
--rank=64 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--use_8bit_adam \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=5000 \
--validation_image="openpose.png" \
--validation_prompt="A couple, 4k photo, highly detailed" \
--offload \
--seed="0" \
--push_to_hub
```
`openpose.png` comes from [here](https://huggingface.co/Adapter/t2iadapter/resolve/main/openpose.png).
You need to install `diffusers` from the branch of [this PR](https://github.com/huggingface/diffusers/pull/9999). When it's merged, you should install `diffusers` from the `main`.
The training script exposes additional CLI args that might be useful to experiment with:
* `use_lora_bias`: When set, additionally trains the biases of the `lora_B` layer.
* `train_norm_layers`: When set, additionally trains the normalization scales. Takes care of saving and loading.
* `lora_layers`: Specify the layers you want to apply LoRA to. If you specify "all-linear", all the linear layers will be LoRA-attached.
### Training with DeepSpeed
It's possible to train with [DeepSpeed](https://github.com/microsoft/DeepSpeed), specifically leveraging the Zero2 system optimization. To use it, save the following config to an YAML file (feel free to modify as needed):
```yaml
compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
gradient_accumulation_steps: 1
gradient_clipping: 1.0
offload_optimizer_device: cpu
offload_param_device: cpu
zero3_init_flag: false
zero_stage: 2
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 1
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
```
And then while launching training, pass the config file:
```bash
accelerate launch --config_file=CONFIG_FILE.yaml ...
```
### Inference
The pose images in our dataset were computed using the [`controlnet_aux`](https://github.com/huggingface/controlnet_aux) library. Let's install it first:
```bash
pip install controlnet_aux
```
And then we are ready:
```py
from controlnet_aux import OpenposeDetector
from diffusers import CogView4ControlPipeline
from diffusers.utils import load_image
from PIL import Image
import numpy as np
import torch
pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda")
pipe.load_lora_weights("...") # change this.
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
# prepare pose condition.
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
image = load_image(url)
image = open_pose(image, detect_resolution=512, image_resolution=1024)
image = np.array(image)[:, :, ::-1]
image = Image.fromarray(np.uint8(image))
prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
control_image=image,
num_inference_steps=50,
joint_attention_kwargs={"scale": 0.9},
guidance_scale=25.,
).images[0]
gen_images.save("output.png")
```
## Full fine-tuning
We provide a non-LoRA version of the training script `train_control_cogview4.py`. Here is an example command:
```bash
accelerate launch --config_file=accelerate_ds2.yaml train_control_cogview4.py \
--pretrained_model_name_or_path="THUDM/CogView4-6B" \
--dataset_name="raulc0399/open_pose_controlnet" \
--output_dir="pose-control" \
--mixed_precision="bf16" \
--train_batch_size=2 \
--dataloader_num_workers=4 \
--gradient_accumulation_steps=4 \
--gradient_checkpointing \
--use_8bit_adam \
--proportion_empty_prompts=0.2 \
--learning_rate=5e-5 \
--adam_weight_decay=1e-4 \
--report_to="wandb" \
--lr_scheduler="cosine" \
--lr_warmup_steps=1000 \
--checkpointing_steps=1000 \
--max_train_steps=10000 \
--validation_steps=200 \
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
--offload \
--seed="0" \
--push_to_hub
```
Change the `validation_image` and `validation_prompt` as needed.
For inference, this time, we will run:
```py
from controlnet_aux import OpenposeDetector
from diffusers import CogView4ControlPipeline, CogView4Transformer2DModel
from diffusers.utils import load_image
from PIL import Image
import numpy as np
import torch
transformer = CogView4Transformer2DModel.from_pretrained("...") # change this.
pipe = CogView4ControlPipeline.from_pretrained(
"THUDM/CogView4-6B", transformer=transformer, torch_dtype=torch.bfloat16
).to("cuda")
open_pose = OpenposeDetector.from_pretrained("lllyasviel/Annotators")
# prepare pose condition.
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/people.jpg"
image = load_image(url)
image = open_pose(image, detect_resolution=512, image_resolution=1024)
image = np.array(image)[:, :, ::-1]
image = Image.fromarray(np.uint8(image))
prompt = "A couple, 4k photo, highly detailed"
gen_images = pipe(
prompt=prompt,
control_image=image,
num_inference_steps=50,
guidance_scale=25.,
).images[0]
gen_images.save("output.png")
```
## Things to note
* The scripts provided in this directory are experimental and educational. This means we may have to tweak things around to get good results on a given condition. We believe this is best done with the community 🤗
* The scripts are not memory-optimized but we offload the VAE and the text encoders to CPU when they are not used if `--offload` is specified.
* We can extract LoRAs from the fully fine-tuned model. While we currently don't provide any utilities for that, users are welcome to refer to [this script](https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/master/control_lora_create.py) that provides a similar functionality.
\ No newline at end of file
transformers==4.47.0
wandb
torch
torchvision
accelerate==1.2.0
peft>=0.14.0
This diff is collapsed.
...@@ -53,8 +53,18 @@ args = parser.parse_args() ...@@ -53,8 +53,18 @@ args = parser.parse_args()
# this is specific to `AdaLayerNormContinuous`: # this is specific to `AdaLayerNormContinuous`:
# diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale # diffusers implementation split the linear projection into the scale, shift while CogView4 split it tino shift, scale
def swap_scale_shift(weight, dim): def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0) """
new_weight = torch.cat([scale, shift], dim=0) Swap the scale and shift components in the weight tensor.
Args:
weight (torch.Tensor): The original weight tensor.
dim (int): The dimension along which to split.
Returns:
torch.Tensor: The modified weight tensor with scale and shift swapped.
"""
shift, scale = weight.chunk(2, dim=dim)
new_weight = torch.cat([scale, shift], dim=dim)
return new_weight return new_weight
...@@ -200,6 +210,7 @@ def main(args): ...@@ -200,6 +210,7 @@ def main(args):
"norm_num_groups": 32, "norm_num_groups": 32,
"sample_size": 1024, "sample_size": 1024,
"scaling_factor": 1.0, "scaling_factor": 1.0,
"shift_factor": 0.0,
"force_upcast": True, "force_upcast": True,
"use_quant_conv": False, "use_quant_conv": False,
"use_post_quant_conv": False, "use_post_quant_conv": False,
......
...@@ -25,9 +25,15 @@ import argparse ...@@ -25,9 +25,15 @@ import argparse
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import GlmForCausalLM, PreTrainedTokenizerFast from transformers import GlmModel, PreTrainedTokenizerFast
from diffusers import AutoencoderKL, CogView4Pipeline, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler from diffusers import (
AutoencoderKL,
CogView4ControlPipeline,
CogView4Pipeline,
CogView4Transformer2DModel,
FlowMatchEulerDiscreteScheduler,
)
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
...@@ -112,6 +118,12 @@ parser.add_argument( ...@@ -112,6 +118,12 @@ parser.add_argument(
default=128, default=128,
help="Maximum size for positional embeddings.", help="Maximum size for positional embeddings.",
) )
parser.add_argument(
"--control",
action="store_true",
default=False,
help="Whether to use control model.",
)
args = parser.parse_args() args = parser.parse_args()
...@@ -150,13 +162,15 @@ def convert_megatron_transformer_checkpoint_to_diffusers( ...@@ -150,13 +162,15 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
Returns: Returns:
dict: The converted state dictionary compatible with Diffusers. dict: The converted state dictionary compatible with Diffusers.
""" """
ckpt = torch.load(ckpt_path, map_location="cpu") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
mega = ckpt["model"] mega = ckpt["model"]
new_state_dict = {} new_state_dict = {}
# Patch Embedding # Patch Embedding
new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(hidden_size, 64) new_state_dict["patch_embed.proj.weight"] = mega["encoder_expand_linear.weight"].reshape(
hidden_size, 128 if args.control else 64
)
new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"] new_state_dict["patch_embed.proj.bias"] = mega["encoder_expand_linear.bias"]
new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"] new_state_dict["patch_embed.text_proj.weight"] = mega["text_projector.weight"]
new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"] new_state_dict["patch_embed.text_proj.bias"] = mega["text_projector.bias"]
...@@ -189,14 +203,8 @@ def convert_megatron_transformer_checkpoint_to_diffusers( ...@@ -189,14 +203,8 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
block_prefix = f"transformer_blocks.{i}." block_prefix = f"transformer_blocks.{i}."
# AdaLayerNorm # AdaLayerNorm
new_state_dict[block_prefix + "norm1.linear.weight"] = swap_scale_shift( new_state_dict[block_prefix + "norm1.linear.weight"] = mega[f"decoder.layers.{i}.adaln.weight"]
mega[f"decoder.layers.{i}.adaln.weight"], dim=0 new_state_dict[block_prefix + "norm1.linear.bias"] = mega[f"decoder.layers.{i}.adaln.bias"]
)
new_state_dict[block_prefix + "norm1.linear.bias"] = swap_scale_shift(
mega[f"decoder.layers.{i}.adaln.bias"], dim=0
)
# QKV
qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"] qkv_weight = mega[f"decoder.layers.{i}.self_attention.linear_qkv.weight"]
qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"] qkv_bias = mega[f"decoder.layers.{i}.self_attention.linear_qkv.bias"]
...@@ -221,7 +229,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers( ...@@ -221,7 +229,7 @@ def convert_megatron_transformer_checkpoint_to_diffusers(
# Attention Output # Attention Output
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[ new_state_dict[block_prefix + "attn1.to_out.0.weight"] = mega[
f"decoder.layers.{i}.self_attention.linear_proj.weight" f"decoder.layers.{i}.self_attention.linear_proj.weight"
].T ]
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[ new_state_dict[block_prefix + "attn1.to_out.0.bias"] = mega[
f"decoder.layers.{i}.self_attention.linear_proj.bias" f"decoder.layers.{i}.self_attention.linear_proj.bias"
] ]
...@@ -252,7 +260,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config): ...@@ -252,7 +260,7 @@ def convert_cogview4_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
Returns: Returns:
dict: The converted VAE state dictionary compatible with Diffusers. dict: The converted VAE state dictionary compatible with Diffusers.
""" """
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"] original_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)["state_dict"]
return convert_ldm_vae_checkpoint(original_state_dict, vae_config) return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
...@@ -286,7 +294,7 @@ def main(args): ...@@ -286,7 +294,7 @@ def main(args):
) )
transformer = CogView4Transformer2DModel( transformer = CogView4Transformer2DModel(
patch_size=2, patch_size=2,
in_channels=16, in_channels=32 if args.control else 16,
num_layers=args.num_layers, num_layers=args.num_layers,
attention_head_dim=args.attention_head_dim, attention_head_dim=args.attention_head_dim,
num_attention_heads=args.num_heads, num_attention_heads=args.num_heads,
...@@ -317,6 +325,7 @@ def main(args): ...@@ -317,6 +325,7 @@ def main(args):
"norm_num_groups": 32, "norm_num_groups": 32,
"sample_size": 1024, "sample_size": 1024,
"scaling_factor": 1.0, "scaling_factor": 1.0,
"shift_factor": 0.0,
"force_upcast": True, "force_upcast": True,
"use_quant_conv": False, "use_quant_conv": False,
"use_post_quant_conv": False, "use_post_quant_conv": False,
...@@ -331,7 +340,7 @@ def main(args): ...@@ -331,7 +340,7 @@ def main(args):
# Load the text encoder and tokenizer # Load the text encoder and tokenizer
text_encoder_id = "THUDM/glm-4-9b-hf" text_encoder_id = "THUDM/glm-4-9b-hf"
tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id) tokenizer = PreTrainedTokenizerFast.from_pretrained(text_encoder_id)
text_encoder = GlmForCausalLM.from_pretrained( text_encoder = GlmModel.from_pretrained(
text_encoder_id, text_encoder_id,
cache_dir=args.text_encoder_cache_dir, cache_dir=args.text_encoder_cache_dir,
torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32, torch_dtype=torch.bfloat16 if args.dtype == "bf16" else torch.float32,
...@@ -345,13 +354,22 @@ def main(args): ...@@ -345,13 +354,22 @@ def main(args):
) )
# Create the pipeline # Create the pipeline
pipe = CogView4Pipeline( if args.control:
tokenizer=tokenizer, pipe = CogView4ControlPipeline(
text_encoder=text_encoder, tokenizer=tokenizer,
vae=vae, text_encoder=text_encoder,
transformer=transformer, vae=vae,
scheduler=scheduler, transformer=transformer,
) scheduler=scheduler,
)
else:
pipe = CogView4Pipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
# Save the converted pipeline # Save the converted pipeline
pipe.save_pretrained( pipe.save_pretrained(
......
...@@ -345,6 +345,7 @@ else: ...@@ -345,6 +345,7 @@ else:
"CogVideoXPipeline", "CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline", "CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline", "CogView3PlusPipeline",
"CogView4ControlPipeline",
"CogView4Pipeline", "CogView4Pipeline",
"ConsisIDPipeline", "ConsisIDPipeline",
"CycleDiffusionPipeline", "CycleDiffusionPipeline",
...@@ -889,6 +890,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -889,6 +890,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogVideoXPipeline, CogVideoXPipeline,
CogVideoXVideoToVideoPipeline, CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline, CogView3PlusPipeline,
CogView4ControlPipeline,
CogView4Pipeline, CogView4Pipeline,
ConsisIDPipeline, ConsisIDPipeline,
CycleDiffusionPipeline, CycleDiffusionPipeline,
......
...@@ -23,6 +23,7 @@ from ...loaders import PeftAdapterMixin ...@@ -23,6 +23,7 @@ from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward from ..attention import FeedForward
from ..attention_processor import Attention from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
...@@ -126,7 +127,8 @@ class CogView4AttnProcessor: ...@@ -126,7 +127,8 @@ class CogView4AttnProcessor:
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1) batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections # 1. QKV projections
...@@ -156,6 +158,15 @@ class CogView4AttnProcessor: ...@@ -156,6 +158,15 @@ class CogView4AttnProcessor:
) )
# 4. Attention # 4. Attention
if attention_mask is not None:
text_attention_mask = attention_mask.float().to(query.device)
actual_text_seq_length = text_attention_mask.size(1)
new_attention_mask = torch.zeros((batch_size, text_seq_length + image_seq_length), device=query.device)
new_attention_mask[:, :actual_text_seq_length] = text_attention_mask
new_attention_mask = new_attention_mask.unsqueeze(2)
attention_mask_matrix = new_attention_mask @ new_attention_mask.transpose(1, 2)
attention_mask = (attention_mask_matrix > 0).unsqueeze(1).to(query.dtype)
hidden_states = F.scaled_dot_product_attention( hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
) )
...@@ -203,6 +214,8 @@ class CogView4TransformerBlock(nn.Module): ...@@ -203,6 +214,8 @@ class CogView4TransformerBlock(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor: ) -> torch.Tensor:
# 1. Timestep conditioning # 1. Timestep conditioning
( (
...@@ -223,6 +236,8 @@ class CogView4TransformerBlock(nn.Module): ...@@ -223,6 +236,8 @@ class CogView4TransformerBlock(nn.Module):
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
**kwargs,
) )
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1) hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1) encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
...@@ -289,7 +304,7 @@ class CogView4RotaryPosEmbed(nn.Module): ...@@ -289,7 +304,7 @@ class CogView4RotaryPosEmbed(nn.Module):
return (freqs.cos(), freqs.sin()) return (freqs.cos(), freqs.sin())
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
r""" r"""
Args: Args:
patch_size (`int`, defaults to `2`): patch_size (`int`, defaults to `2`):
...@@ -386,6 +401,8 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -386,6 +401,8 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
crop_coords: torch.Tensor, crop_coords: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Union[torch.Tensor, Transformer2DModelOutput]: ) -> Union[torch.Tensor, Transformer2DModelOutput]:
if attention_kwargs is not None: if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy() attention_kwargs = attention_kwargs.copy()
...@@ -421,11 +438,11 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): ...@@ -421,11 +438,11 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
for block in self.transformer_blocks: for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing: if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func( hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, temb, image_rotary_emb block, hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
) )
else: else:
hidden_states, encoder_hidden_states = block( hidden_states, encoder_hidden_states = block(
hidden_states, encoder_hidden_states, temb, image_rotary_emb hidden_states, encoder_hidden_states, temb, image_rotary_emb, attention_mask, **kwargs
) )
# 4. Output norm & projection # 4. Output norm & projection
......
...@@ -154,7 +154,7 @@ else: ...@@ -154,7 +154,7 @@ else:
"CogVideoXFunControlPipeline", "CogVideoXFunControlPipeline",
] ]
_import_structure["cogview3"] = ["CogView3PlusPipeline"] _import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["cogview4"] = ["CogView4Pipeline"] _import_structure["cogview4"] = ["CogView4Pipeline", "CogView4ControlPipeline"]
_import_structure["consisid"] = ["ConsisIDPipeline"] _import_structure["consisid"] = ["ConsisIDPipeline"]
_import_structure["controlnet"].extend( _import_structure["controlnet"].extend(
[ [
...@@ -511,7 +511,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -511,7 +511,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogVideoXVideoToVideoPipeline, CogVideoXVideoToVideoPipeline,
) )
from .cogview3 import CogView3PlusPipeline from .cogview3 import CogView3PlusPipeline
from .cogview4 import CogView4Pipeline from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
from .consisid import ConsisIDPipeline from .consisid import ConsisIDPipeline
from .controlnet import ( from .controlnet import (
BlipDiffusionControlNetPipeline, BlipDiffusionControlNetPipeline,
......
...@@ -22,7 +22,7 @@ from ..models.controlnets import ControlNetUnionModel ...@@ -22,7 +22,7 @@ from ..models.controlnets import ControlNetUnionModel
from ..utils import is_sentencepiece_available from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline from .aura_flow import AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline from .cogview3 import CogView3PlusPipeline
from .cogview4 import CogView4Pipeline from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
from .controlnet import ( from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline, StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetInpaintPipeline,
...@@ -145,6 +145,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict( ...@@ -145,6 +145,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("lumina2", Lumina2Pipeline), ("lumina2", Lumina2Pipeline),
("cogview3", CogView3PlusPipeline), ("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline), ("cogview4", CogView4Pipeline),
("cogview4-control", CogView4ControlPipeline),
] ]
) )
......
...@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable: ...@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["pipeline_cogview4"] = ["CogView4Pipeline"] _import_structure["pipeline_cogview4"] = ["CogView4Pipeline"]
_import_structure["pipeline_cogview4_control"] = ["CogView4ControlPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
if not (is_transformers_available() and is_torch_available()): if not (is_transformers_available() and is_torch_available()):
...@@ -31,6 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -31,6 +32,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else: else:
from .pipeline_cogview4 import CogView4Pipeline from .pipeline_cogview4 import CogView4Pipeline
from .pipeline_cogview4_control import CogView4ControlPipeline
else: else:
import sys import sys
......
...@@ -389,14 +389,18 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): ...@@ -389,14 +389,18 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
def num_timesteps(self): def num_timesteps(self):
return self._num_timesteps return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@property @property
def attention_kwargs(self): def attention_kwargs(self):
return self._attention_kwargs return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -533,6 +537,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): ...@@ -533,6 +537,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False self._interrupt = False
# Default call parameters # Default call parameters
...@@ -610,6 +615,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): ...@@ -610,6 +615,7 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
if self.interrupt: if self.interrupt:
continue continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype) latent_model_input = latents.to(transformer_dtype)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
...@@ -661,6 +667,8 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin): ...@@ -661,6 +667,8 @@ class CogView4Pipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
self._current_timestep = None
if not output_type == "latent": if not output_type == "latent":
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
image = self.vae.decode(latents, return_dict=False, generator=generator)[0] image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
......
This diff is collapsed.
...@@ -362,6 +362,21 @@ class CogView3PlusPipeline(metaclass=DummyObject): ...@@ -362,6 +362,21 @@ class CogView3PlusPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class CogView4ControlPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CogView4Pipeline(metaclass=DummyObject): class CogView4Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment