"vscode:/vscode.git/clone" did not exist on "57525bb41879d42fa8a0379b0c800011e5802277"
Unverified Commit 8d81564b authored by Yuxuan.Zhang's avatar Yuxuan.Zhang Committed by GitHub
Browse files

CogView3Plus DiT (#9570)

* merge 9588

* max_shard_size="5GB" for colab running

* conversion script updates; modeling test; refactor transformer

* make fix-copies

* Update convert_cogview3_to_diffusers.py

* initial pipeline draft

* make style

* fight bugs 🐛

🪳

* add example

* add tests; refactor

* make style

* make fix-copies

* add co-author

YiYi Xu <yixu310@gmail.com>

* remove files

* add docs

* add co-author
Co-Authored-By: default avatarYiYi Xu <yixu310@gmail.com>

* fight docs

* address reviews

* make style

* make model work

* remove qkv fusion

* remove qkv fusion tets

* address review comments

* fix make fix-copies error

* remove None and TODO

* for FP16(draft)

* make style

* remove dynamic cfg

* remove pooled_projection_dim as a parameter

* fix tests

---------
Co-authored-by: default avatarAryan <aryan@huggingface.co>
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 68d16f78
......@@ -242,6 +242,8 @@
title: AuraFlowTransformer2DModel
- local: api/models/cogvideox_transformer3d
title: CogVideoXTransformer3DModel
- local: api/models/cogview3plus_transformer2d
title: CogView3PlusTransformer2DModel
- local: api/models/dit_transformer2d
title: DiTTransformer2DModel
- local: api/models/flux_transformer
......@@ -320,6 +322,8 @@
title: BLIP-Diffusion
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/cogview3
title: CogView3
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
......
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# CogView3PlusTransformer2DModel
A Diffusion Transformer model for 2D data from [CogView3Plus](https://github.com/THUDM/CogView3) was introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) by Tsinghua University & ZhipuAI.
The model can be loaded with the following code snippet.
```python
from diffusers import CogView3PlusTransformer2DModel
vae = CogView3PlusTransformer2DModel.from_pretrained("THUDM/CogView3Plus-3b", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## CogView3PlusTransformer2DModel
[[autodoc]] CogView3PlusTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# CogView3Plus
[CogView3: Finer and Faster Text-to-Image Generation via Relay Diffusion](https://huggingface.co/papers/2403.05121) from Tsinghua University & ZhipuAI, by Wendi Zheng, Jiayan Teng, Zhuoyi Yang, Weihan Wang, Jidong Chen, Xiaotao Gu, Yuxiao Dong, Ming Ding, Jie Tang.
The abstract from the paper is:
*Recent advancements in text-to-image generative systems have been largely driven by diffusion models. However, single-stage text-to-image diffusion models still face challenges, in terms of computational efficiency and the refinement of image details. To tackle the issue, we propose CogView3, an innovative cascaded framework that enhances the performance of text-to-image diffusion. CogView3 is the first model implementing relay diffusion in the realm of text-to-image generation, executing the task by first creating low-resolution images and subsequently applying relay-based super-resolution. This methodology not only results in competitive text-to-image outputs but also greatly reduces both training and inference costs. Our experimental results demonstrate that CogView3 outperforms SDXL, the current state-of-the-art open-source text-to-image diffusion model, by 77.0% in human evaluations, all while requiring only about 1/2 of the inference time. The distilled variant of CogView3 achieves comparable performance while only utilizing 1/10 of the inference time by SDXL.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/THUDM). The original weights can be found under [hf.co/THUDM](https://huggingface.co/THUDM).
## CogView3PlusPipeline
[[autodoc]] CogView3PlusPipeline
- all
- __call__
## CogView3PipelineOutput
[[autodoc]] pipelines.cogview3.pipeline_output.CogView3PipelineOutput
"""
Convert a CogView3 checkpoint to the Diffusers format.
This script converts a CogView3 checkpoint to the Diffusers format, which can then be used
with the Diffusers library.
Example usage:
python scripts/convert_cogview3_to_diffusers.py \
--transformer_checkpoint_path 'your path/cogview3plus_3b/1/mp_rank_00_model_states.pt' \
--vae_checkpoint_path 'your path/3plus_ae/imagekl_ch16.pt' \
--output_path "/raid/yiyi/cogview3_diffusers" \
--dtype "bf16"
Arguments:
--transformer_checkpoint_path: Path to Transformer state dict.
--vae_checkpoint_path: Path to VAE state dict.
--output_path: The path to save the converted model.
--push_to_hub: Whether to push the converted checkpoint to the HF Hub or not. Defaults to `False`.
--text_encoder_cache_dir: Cache directory where text encoder is located. Defaults to None, which means HF_HOME will be used
--dtype: The dtype to save the model in (default: "bf16", options: "fp16", "bf16", "fp32"). If None, the dtype of the state dict is considered.
Default is "bf16" because CogView3 uses bfloat16 for Training.
Note: You must provide either --original_state_dict_repo_id or --checkpoint_path.
"""
import argparse
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
from diffusers.loaders.single_file_utils import convert_ldm_vae_checkpoint
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
TOKENIZER_MAX_LENGTH = 224
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", required=True, type=str)
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
parser.add_argument("--dtype", type=str, default="bf16")
args = parser.parse_args()
# this is specific to `AdaLayerNormContinuous`:
# diffusers implementation split the linear projection into the scale, shift while CogView3 split it tino shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_cogview3_transformer_checkpoint_to_diffusers(ckpt_path):
original_state_dict = torch.load(ckpt_path, map_location="cpu")
original_state_dict = original_state_dict["module"]
original_state_dict = {k.replace("model.diffusion_model.", ""): v for k, v in original_state_dict.items()}
new_state_dict = {}
# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("mixins.patch_embed.proj.weight")
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("mixins.patch_embed.proj.bias")
new_state_dict["patch_embed.text_proj.weight"] = original_state_dict.pop("mixins.patch_embed.text_proj.weight")
new_state_dict["patch_embed.text_proj.bias"] = original_state_dict.pop("mixins.patch_embed.text_proj.bias")
# Convert time_condition_embed
new_state_dict["time_condition_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop(
"time_embed.0.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop(
"time_embed.0.bias"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop(
"time_embed.2.weight"
)
new_state_dict["time_condition_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop(
"time_embed.2.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.weight"] = original_state_dict.pop(
"label_emb.0.0.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_1.bias"] = original_state_dict.pop(
"label_emb.0.0.bias"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.weight"] = original_state_dict.pop(
"label_emb.0.2.weight"
)
new_state_dict["time_condition_embed.condition_embedder.linear_2.bias"] = original_state_dict.pop(
"label_emb.0.2.bias"
)
# Convert transformer blocks
for i in range(30):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"transformer.layers.{i}."
adaln_prefix = f"mixins.adaln.adaln_modules.{i}."
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(adaln_prefix + "1.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(adaln_prefix + "1.bias")
qkv_weight = original_state_dict.pop(old_prefix + "attention.query_key_value.weight")
qkv_bias = original_state_dict.pop(old_prefix + "attention.query_key_value.bias")
q, k, v = qkv_weight.chunk(3, dim=0)
q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_q.bias"] = q_bias
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_k.bias"] = k_bias
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.to_v.bias"] = v_bias
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attention.dense.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(
old_prefix + "attention.dense.bias"
)
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.weight"
)
new_state_dict[block_prefix + "ff.net.0.proj.bias"] = original_state_dict.pop(
old_prefix + "mlp.dense_h_to_4h.bias"
)
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(
old_prefix + "mlp.dense_4h_to_h.weight"
)
new_state_dict[block_prefix + "ff.net.2.bias"] = original_state_dict.pop(old_prefix + "mlp.dense_4h_to_h.bias")
# Convert final norm and projection
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.weight"), dim=0
)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(
original_state_dict.pop("mixins.final_layer.adaln.1.bias"), dim=0
)
new_state_dict["proj_out.weight"] = original_state_dict.pop("mixins.final_layer.linear.weight")
new_state_dict["proj_out.bias"] = original_state_dict.pop("mixins.final_layer.linear.bias")
return new_state_dict
def convert_cogview3_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
def main(args):
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
transformer = None
vae = None
if args.transformer_checkpoint_path is not None:
converted_transformer_state_dict = convert_cogview3_transformer_checkpoint_to_diffusers(
args.transformer_checkpoint_path
)
transformer = CogView3PlusTransformer2DModel()
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if dtype is not None:
# Original checkpoint data type will be preserved
transformer = transformer.to(dtype=dtype)
if args.vae_checkpoint_path is not None:
vae_config = {
"in_channels": 3,
"out_channels": 3,
"down_block_types": ("DownEncoderBlock2D",) * 4,
"up_block_types": ("UpDecoderBlock2D",) * 4,
"block_out_channels": (128, 512, 1024, 1024),
"layers_per_block": 3,
"act_fn": "silu",
"latent_channels": 16,
"norm_num_groups": 32,
"sample_size": 1024,
"scaling_factor": 1.0,
"force_upcast": True,
"use_quant_conv": False,
"use_post_quant_conv": False,
"mid_block_add_attention": False,
}
converted_vae_state_dict = convert_cogview3_vae_checkpoint_to_diffusers(args.vae_checkpoint_path, vae_config)
vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_state_dict, strict=True)
if dtype is not None:
vae = vae.to(dtype=dtype)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
scheduler = CogVideoXDDIMScheduler.from_config(
{
"snr_shift_scale": 4.0,
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": False,
"num_train_timesteps": 1000,
"prediction_type": "v_prediction",
"rescale_betas_zero_snr": True,
"set_alpha_to_one": True,
"timestep_spacing": "trailing",
}
)
pipe = CogView3PlusPipeline(
tokenizer=tokenizer,
text_encoder=text_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
# This is necessary for users with insufficient memory, such as those using Colab and notebooks, as it can
# save some memory used for model loading.
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
if __name__ == "__main__":
main(args)
......@@ -84,6 +84,7 @@ else:
"AutoencoderOobleck",
"AutoencoderTiny",
"CogVideoXTransformer3DModel",
"CogView3PlusTransformer2DModel",
"ConsistencyDecoderVAE",
"ControlNetModel",
"ControlNetXSAdapter",
......@@ -258,6 +259,7 @@ else:
"CogVideoXImageToVideoPipeline",
"CogVideoXPipeline",
"CogVideoXVideoToVideoPipeline",
"CogView3PlusPipeline",
"CycleDiffusionPipeline",
"FluxControlNetImg2ImgPipeline",
"FluxControlNetInpaintPipeline",
......@@ -559,6 +561,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderOobleck,
AutoencoderTiny,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
ConsistencyDecoderVAE,
ControlNetModel,
ControlNetXSAdapter,
......@@ -711,6 +714,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogVideoXImageToVideoPipeline,
CogVideoXPipeline,
CogVideoXVideoToVideoPipeline,
CogView3PlusPipeline,
CycleDiffusionPipeline,
FluxControlNetImg2ImgPipeline,
FluxControlNetInpaintPipeline,
......
......@@ -54,6 +54,7 @@ if is_torch_available():
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
......@@ -98,6 +99,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .transformers import (
AuraFlowTransformer2DModel,
CogVideoXTransformer3DModel,
CogView3PlusTransformer2DModel,
DiTTransformer2DModel,
DualTransformer2DModel,
FluxTransformer2DModel,
......
......@@ -122,6 +122,7 @@ class Attention(nn.Module):
out_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
):
super().__init__()
......@@ -179,8 +180,8 @@ class Attention(nn.Module):
self.norm_q = None
self.norm_k = None
elif qk_norm == "layer_norm":
self.norm_q = nn.LayerNorm(dim_head, eps=eps)
self.norm_k = nn.LayerNorm(dim_head, eps=eps)
self.norm_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
self.norm_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
elif qk_norm == "fp32_layer_norm":
self.norm_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
self.norm_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
......
......@@ -442,6 +442,60 @@ class CogVideoXPatchEmbed(nn.Module):
return embeds
class CogView3PlusPatchEmbed(nn.Module):
def __init__(
self,
in_channels: int = 16,
hidden_size: int = 2560,
patch_size: int = 2,
text_hidden_size: int = 4096,
pos_embed_max_size: int = 128,
):
super().__init__()
self.in_channels = in_channels
self.hidden_size = hidden_size
self.patch_size = patch_size
self.text_hidden_size = text_hidden_size
self.pos_embed_max_size = pos_embed_max_size
# Linear projection for image patches
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
# Linear projection for text embeddings
self.text_proj = nn.Linear(text_hidden_size, hidden_size)
pos_embed = get_2d_sincos_pos_embed(hidden_size, pos_embed_max_size, base_size=pos_embed_max_size)
pos_embed = pos_embed.reshape(pos_embed_max_size, pos_embed_max_size, hidden_size)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float(), persistent=False)
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
if height % self.patch_size != 0 or width % self.patch_size != 0:
raise ValueError("Height and width must be divisible by patch size")
height = height // self.patch_size
width = width // self.patch_size
hidden_states = hidden_states.view(batch_size, channel, height, self.patch_size, width, self.patch_size)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).contiguous()
hidden_states = hidden_states.view(batch_size, height * width, channel * self.patch_size * self.patch_size)
# Project the patches
hidden_states = self.proj(hidden_states)
encoder_hidden_states = self.text_proj(encoder_hidden_states)
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# Calculate text_length
text_length = encoder_hidden_states.shape[1]
image_pos_embed = self.pos_embed[:height, :width].reshape(height * width, -1)
text_pos_embed = torch.zeros(
(text_length, self.hidden_size), dtype=image_pos_embed.dtype, device=image_pos_embed.device
)
pos_embed = torch.cat([text_pos_embed, image_pos_embed], dim=0)[None, ...]
return (hidden_states + pos_embed).to(hidden_states.dtype)
def get_3d_rotary_pos_embed(
embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
......@@ -1080,6 +1134,39 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
return conditioning
class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(
self,
timestep: torch.Tensor,
original_size: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
hidden_dtype: torch.dtype,
) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
original_size_proj = self.condition_proj(original_size.flatten()).view(original_size.size(0), -1)
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
# (B, 3 * condition_dim)
condition_proj = torch.cat([original_size_proj, crop_coords_proj, target_size_proj], dim=1)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
conditioning = timesteps_emb + condition_emb
return conditioning
class HunyuanDiTAttentionPool(nn.Module):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
......
......@@ -355,6 +355,51 @@ class LuminaLayerNormContinuous(nn.Module):
return x
class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
r"""
Norm layer adaptive layer norm zero (adaLN-Zero).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, dim: int):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
def forward(
self,
x: torch.Tensor,
context: torch.Tensor,
emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
c_shift_msa,
c_scale_msa,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = emb.chunk(12, dim=1)
normed_x = self.norm_x(x)
normed_context = self.norm_c(context)
x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp
class CogVideoXLayerNormZero(nn.Module):
def __init__(
self,
......
......@@ -14,6 +14,7 @@ if is_torch_available():
from .stable_audio_transformer import StableAudioDiTModel
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Union
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.attention_processor import (
Attention,
AttentionProcessor,
CogVideoXAttnProcessor2_0,
)
from ...models.modeling_utils import ModelMixin
from ...models.normalization import AdaLayerNormContinuous
from ...utils import is_torch_version, logging
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class CogView3PlusTransformerBlock(nn.Module):
r"""
Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
Args:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
time_embed_dim (`int`):
The number of channels in timestep embedding.
"""
def __init__(
self,
dim: int = 2560,
num_attention_heads: int = 64,
attention_head_dim: int = 40,
time_embed_dim: int = 512,
):
super().__init__()
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
out_dim=dim,
bias=True,
qk_norm="layer_norm",
elementwise_affine=False,
eps=1e-6,
processor=CogVideoXAttnProcessor2_0(),
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
emb: torch.Tensor,
) -> torch.Tensor:
text_seq_length = encoder_hidden_states.size(1)
# norm & modulate
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
norm_encoder_hidden_states,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = self.norm1(hidden_states, encoder_hidden_states, emb)
# attention
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
)
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
# norm & modulate
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
# feed-forward
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
if hidden_states.dtype == torch.float16:
hidden_states = hidden_states.clip(-65504, 65504)
if encoder_hidden_states.dtype == torch.float16:
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
return hidden_states, encoder_hidden_states
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
r"""
The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
Diffusion](https://huggingface.co/papers/2403.05121).
Args:
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
attention_head_dim (`int`, defaults to `40`):
The number of channels in each head.
num_attention_heads (`int`, defaults to `64`):
The number of heads to use for multi-head attention.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
condition_dim (`int`, defaults to `256`):
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
crop_coords).
pos_embed_max_size (`int`, defaults to `128`):
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
patch_size => 128 * 8 * 2 => 2048`.
sample_size (`int`, defaults to `128`):
The base resolution of input latents. If height/width is not provided during generation, this value is used
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 30,
attention_head_dim: int = 40,
num_attention_heads: int = 64,
out_channels: int = 16,
text_embed_dim: int = 4096,
time_embed_dim: int = 512,
condition_dim: int = 256,
pos_embed_max_size: int = 128,
sample_size: int = 128,
):
super().__init__()
self.out_channels = out_channels
self.inner_dim = num_attention_heads * attention_head_dim
# CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
# Each of these are sincos embeddings of shape 2 * condition_dim
self.pooled_projection_dim = 3 * 2 * condition_dim
self.patch_embed = CogView3PlusPatchEmbed(
in_channels=in_channels,
hidden_size=self.inner_dim,
patch_size=patch_size,
text_hidden_size=text_embed_dim,
pos_embed_max_size=pos_embed_max_size,
)
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
condition_dim=condition_dim,
pooled_projection_dim=self.pooled_projection_dim,
timesteps_dim=self.inner_dim,
)
self.transformer_blocks = nn.ModuleList(
[
CogView3PlusTransformerBlock(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
time_embed_dim=time_embed_dim,
)
for _ in range(num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(
embedding_dim=self.inner_dim,
conditioning_embedding_dim=time_embed_dim,
elementwise_affine=False,
eps=1e-6,
)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
original_size: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
return_dict: bool = True,
) -> Union[torch.Tensor, Transformer2DModelOutput]:
"""
The [`CogView3PlusTransformer2DModel`] forward method.
Args:
hidden_states (`torch.Tensor`):
Input `hidden_states` of shape `(batch size, channel, height, width)`.
encoder_hidden_states (`torch.Tensor`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
`(batch_size, sequence_len, text_embed_dim)`
timestep (`torch.LongTensor`):
Used to indicate denoising step.
original_size (`torch.Tensor`):
CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
target_size (`torch.Tensor`):
CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
crop_coords (`torch.Tensor`):
CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
The denoised latents using provided inputs as conditioning.
"""
height, width = hidden_states.shape[-2:]
text_seq_length = encoder_hidden_states.shape[1]
hidden_states = self.patch_embed(
hidden_states, encoder_hidden_states
) # takes care of adding positional embeddings too.
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
encoder_hidden_states = hidden_states[:, :text_seq_length]
hidden_states = hidden_states[:, text_seq_length:]
for index_block, block in enumerate(self.transformer_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
emb=emb,
)
hidden_states = self.norm_out(hidden_states, emb)
hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
# unpatchify
patch_size = self.config.patch_size
height = height // patch_size
width = width // patch_size
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
)
hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
......@@ -145,6 +145,7 @@ else:
"CogVideoXImageToVideoPipeline",
"CogVideoXVideoToVideoPipeline",
]
_import_structure["cogview3"] = ["CogView3PlusPipeline"]
_import_structure["controlnet"].extend(
[
"BlipDiffusionControlNetPipeline",
......@@ -470,6 +471,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .aura_flow import AuraFlowPipeline
from .blip_diffusion import BlipDiffusionPipeline
from .cogvideo import CogVideoXImageToVideoPipeline, CogVideoXPipeline, CogVideoXVideoToVideoPipeline
from .cogview3 import CogView3PlusPipeline
from .controlnet import (
BlipDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
......
......@@ -20,6 +20,7 @@ from huggingface_hub.utils import validate_hf_hub_args
from ..configuration_utils import ConfigMixin
from ..utils import is_sentencepiece_available
from .aura_flow import AuraFlowPipeline
from .cogview3 import CogView3PlusPipeline
from .controlnet import (
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionControlNetInpaintPipeline,
......@@ -119,6 +120,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("flux", FluxPipeline),
("flux-controlnet", FluxControlNetPipeline),
("lumina", LuminaText2ImgPipeline),
("cogview3", CogView3PlusPipeline),
]
)
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["CogView3PlusPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_cogview3plus"] = ["CogView3PlusPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_cogview3plus import CogView3PlusPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
This diff is collapsed.
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class CogView3PipelineOutput(BaseOutput):
"""
Output class for CogView3 pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
......@@ -122,6 +122,21 @@ class CogVideoXTransformer3DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class CogView3PlusTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ConsistencyDecoderVAE(metaclass=DummyObject):
_backends = ["torch"]
......
......@@ -317,6 +317,21 @@ class CogVideoXVideoToVideoPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class CogView3PlusPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class CycleDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogView3PlusTransformer2DModel
from diffusers.utils.testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView3PlusTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"original_size": original_size,
"target_size": target_size,
"crop_coords": crop_coords,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
"num_layers": 1,
"attention_head_dim": 4,
"num_attention_heads": 2,
"out_channels": 4,
"text_embed_dim": 8,
"time_embed_dim": 8,
"condition_dim": 2,
"pos_embed_max_size": 8,
"sample_size": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment