Commit 08a21d59 authored by chenpangpang's avatar chenpangpang
Browse files

feat: 初始提交

parent 1a6b26f1
Pipeline #2165 failed with stages
in 0 seconds
{
"last_node_id": 6,
"last_link_id": 3,
"nodes": [
{
"id": 1,
"type": "Ruyi_LoadModel",
"pos": {
"0": 210,
"1": 162
},
"size": {
"0": 315,
"1": 82
},
"flags": {},
"order": 0,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "ruyi_model",
"type": "RUYI_MODEL",
"links": [
1
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "Ruyi_LoadModel"
},
"widgets_values": [
"Ruyi-Mini-7B",
"yes",
"yes"
]
},
{
"id": 4,
"type": "VHS_VideoCombine",
"pos": {
"0": 1045,
"1": 133
},
"size": [
404.73553466796875,
601.8645528157551
],
"flags": {},
"order": 3,
"mode": 0,
"inputs": [
{
"name": "images",
"type": "IMAGE",
"link": 3
},
{
"name": "audio",
"type": "AUDIO",
"link": null,
"shape": 7
},
{
"name": "meta_batch",
"type": "VHS_BatchManager",
"link": null,
"shape": 7
},
{
"name": "vae",
"type": "VAE",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "Filenames",
"type": "VHS_FILENAMES",
"links": null
}
],
"properties": {
"Node name for S&R": "VHS_VideoCombine"
},
"widgets_values": {
"frame_rate": 24,
"loop_count": 0,
"filename_prefix": "Ruyi-I2V-StartFrame",
"format": "video/h264-mp4",
"pix_fmt": "yuv420p",
"crf": 19,
"save_metadata": true,
"pingpong": false,
"save_output": true,
"videopreview": {
"hidden": false,
"paused": false,
"params": {
"filename": "Ruyi-I2V-StartFrame_00001.mp4",
"subfolder": "",
"type": "output",
"format": "video/h264-mp4",
"frame_rate": 24
},
"muted": false
}
}
},
{
"id": 3,
"type": "LoadImage",
"pos": {
"0": 200,
"1": 439
},
"size": {
"0": 315,
"1": 314
},
"flags": {},
"order": 1,
"mode": 0,
"inputs": [],
"outputs": [
{
"name": "IMAGE",
"type": "IMAGE",
"links": [
2
],
"slot_index": 0
},
{
"name": "MASK",
"type": "MASK",
"links": null
}
],
"properties": {
"Node name for S&R": "LoadImage"
},
"widgets_values": [
"example_03.jpg",
"image"
]
},
{
"id": 2,
"type": "Ruyi_I2VSampler",
"pos": {
"0": 628,
"1": 284
},
"size": {
"0": 327.5999755859375,
"1": 338
},
"flags": {},
"order": 2,
"mode": 0,
"inputs": [
{
"name": "ruyi_model",
"type": "RUYI_MODEL",
"link": 1
},
{
"name": "start_img",
"type": "IMAGE",
"link": 2
},
{
"name": "end_img",
"type": "IMAGE",
"link": null,
"shape": 7
}
],
"outputs": [
{
"name": "images",
"type": "IMAGE",
"links": [
3
],
"slot_index": 0
}
],
"properties": {
"Node name for S&R": "Ruyi_I2VSampler"
},
"widgets_values": [
120,
512,
925247271358454,
"randomize",
25,
7,
"DDIM",
"2",
"static",
"normal_mode",
"5"
]
}
],
"links": [
[
1,
1,
0,
2,
0,
"RUYI_MODEL"
],
[
2,
3,
0,
2,
1,
"IMAGE"
],
[
3,
2,
0,
4,
0,
"IMAGE"
]
],
"groups": [],
"config": {},
"extra": {
"ds": {
"scale": 1,
"offset": [
0,
0
]
}
},
"version": 0.4
}
transformer_additional_kwargs:
basic_block_type: "basic"
after_norm: false
time_position_encoding: true
noise_scheduler_kwargs:
beta_start: 0.00085
beta_end: 0.03
beta_schedule: "scaled_linear"
steps_offset: 1
prediction_type: "v_prediction"
clip_sample: false
vae_kwargs:
enable_magvit: true
import os
import torch
from PIL import Image
from diffusers import (EulerDiscreteScheduler, EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler, PNDMScheduler, DDIMScheduler)
from omegaconf import OmegaConf
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from safetensors.torch import load_file as load_safetensors
from huggingface_hub import snapshot_download
from ruyi.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
from ruyi.models.autoencoder_magvit import AutoencoderKLMagvit
from ruyi.models.transformer3d import HunyuanTransformer3DModel
from ruyi.pipeline.pipeline_ruyi_inpaint import RuyiInpaintPipeline
from ruyi.utils.lora_utils import merge_lora, unmerge_lora
from ruyi.utils.utils import get_image_to_video_latent, save_videos_grid
# Input and output
start_image_path = "assets/girl_01.jpg"
end_image_path = "assets/girl_02.jpg" # Can be None for start-image-to-video
output_video_path = "outputs/example_01.mp4"
# Video settings
video_length = 120 # The max video length is 120 frames (24 frames per second)
base_resolution = 512 # # The pixels in the generated video are approximately 512 x 512. Values in the range of [384, 896] typically produce good video quality.
video_size = None # Override base_resolution. Format: [height, width], e.g., [384, 672]
# Control settings
aspect_ratio = "16:9" # Choose in ["16:9", "9:16"], note that this is only the hint
motion = "auto" # Motion control, choose in ["1", "2", "3", "4", "auto"]
camera_direction = "auto" # Camera control, choose in ["static", "left", "right", "up", "down", "auto"]
# Sampler settings
steps = 25
cfg = 7.0
scheduler_name = "DDIM" # Choose in ["Euler", "Euler A", "DPM++", "PNDM","DDIM"]
# GPU memory settings
low_gpu_memory_mode = False # Low gpu memory mode
gpu_offload_steps = 5 # Choose in [0, 10, 7, 5, 1], the latter number requires less GPU memory but longer time
# Random seed
seed = 42 # The Answer to the Ultimate Question of Life, The Universe, and Everything
# Model settings
config_path = "config/default.yaml"
model_name = "Ruyi-Mini-7B"
model_type = "Inpaint"
model_path = f"models/{model_name}" # (Down)load mode in this path
auto_download = True # Automatically download the model if the pipeline creation fails
auto_update = True # If auto_download is enabled, check for updates and update the model if necessary
# LoRA settings
lora_path = None
lora_weight = 1.0
# Other settings
weight_dtype = torch.bfloat16
device = torch.device("cuda")
def get_control_embeddings(pipeline, aspect_ratio, motion, camera_direction):
# Default keys
p_default_key = "p.default"
n_default_key = "n.default"
# Load embeddings
if motion == "auto":
motion = "0"
p_key = f"p.{aspect_ratio.replace(':', 'x')}movie{motion}{camera_direction}"
embeddings = pipeline.embeddings
# Get embeddings
positive_embeds = embeddings.get(f"{p_key}.emb1", embeddings[f"{p_default_key}.emb1"])
positive_attention_mask = embeddings.get(f"{p_key}.mask1", embeddings[f"{p_default_key}.mask1"])
positive_embeds_2 = embeddings.get(f"{p_key}.emb2", embeddings[f"{p_default_key}.emb2"])
positive_attention_mask_2 = embeddings.get(f"{p_key}.mask2", embeddings[f"{p_default_key}.mask2"])
negative_embeds = embeddings[f"{n_default_key}.emb1"]
negative_attention_mask = embeddings[f"{n_default_key}.mask1"]
negative_embeds_2 = embeddings[f"{n_default_key}.emb2"]
negative_attention_mask_2 = embeddings[f"{n_default_key}.mask2"]
return {
"positive_embeds": positive_embeds,
"positive_attention_mask": positive_attention_mask,
"positive_embeds_2": positive_embeds_2,
"positive_attention_mask_2": positive_attention_mask_2,
"negative_embeds": negative_embeds,
"negative_attention_mask": negative_attention_mask,
"negative_embeds_2": negative_embeds_2,
"negative_attention_mask_2": negative_attention_mask_2,
}
def try_setup_pipeline(model_path, weight_dtype, config):
try:
# Get Vae
vae = AutoencoderKLMagvit.from_pretrained(
model_path,
subfolder="vae"
).to(weight_dtype)
print("Vae loaded ...")
# Get Transformer
transformer_additional_kwargs = OmegaConf.to_container(config['transformer_additional_kwargs'])
transformer = HunyuanTransformer3DModel.from_pretrained_2d(
model_path,
subfolder="transformer",
transformer_additional_kwargs=transformer_additional_kwargs
).to(weight_dtype)
print("Transformer loaded ...")
# Load Clip
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
model_path, subfolder="image_encoder"
).to(weight_dtype)
clip_image_processor = CLIPImageProcessor.from_pretrained(
model_path, subfolder="image_encoder"
)
# Load sampler and create pipeline
Choosen_Scheduler = DDIMScheduler
scheduler = Choosen_Scheduler.from_pretrained(
model_path,
subfolder="scheduler"
)
pipeline = RuyiInpaintPipeline.from_pretrained(
model_path,
vae=vae,
transformer=transformer,
scheduler=scheduler,
torch_dtype=weight_dtype,
clip_image_encoder=clip_image_encoder,
clip_image_processor=clip_image_processor,
)
# Load embeddings
embeddings = load_safetensors(os.path.join(model_path, "embeddings.safetensors"))
pipeline.embeddings = embeddings
print("Pipeline loaded ...")
return pipeline
except Exception as e:
print("[Ruyi] Setup pipeline failed:", e)
return None
# Load config
config = OmegaConf.load(config_path)
# Load images
start_img = [Image.open(start_image_path).convert("RGB")]
end_img = [Image.open(end_image_path).convert("RGB")] if end_image_path is not None else None
# Check for update
repo_id = f"IamCreateAI/{model_name}"
if auto_download and auto_update:
print(f"Checking for {model_name} updates ...")
# Download the model
snapshot_download(repo_id=repo_id, local_dir=model_path)
# Init model
pipeline = try_setup_pipeline(model_path, weight_dtype, config)
if pipeline is None and auto_download:
print(f"Downloading {model_name} ...")
# Download the model
snapshot_download(repo_id=repo_id, local_dir=model_path)
pipeline = try_setup_pipeline(model_path, weight_dtype, config)
if pipeline is None:
message = (f"[Load Model Failed] "
f"Please download Ruyi model from huggingface repo '{repo_id}', "
f"And put it into '{model_path}'.")
if not auto_download:
message += "\nOr just set auto_download to 'True'."
raise FileNotFoundError(message)
# Setup GPU memory mode
if low_gpu_memory_mode:
pipeline.enable_sequential_cpu_offload()
else:
pipeline.enable_model_cpu_offload()
# Prepare LoRA config
loras = {
'models': [lora_path] if lora_path is not None else [],
'weights': [lora_weight] if lora_path is not None else [],
}
# Count most suitable height and width
if video_size is None:
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
original_width, original_height = start_img[0].size if type(start_img) is list else Image.open(start_img).size
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
height, width = [int(x / 16) * 16 for x in closest_size]
else:
height, width = video_size
# Set hidden states offload steps
pipeline.transformer.hidden_cache_size = gpu_offload_steps
# Load Sampler
if scheduler_name == "DPM++":
noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder='scheduler')
elif scheduler_name == "Euler":
noise_scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder='scheduler')
elif scheduler_name == "Euler A":
noise_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder='scheduler')
elif scheduler_name == "PNDM":
noise_scheduler = PNDMScheduler.from_pretrained(model_path, subfolder='scheduler')
elif scheduler_name == "DDIM":
noise_scheduler = DDIMScheduler.from_pretrained(model_path, subfolder='scheduler')
pipeline.scheduler = noise_scheduler
# Set random seed
generator= torch.Generator(device).manual_seed(seed)
# Load control embeddings
embeddings = get_control_embeddings(pipeline, aspect_ratio, motion, camera_direction)
with torch.no_grad():
video_length = int(video_length // pipeline.vae.mini_batch_encoder * pipeline.vae.mini_batch_encoder) if video_length != 1 else 1
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_img, end_img, video_length=video_length, sample_size=(height, width))
for _lora_path, _lora_weight in zip(loras.get("models", []), loras.get("weights", [])):
pipeline = merge_lora(pipeline, _lora_path, _lora_weight)
sample = pipeline(
prompt_embeds = embeddings["positive_embeds"],
prompt_attention_mask = embeddings["positive_attention_mask"],
prompt_embeds_2 = embeddings["positive_embeds_2"],
prompt_attention_mask_2 = embeddings["positive_attention_mask_2"],
negative_prompt_embeds = embeddings["negative_embeds"],
negative_prompt_attention_mask = embeddings["negative_attention_mask"],
negative_prompt_embeds_2 = embeddings["negative_embeds_2"],
negative_prompt_attention_mask_2 = embeddings["negative_attention_mask_2"],
video_length = video_length,
height = height,
width = width,
generator = generator,
guidance_scale = cfg,
num_inference_steps = steps,
video = input_video,
mask_video = input_video_mask,
clip_image = clip_image,
).videos
for _lora_path, _lora_weight in zip(loras.get("models", []), loras.get("weights", [])):
pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
# Save the video
output_folder = os.path.dirname(output_video_path)
if output_folder != '':
os.makedirs(output_folder, exist_ok=True)
save_videos_grid(sample, output_video_path, fps=24)
import os
import torch
from PIL import Image
from diffusers import (EulerDiscreteScheduler, EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler, PNDMScheduler, DDIMScheduler)
from omegaconf import OmegaConf
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from safetensors.torch import load_file as load_safetensors
from huggingface_hub import snapshot_download
from ruyi.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
from ruyi.models.autoencoder_magvit import AutoencoderKLMagvit
from ruyi.models.transformer3d import HunyuanTransformer3DModel
from ruyi.pipeline.pipeline_ruyi_inpaint import RuyiInpaintPipeline
from ruyi.utils.lora_utils import merge_lora, unmerge_lora
from ruyi.utils.utils import get_image_to_video_latent, save_videos_grid
# Input and output
start_image_path = "assets/girl_01.jpg"
end_image_path = "assets/girl_02.jpg" # Can be None for start-image-to-video
output_video_path = "outputs/example_01.mp4"
# Video settings
video_length = 120 # The max video length is 120 frames (24 frames per second)
base_resolution = 512 # # The pixels in the generated video are approximately 512 x 512. Values in the range of [384, 896] typically produce good video quality.
video_size = None # Override base_resolution. Format: [height, width], e.g., [384, 672]
# Control settings
aspect_ratio = "9:16" # Choose in ["16:9", "9:16"], note that this is only the hint
motion = "auto" # Motion control, choose in ["1", "2", "3", "4", "auto"]
camera_direction = "auto" # Camera control, choose in ["static", "left", "right", "up", "down", "auto"]
# Sampler settings
steps = 25
cfg = 7.0
scheduler_name = "DDIM" # Choose in ["Euler", "Euler A", "DPM++", "PNDM","DDIM"]
# GPU memory settings
low_gpu_memory_mode = False # Low gpu memory mode
gpu_offload_steps = 0 # Choose in [0, 10, 7, 5, 1], the latter number requires less GPU memory but longer time
# Random seed
seed = 42 # The Answer to the Ultimate Question of Life, The Universe, and Everything
# Model settings
config_path = "config/default.yaml"
model_name = "Ruyi-Mini-7B"
model_type = "Inpaint"
model_path = f"models/{model_name}" # (Down)load mode in this path
auto_download = True # Automatically download the model if the pipeline creation fails
auto_update = True # If auto_download is enabled, check for updates and update the model if necessary
# LoRA settings
lora_path = None
lora_weight = 1.0
# Other settings
weight_dtype = torch.bfloat16
device = torch.device("cuda")
def get_control_embeddings(pipeline, aspect_ratio, motion, camera_direction):
# Default keys
p_default_key = "p.default"
n_default_key = "n.default"
# Load embeddings
if motion == "auto":
motion = "0"
p_key = f"p.{aspect_ratio.replace(':', 'x')}movie{motion}{camera_direction}"
embeddings = pipeline.embeddings
# Get embeddings
positive_embeds = embeddings.get(f"{p_key}.emb1", embeddings[f"{p_default_key}.emb1"])
positive_attention_mask = embeddings.get(f"{p_key}.mask1", embeddings[f"{p_default_key}.mask1"])
positive_embeds_2 = embeddings.get(f"{p_key}.emb2", embeddings[f"{p_default_key}.emb2"])
positive_attention_mask_2 = embeddings.get(f"{p_key}.mask2", embeddings[f"{p_default_key}.mask2"])
negative_embeds = embeddings[f"{n_default_key}.emb1"]
negative_attention_mask = embeddings[f"{n_default_key}.mask1"]
negative_embeds_2 = embeddings[f"{n_default_key}.emb2"]
negative_attention_mask_2 = embeddings[f"{n_default_key}.mask2"]
return {
"positive_embeds": positive_embeds,
"positive_attention_mask": positive_attention_mask,
"positive_embeds_2": positive_embeds_2,
"positive_attention_mask_2": positive_attention_mask_2,
"negative_embeds": negative_embeds,
"negative_attention_mask": negative_attention_mask,
"negative_embeds_2": negative_embeds_2,
"negative_attention_mask_2": negative_attention_mask_2,
}
def try_setup_pipeline(model_path, weight_dtype, config):
try:
# Get Vae
vae = AutoencoderKLMagvit.from_pretrained(
model_path,
subfolder="vae"
).to(weight_dtype)
print("Vae loaded ...")
# Get Transformer
transformer_additional_kwargs = OmegaConf.to_container(config['transformer_additional_kwargs'])
transformer = HunyuanTransformer3DModel.from_pretrained_2d(
model_path,
subfolder="transformer",
transformer_additional_kwargs=transformer_additional_kwargs
).to(weight_dtype)
print("Transformer loaded ...")
# Load Clip
clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
model_path, subfolder="image_encoder"
).to(weight_dtype)
clip_image_processor = CLIPImageProcessor.from_pretrained(
model_path, subfolder="image_encoder"
)
# Load sampler and create pipeline
Choosen_Scheduler = DDIMScheduler
scheduler = Choosen_Scheduler.from_pretrained(
model_path,
subfolder="scheduler"
)
pipeline = RuyiInpaintPipeline.from_pretrained(
model_path,
vae=vae,
transformer=transformer,
scheduler=scheduler,
torch_dtype=weight_dtype,
clip_image_encoder=clip_image_encoder,
clip_image_processor=clip_image_processor,
)
# Load embeddings
embeddings = load_safetensors(os.path.join(model_path, "embeddings.safetensors"))
pipeline.embeddings = embeddings
print("Pipeline loaded ...")
return pipeline
except Exception as e:
print("[Ruyi] Setup pipeline failed:", e)
return None
# Load config
config = OmegaConf.load(config_path)
# Load images
start_img = [Image.open(start_image_path).convert("RGB")]
end_img = [Image.open(end_image_path).convert("RGB")] if end_image_path is not None else None
# Check for update
repo_id = f"IamCreateAI/{model_name}"
if auto_download and auto_update:
print(f"Checking for {model_name} updates ...")
# Download the model
snapshot_download(repo_id=repo_id, local_dir=model_path)
# Init model
pipeline = try_setup_pipeline(model_path, weight_dtype, config)
if pipeline is None and auto_download:
print(f"Downloading {model_name} ...")
# Download the model
snapshot_download(repo_id=repo_id, local_dir=model_path)
pipeline = try_setup_pipeline(model_path, weight_dtype, config)
if pipeline is None:
message = (f"[Load Model Failed] "
f"Please download Ruyi model from huggingface repo '{repo_id}', "
f"And put it into '{model_path}'.")
if not auto_download:
message += "\nOr just set auto_download to 'True'."
raise FileNotFoundError(message)
# Setup GPU memory mode
if low_gpu_memory_mode:
pipeline.enable_sequential_cpu_offload()
else:
pipeline.enable_model_cpu_offload()
# Prepare LoRA config
loras = {
'models': [lora_path] if lora_path is not None else [],
'weights': [lora_weight] if lora_path is not None else [],
}
# Count most suitable height and width
if video_size is None:
aspect_ratio_sample_size = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
original_width, original_height = start_img[0].size if type(start_img) is list else Image.open(start_img).size
closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
height, width = [int(x / 16) * 16 for x in closest_size]
else:
height, width = video_size
# Set hidden states offload steps
pipeline.transformer.hidden_cache_size = gpu_offload_steps
# Load Sampler
if scheduler_name == "DPM++":
noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(model_path, subfolder='scheduler')
elif scheduler_name == "Euler":
noise_scheduler = EulerDiscreteScheduler.from_pretrained(model_path, subfolder='scheduler')
elif scheduler_name == "Euler A":
noise_scheduler = EulerAncestralDiscreteScheduler.from_pretrained(model_path, subfolder='scheduler')
elif scheduler_name == "PNDM":
noise_scheduler = PNDMScheduler.from_pretrained(model_path, subfolder='scheduler')
elif scheduler_name == "DDIM":
noise_scheduler = DDIMScheduler.from_pretrained(model_path, subfolder='scheduler')
pipeline.scheduler = noise_scheduler
# Set random seed
generator= torch.Generator(device).manual_seed(seed)
# Load control embeddings
embeddings = get_control_embeddings(pipeline, aspect_ratio, motion, camera_direction)
with torch.no_grad():
video_length = int(video_length // pipeline.vae.mini_batch_encoder * pipeline.vae.mini_batch_encoder) if video_length != 1 else 1
input_video, input_video_mask, clip_image = get_image_to_video_latent(start_img, end_img, video_length=video_length, sample_size=(height, width))
for _lora_path, _lora_weight in zip(loras.get("models", []), loras.get("weights", [])):
pipeline = merge_lora(pipeline, _lora_path, _lora_weight)
sample = pipeline(
prompt_embeds = embeddings["positive_embeds"],
prompt_attention_mask = embeddings["positive_attention_mask"],
prompt_embeds_2 = embeddings["positive_embeds_2"],
prompt_attention_mask_2 = embeddings["positive_attention_mask_2"],
negative_prompt_embeds = embeddings["negative_embeds"],
negative_prompt_attention_mask = embeddings["negative_attention_mask"],
negative_prompt_embeds_2 = embeddings["negative_embeds_2"],
negative_prompt_attention_mask_2 = embeddings["negative_attention_mask_2"],
video_length = video_length,
height = height,
width = width,
generator = generator,
guidance_scale = cfg,
num_inference_steps = steps,
video = input_video,
mask_video = input_video_mask,
clip_image = clip_image,
).videos
for _lora_path, _lora_weight in zip(loras.get("models", []), loras.get("weights", [])):
pipeline = unmerge_lora(pipeline, _lora_path, _lora_weight)
# Save the video
output_folder = os.path.dirname(output_video_path)
if output_folder != '':
os.makedirs(output_folder, exist_ok=True)
save_videos_grid(sample, output_video_path, fps=24)
[project]
name = "ruyi-models"
description = "ComfyUI wrapper nodes for Ruyi, an image-to-video model by CreateAI."
version = "1.0.1"
license = {file = "LICENSE"}
dependencies = ["Pillow", "einops", "safetensors", "timm", "tomesd", "torch", "torchdiffeq", "torchsde", "decord", "datasets", "numpy", "scikit-image", "opencv-python", "omegaconf", "SentencePiece", "albumentations", "imageio[ffmpeg]", "imageio[pyav]", "tensorboard", "beautifulsoup4", "ftfy", "func_timeout", "huggingface_hub", "accelerate>=0.26.0", "diffusers>=0.28.2", "transformers>=4.37.2"]
[project.urls]
Repository = "https://github.com/IamCreateAI/Ruyi-Models"
# Used by Comfy Registry https://comfyregistry.org
[tool.comfy]
PublisherId = "CreateAI"
DisplayName = "Ruyi-Models"
Icon = ""
Pillow
einops
safetensors
timm
tomesd
torch
torchdiffeq
torchsde
decord
datasets
numpy
scikit-image
opencv-python
omegaconf
SentencePiece
albumentations
imageio[ffmpeg]
imageio[pyav]
tensorboard
beautifulsoup4
ftfy
func_timeout
huggingface_hub
accelerate>=0.26.0
diffusers>=0.28.2
transformers>=4.37.2
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
Sized, TypeVar, Union)
import cv2
import numpy as np
import torch
from PIL import Image
from torch.utils.data import BatchSampler, Dataset, Sampler
ASPECT_RATIO_512 = {
'0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
'0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
'0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
'0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
'0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
'1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
'1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
'1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
'2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
'3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
}
ASPECT_RATIO_RANDOM_CROP_512 = {
'0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
'0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
'0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
'1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
'2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
}
ASPECT_RATIO_RANDOM_CROP_PROB = [
1, 2,
4, 4, 4, 4,
8, 8, 8,
4, 4, 4, 4,
2, 1
]
ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
aspect_ratio = height / width
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
return ratios[closest_ratio], float(closest_ratio)
def get_image_size_without_loading(path):
with Image.open(path) as img:
return img.size # (width, height)
class RandomSampler(Sampler[int]):
r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
If with replacement, then user can specify :attr:`num_samples` to draw.
Args:
data_source (Dataset): dataset to sample from
replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
num_samples (int): number of samples to draw, default=`len(dataset)`.
generator (Generator): Generator used in sampling.
"""
data_source: Sized
replacement: bool
def __init__(self, data_source: Sized, replacement: bool = False,
num_samples: Optional[int] = None, generator=None) -> None:
self.data_source = data_source
self.replacement = replacement
self._num_samples = num_samples
self.generator = generator
self._pos_start = 0
if not isinstance(self.replacement, bool):
raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
if not isinstance(self.num_samples, int) or self.num_samples <= 0:
raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
@property
def num_samples(self) -> int:
# dataset size might change at runtime
if self._num_samples is None:
return len(self.data_source)
return self._num_samples
def __iter__(self) -> Iterator[int]:
n = len(self.data_source)
if self.generator is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator = torch.Generator()
generator.manual_seed(seed)
else:
generator = self.generator
if self.replacement:
for _ in range(self.num_samples // 32):
yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
else:
for _ in range(self.num_samples // n):
xx = torch.randperm(n, generator=generator).tolist()
if self._pos_start >= n:
self._pos_start = 0
print("xx top 10", xx[:10], self._pos_start)
for idx in range(self._pos_start, n):
yield xx[idx]
self._pos_start = (self._pos_start + 1) % n
self._pos_start = 0
yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
def __len__(self) -> int:
return self.num_samples
class AspectRatioBatchImageSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
Args:
sampler (Sampler): Base sampler.
dataset (Dataset): Dataset providing data information.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
aspect_ratios (dict): The predefined aspect ratios.
"""
def __init__(
self,
sampler: Sampler,
dataset: Dataset,
batch_size: int,
train_folder: str = None,
aspect_ratios: dict = ASPECT_RATIO_512,
drop_last: bool = False,
config=None,
**kwargs
) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError('batch_size should be a positive integer value, '
f'but got batch_size={batch_size}')
self.sampler = sampler
self.dataset = dataset
self.train_folder = train_folder
self.batch_size = batch_size
self.aspect_ratios = aspect_ratios
self.drop_last = drop_last
self.config = config
# buckets for each aspect ratio
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
# [str(k) for k, v in aspect_ratios]
self.current_available_bucket_keys = list(aspect_ratios.keys())
def __iter__(self):
for idx in self.sampler:
try:
image_dict = self.dataset[idx]
width, height = image_dict.get("weight", None), image_dict.get("height", None)
if width is None or height is None:
image_id, name = image_dict['file_path'], image_dict['text']
if self.train_folder is None:
image_dir = image_id
else:
image_dir = os.path.join(self.train_folder, image_id)
width, height = get_image_size_without_loading(image_dir)
ratio = height / width # self.dataset[idx]
else:
height = int(height)
width = int(width)
ratio = height / width # self.dataset[idx]
except Exception as e:
print(e)
continue
# find the closest aspect ratio
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
if closest_ratio not in self.current_available_bucket_keys:
continue
bucket = self._aspect_ratio_buckets[closest_ratio]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
class AspectRatioBatchSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
Args:
sampler (Sampler): Base sampler.
dataset (Dataset): Dataset providing data information.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
aspect_ratios (dict): The predefined aspect ratios.
"""
def __init__(
self,
sampler: Sampler,
dataset: Dataset,
batch_size: int,
video_folder: str = None,
train_data_format: str = "webvid",
aspect_ratios: dict = ASPECT_RATIO_512,
drop_last: bool = False,
config=None,
**kwargs
) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError('batch_size should be a positive integer value, '
f'but got batch_size={batch_size}')
self.sampler = sampler
self.dataset = dataset
self.video_folder = video_folder
self.train_data_format = train_data_format
self.batch_size = batch_size
self.aspect_ratios = aspect_ratios
self.drop_last = drop_last
self.config = config
# buckets for each aspect ratio
self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
# [str(k) for k, v in aspect_ratios]
self.current_available_bucket_keys = list(aspect_ratios.keys())
def __iter__(self):
for idx in self.sampler:
try:
video_dict = self.dataset[idx]
width, more = video_dict.get("width", None), video_dict.get("height", None)
if width is None or height is None:
if self.train_data_format == "normal":
video_id, name = video_dict['file_path'], video_dict['text']
if self.video_folder is None:
video_dir = video_id
else:
video_dir = os.path.join(self.video_folder, video_id)
else:
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
cap = cv2.VideoCapture(video_dir)
# 获取视频尺寸
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
ratio = height / width # self.dataset[idx]
else:
height = int(height)
width = int(width)
ratio = height / width # self.dataset[idx]
except Exception as e:
print(e)
continue
# find the closest aspect ratio
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
if closest_ratio not in self.current_available_bucket_keys:
continue
bucket = self._aspect_ratio_buckets[closest_ratio]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
class AspectRatioBatchImageVideoSampler(BatchSampler):
"""A sampler wrapper for grouping images with similar aspect ratio into a same batch.
Args:
sampler (Sampler): Base sampler.
dataset (Dataset): Dataset providing data information.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``.
aspect_ratios (dict): The predefined aspect ratios.
"""
def __init__(self,
sampler: Sampler,
dataset: Dataset,
batch_size: int,
train_folder: str = None,
aspect_ratios: dict = ASPECT_RATIO_512,
drop_last: bool = False
) -> None:
if not isinstance(sampler, Sampler):
raise TypeError('sampler should be an instance of ``Sampler``, '
f'but got {sampler}')
if not isinstance(batch_size, int) or batch_size <= 0:
raise ValueError('batch_size should be a positive integer value, '
f'but got batch_size={batch_size}')
self.sampler = sampler
self.dataset = dataset
self.train_folder = train_folder
self.batch_size = batch_size
self.aspect_ratios = aspect_ratios
self.drop_last = drop_last
# buckets for each aspect ratio
self.current_available_bucket_keys = list(aspect_ratios.keys())
self.bucket = {
'image':{ratio: [] for ratio in aspect_ratios},
'video':{ratio: [] for ratio in aspect_ratios}
}
def __iter__(self):
for idx in self.sampler:
content_type = self.dataset[idx].get('type', 'image')
if content_type == 'image':
try:
image_dict = self.dataset[idx]
width, height = image_dict.get("width", None), image_dict.get("height", None)
if width is None or height is None:
image_id, name = image_dict['file_path'], image_dict['text']
if self.train_folder is None:
image_dir = image_id
else:
image_dir = os.path.join(self.train_folder, image_id)
width, height = get_image_size_without_loading(image_dir)
ratio = height / width # self.dataset[idx]
else:
height = int(height)
width = int(width)
ratio = height / width # self.dataset[idx]
except Exception as e:
print(e)
continue
# find the closest aspect ratio
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
if closest_ratio not in self.current_available_bucket_keys:
continue
bucket = self.bucket['image'][closest_ratio]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
else:
try:
video_dict = self.dataset[idx]
width, height = video_dict.get("width", None), video_dict.get("height", None)
if width is None or height is None:
video_id, name = video_dict['file_path'], video_dict['text']
if self.train_folder is None:
video_dir = video_id
else:
video_dir = os.path.join(self.train_folder, video_id)
cap = cv2.VideoCapture(video_dir)
# 获取视频尺寸
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
ratio = height / width # self.dataset[idx]
else:
height = int(height)
width = int(width)
ratio = height / width # self.dataset[idx]
except Exception as e:
print(e)
continue
# find the closest aspect ratio
closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
if closest_ratio not in self.current_available_bucket_keys:
continue
bucket = self.bucket['video'][closest_ratio]
bucket.append(idx)
# yield a batch of indices in the same aspect ratio group
if len(bucket) == self.batch_size:
yield bucket[:]
del bucket[:]
\ No newline at end of file
from .attention import *
from .transformer2d import *
from .transformer3d import *
from .autoencoder_magvit import *
from .embeddings import *
from .motion_module import *
from .norm import *
from .patch import *
from .resampler import *
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
import diffusers
import pkg_resources
import torch
import torch.nn.functional as F
import torch.nn.init as init
installed_version = diffusers.__version__
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
from diffusers.models.attention_processor import (Attention,
AttnProcessor2_0,
HunyuanAttnProcessor2_0)
else:
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from diffusers.models.attention import AdaLayerNorm, FeedForward
from diffusers.models.embeddings import SinusoidalPositionalEmbedding
from diffusers.models.normalization import AdaLayerNormZero
from diffusers.utils import USE_PEFT_BACKEND
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import maybe_allow_in_graph
from einops import rearrange, repeat
from torch import nn
from .motion_module import PositionalEncoding, get_motion_module
from .norm import FP32LayerNorm, AdaLayerNormShift
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
r"""
A gated self-attention dense layer that combines visual features and object features.
Parameters:
query_dim (`int`): The number of channels in the query.
context_dim (`int`): The number of channels in the context.
n_heads (`int`): The number of heads to use for attention.
d_head (`int`): The number of channels in each head.
"""
def __init__(self, query_dim: int, context_dim: int, n_heads: int, d_head: int):
super().__init__()
# we need a linear projection since we need cat visual feature and obj feature
self.linear = nn.Linear(context_dim, query_dim)
self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
self.ff = FeedForward(query_dim, activation_fn="geglu")
self.norm1 = FP32LayerNorm(query_dim)
self.norm2 = FP32LayerNorm(query_dim)
self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
self.enabled = True
def forward(self, x: torch.Tensor, objs: torch.Tensor) -> torch.Tensor:
if not self.enabled:
return x
n_visual = x.shape[1]
objs = self.linear(objs)
x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))
return x
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
class KVCompressionCrossAttention(nn.Module):
r"""
A cross attention layer.
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.scale = dim_head**-0.5
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
# You can set slice_size with `set_attention_slice`
self.sliceable_head_dim = heads
self._slice_size = None
self._use_memory_efficient_attention_xformers = True
self.added_kv_proj_dim = added_kv_proj_dim
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.kv_compression = nn.Conv2d(
query_dim,
query_dim,
groups=query_dim,
kernel_size=2,
stride=2,
bias=True
)
self.kv_compression_norm = FP32LayerNorm(query_dim)
init.constant_(self.kv_compression.weight, 1 / 4)
if self.kv_compression.bias is not None:
init.constant_(self.kv_compression.bias, 0)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def set_attention_slice(self, slice_size):
if slice_size is not None and slice_size > self.sliceable_head_dim:
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
self._slice_size = slice_size
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, num_frames: int = 16, height: int = 32, width: int = 32):
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
key = self.kv_compression(key)
key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
key = self.kv_compression_norm(key)
key = key.to(query.dtype)
value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
value = self.kv_compression(value)
value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
value = self.kv_compression_norm(value)
value = value.to(query.dtype)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
else:
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = rearrange(key, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
key = self.kv_compression(key)
key = rearrange(key, "(b f) c h w -> b (f h w) c", f=num_frames)
key = self.kv_compression_norm(key)
key = key.to(query.dtype)
value = rearrange(value, "b (f h w) c -> (b f) c h w", f=num_frames, h=height, w=width)
value = self.kv_compression(value)
value = rearrange(value, "(b f) c h w -> b (f h w) c", f=num_frames)
value = self.kv_compression_norm(value)
value = value.to(query.dtype)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _attention(self, query, key, value, attention_mask=None):
if self.upcast_attention:
query = query.float()
key = key.float()
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
attention_scores = attention_scores + attention_mask
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
# cast back to the original dtype
attention_probs = attention_probs.to(value.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
batch_size_attention = query.shape[0]
hidden_states = torch.zeros(
(batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
)
slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
query_slice = query[start_idx:end_idx]
key_slice = key[start_idx:end_idx]
if self.upcast_attention:
query_slice = query_slice.float()
key_slice = key_slice.float()
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
query_slice,
key_slice.transpose(-1, -2),
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
attn_slice = attn_slice + attention_mask[start_idx:end_idx]
if self.upcast_softmax:
attn_slice = attn_slice.float()
attn_slice = attn_slice.softmax(dim=-1)
# cast back to the original dtype
attn_slice = attn_slice.to(value.dtype)
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
# TODO attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
@maybe_allow_in_graph
class TemporalTransformerBlock(nn.Module):
r"""
A Temporal Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
# kv compression
kvcompression: Optional[bool] = False,
# motion module kwargs
motion_module_type = "VanillaGrid",
motion_module_kwargs = None,
qk_norm = False,
after_norm = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.kvcompression = kvcompression
if kvcompression:
self.attn1 = KVCompressionCrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
else:
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
qk_norm="layer_norm" if qk_norm else None,
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
)
else:
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
self.attn_temporal = get_motion_module(
in_channels = dim,
motion_module_type = motion_module_type,
motion_module_kwargs = motion_module_kwargs,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
)
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
qk_norm="layer_norm" if qk_norm else None,
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
) # is self-attn if encoder_hidden_states is none
else:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if not self.use_ada_layer_norm_single:
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
if after_norm:
self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
else:
self.norm4 = None
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if self.use_ada_layer_norm_single:
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
num_frames: int = 16,
height: int = 32,
width: int = 32,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
norm_hidden_states = rearrange(norm_hidden_states, "b (f d) c -> (b f) d c", f=num_frames)
if self.kvcompression:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
num_frames=1,
height=height,
width=width,
**cross_attention_kwargs,
)
else:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
attn_output = rearrange(attn_output, "(b f) d c -> b (f d) c", f=num_frames)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.use_ada_layer_norm_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 2.75. Temp-Attention
if self.attn_temporal is not None:
attn_output = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=num_frames, h=height, w=width)
attn_output = self.attn_temporal(attn_output)
hidden_states = rearrange(attn_output, "b c f h w -> b (f h w) c")
# 3. Cross-Attention
if self.attn2 is not None:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
norm_hidden_states = self.norm2(hidden_states)
elif self.use_ada_layer_norm_single:
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
if norm_hidden_states.dtype != encoder_hidden_states.dtype or norm_hidden_states.dtype != encoder_attention_mask.dtype:
norm_hidden_states = norm_hidden_states.to(encoder_hidden_states.dtype)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
if not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.use_ada_layer_norm_single:
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice, scale=lora_scale)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.norm4 is not None:
ff_output = self.norm4(ff_output)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.use_ada_layer_norm_single:
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
@maybe_allow_in_graph
class SelfAttentionTemporalTransformerBlock(nn.Module):
r"""
A Temporal Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
qk_norm = False,
after_norm = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
qk_norm="layer_norm" if qk_norm else None,
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
)
else:
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
)
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
qk_norm="layer_norm" if qk_norm else None,
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
) # is self-attn if encoder_hidden_states is none
else:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if not self.use_ada_layer_norm_single:
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
if after_norm:
self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
else:
self.norm4 = None
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if self.use_ada_layer_norm_single:
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.use_ada_layer_norm_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
norm_hidden_states = self.norm2(hidden_states)
elif self.use_ada_layer_norm_single:
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
if not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.use_ada_layer_norm_single:
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice, scale=lora_scale)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.norm4 is not None:
ff_output = self.norm4(ff_output)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.use_ada_layer_norm_single:
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
@maybe_allow_in_graph
class KVCompressionTransformerBlock(nn.Module):
r"""
A Temporal Transformer block.
Parameters:
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
only_cross_attention (`bool`, *optional*):
Whether to use only cross-attention layers. In this case two cross attention layers are used.
double_self_attention (`bool`, *optional*):
Whether to use two self-attention layers. In this case no cross attention layers are used.
upcast_attention (`bool`, *optional*):
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_type (`str`, *optional*, defaults to `"layer_norm"`):
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
attention_type (`str`, *optional*, defaults to `"default"`):
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
positional_embeddings (`str`, *optional*, defaults to `None`):
The type of positional embeddings to apply to.
num_positional_embeddings (`int`, *optional*, defaults to `None`):
The maximum number of positional embeddings to apply.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
norm_eps: float = 1e-5,
final_dropout: bool = False,
attention_type: str = "default",
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
kvcompression: Optional[bool] = False,
qk_norm = False,
after_norm = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
self.use_layer_norm = norm_type == "layer_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
if positional_embeddings and (num_positional_embeddings is None):
raise ValueError(
"If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
)
if positional_embeddings == "sinusoidal":
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
else:
self.pos_embed = None
# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.kvcompression = kvcompression
if kvcompression:
self.attn1 = KVCompressionCrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
else:
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
qk_norm="layer_norm" if qk_norm else None,
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
)
else:
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
# 2. Cross-Attn
if cross_attention_dim is not None or double_self_attention:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
)
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
qk_norm="layer_norm" if qk_norm else None,
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
) # is self-attn if encoder_hidden_states is none
else:
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim if not double_self_attention else None,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None
# 3. Feed-forward
if not self.use_ada_layer_norm_single:
self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
if after_norm:
self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
else:
self.norm4 = None
# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
# 5. Scale-shift for PixArt-Alpha.
if self.use_ada_layer_norm_single:
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
timestep: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
class_labels: Optional[torch.LongTensor] = None,
num_frames: int = 16,
height: int = 32,
width: int = 32,
use_reentrant: bool = False,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
elif self.use_layer_norm:
norm_hidden_states = self.norm1(hidden_states)
elif self.use_ada_layer_norm_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
else:
raise ValueError("Incorrect norm used")
if self.pos_embed is not None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
# 1. Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 2. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
if self.kvcompression:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
num_frames=num_frames,
height=height,
width=width,
**cross_attention_kwargs,
)
else:
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if self.only_cross_attention else None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
elif self.use_ada_layer_norm_single:
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 2.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 3. Cross-Attention
if self.attn2 is not None:
if self.use_ada_layer_norm:
norm_hidden_states = self.norm2(hidden_states, timestep)
elif self.use_ada_layer_norm_zero or self.use_layer_norm:
norm_hidden_states = self.norm2(hidden_states)
elif self.use_ada_layer_norm_single:
# For PixArt norm2 isn't applied here:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
norm_hidden_states = hidden_states
else:
raise ValueError("Incorrect norm")
if self.pos_embed is not None and self.use_ada_layer_norm_single is None:
norm_hidden_states = self.pos_embed(norm_hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
**cross_attention_kwargs,
)
hidden_states = attn_output + hidden_states
# 4. Feed-forward
if not self.use_ada_layer_norm_single:
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
if self.use_ada_layer_norm_single:
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)
num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[
self.ff(hid_slice, scale=lora_scale)
for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states, scale=lora_scale)
if self.norm4 is not None:
ff_output = self.norm4(ff_output)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
elif self.use_ada_layer_norm_single:
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out, norm_elementwise_affine):
super().__init__()
self.norm = FP32LayerNorm(dim_in, dim_in, norm_elementwise_affine)
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(self.norm(x)).chunk(2, dim=-1)
return x * F.gelu(gate)
@maybe_allow_in_graph
class HunyuanDiTBlock(nn.Module):
r"""
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
QKNorm
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of headsto use for multi-head attention.
cross_attention_dim (`int`,*optional*):
The size of the encoder_hidden_states vector for cross attention.
dropout(`float`, *optional*, defaults to 0.0):
The dropout probability to use.
activation_fn (`str`,*optional*, defaults to `"geglu"`):
Activation function to be used in feed-forward. .
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, *optional*, defaults to 1e-6):
A small constant added to the denominator in normalization layers to prevent division by zero.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*):
The size of the hidden layer in the feed-forward block. Defaults to `None`.
ff_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the feed-forward block.
skip (`bool`, *optional*, defaults to `False`):
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
qk_norm (`bool`, *optional*, defaults to `True`):
Whether to use normalization in QK calculation. Defaults to `True`.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
cross_attention_dim: int = 1024,
dropout=0.0,
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-6,
final_dropout: bool = False,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
skip: bool = False,
qk_norm: bool = True,
time_position_encoding: bool = False,
after_norm: bool = False,
is_local_attention: bool = False,
local_attention_frames: int = 2,
enable_inpaint: bool = False,
):
super().__init__()
# Define 3 blocks. Each block has its own normalization layer.
# NOTE: when new version comes, check norm2 and norm 3
# 1. Self-Attn
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.t_embed = PositionalEncoding(dim, dropout=0., max_len=512) if time_position_encoding else nn.Identity()
self.attn1 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
processor=HunyuanAttnProcessor2_0(),
)
# 2. Cross-Attn
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.is_local_attention = is_local_attention
self.local_attention_frames = local_attention_frames
if self.is_local_attention:
from mamba_ssm import Mamba2
self.mamba_norm_in = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.in_linear = nn.Linear(dim, 1536)
self.mamba_norm_1 = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine)
self.mamba_norm_2 = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine)
self.mamba_block_1 = Mamba2(
d_model=1536,
d_state=64,
d_conv=4,
expand=2,
)
self.mamba_block_2 = Mamba2(
d_model=1536,
d_state=64,
d_conv=4,
expand=2,
)
self.mamba_norm_after_mamba_block = FP32LayerNorm(1536, norm_eps, norm_elementwise_affine)
self.out_linear = nn.Linear(1536, dim)
self.out_linear = zero_module(self.out_linear)
self.mamba_norm_out = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
processor=HunyuanAttnProcessor2_0(),
)
if enable_inpaint:
self.norm_clip = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn_clip = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
processor=HunyuanAttnProcessor2_0(),
)
self.gate_clip = GEGLU(dim, dim, norm_elementwise_affine)
self.norm_clip_out = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
else:
self.attn_clip = None
self.norm_clip = None
self.gate_clip = None
self.norm_clip_out = None
# 3. Feed-forward
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout, ### 0.0
activation_fn=activation_fn, ### approx GeLU
final_dropout=final_dropout, ### 0.0
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
bias=ff_bias,
)
# 4. Skip Connection
if skip:
self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True)
self.skip_linear = nn.Linear(2 * dim, dim)
else:
self.skip_linear = None
if after_norm:
print("add after norm")
self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
else:
self.norm4 = None
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
self.is_local_attention = is_local_attention
self.local_attention_frames = local_attention_frames
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb=None,
skip=None,
num_frames: int = 1,
height: int = 32,
width: int = 32,
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([hidden_states, skip], dim=-1)
cat = self.skip_norm(cat)
hidden_states = self.skip_linear(cat)
if num_frames != 1:
image_rotary_emb = (torch.cat([image_rotary_emb[0] for i in range(num_frames)], dim=0), torch.cat([image_rotary_emb[1] for i in range(num_frames)], dim=0))
# add time embedding
hidden_states = rearrange(hidden_states, "b (f d) c -> (b d) f c", f=num_frames)
if self.t_embed is not None:
hidden_states = self.t_embed(hidden_states)
hidden_states = rearrange(hidden_states, "(b d) f c -> b (f d) c", d=height * width)
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct
if num_frames > 2 and self.is_local_attention:
attn1_image_rotary_emb = (image_rotary_emb[0][:int(height * width * 2)], image_rotary_emb[1][:int(height * width * 2)])
norm_hidden_states_1 = rearrange(norm_hidden_states, "b (f d) c -> b f d c", d=height * width)
norm_hidden_states_1 = rearrange(norm_hidden_states_1, "b (f p) d c -> (b f) (p d) c", p = 2)
attn_output = self.attn1(
norm_hidden_states_1,
image_rotary_emb=attn1_image_rotary_emb,
)
attn_output = rearrange(attn_output, "(b f) (p d) c -> b (f p) d c", p = 2, f = num_frames // 2)
norm_hidden_states_2 = rearrange(norm_hidden_states, "b (f d) c -> b f d c", d = height * width)[:, 1:-1]
local_attention_frames_num = norm_hidden_states_2.size()[1] // 2
norm_hidden_states_2 = rearrange(norm_hidden_states_2, "b (f p) d c -> (b f) (p d) c", p = 2)
attn_output_2 = self.attn1(
norm_hidden_states_2,
image_rotary_emb=attn1_image_rotary_emb,
)
attn_output_2 = rearrange(attn_output_2, "(b f) (p d) c -> b (f p) d c", p = 2, f = local_attention_frames_num)
attn_output[:, 1:-1] = (attn_output[:, 1:-1] + attn_output_2) / 2
attn_output = rearrange(attn_output, "b f d c -> b (f d) c")
else:
attn_output = self.attn1(
norm_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_output
if num_frames > 2 and self.is_local_attention:
hidden_states_in = self.in_linear(self.mamba_norm_in(hidden_states))
hidden_states = hidden_states + self.mamba_norm_out(
self.out_linear(
self.mamba_norm_after_mamba_block(
self.mamba_block_1(
self.mamba_norm_1(hidden_states_in)
) +
self.mamba_block_2(
self.mamba_norm_2(hidden_states_in.flip(1))
).flip(1)
)
)
)
# 2. Cross-Attention
hidden_states = hidden_states + self.attn2(
self.norm2(hidden_states),
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
if self.attn_clip is not None:
hidden_states = hidden_states + self.norm_clip_out(
self.gate_clip(
self.attn_clip(
self.norm_clip(hidden_states),
encoder_hidden_states=clip_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
)
)
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
mlp_inputs = self.norm3(hidden_states)
if self.norm4 is not None:
hidden_states = hidden_states + self.norm4(self.ff(mlp_inputs))
else:
hidden_states = hidden_states + self.ff(mlp_inputs)
return hidden_states
@maybe_allow_in_graph
class HunyuanTemporalTransformerBlock(nn.Module):
r"""
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
QKNorm
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of headsto use for multi-head attention.
cross_attention_dim (`int`,*optional*):
The size of the encoder_hidden_states vector for cross attention.
dropout(`float`, *optional*, defaults to 0.0):
The dropout probability to use.
activation_fn (`str`,*optional*, defaults to `"geglu"`):
Activation function to be used in feed-forward. .
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, *optional*, defaults to 1e-6):
A small constant added to the denominator in normalization layers to prevent division by zero.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*):
The size of the hidden layer in the feed-forward block. Defaults to `None`.
ff_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the feed-forward block.
skip (`bool`, *optional*, defaults to `False`):
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
qk_norm (`bool`, *optional*, defaults to `True`):
Whether to use normalization in QK calculation. Defaults to `True`.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
cross_attention_dim: int = 1024,
dropout=0.0,
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-6,
final_dropout: bool = False,
ff_inner_dim: Optional[int] = None,
ff_bias: bool = True,
skip: bool = False,
qk_norm: bool = True,
after_norm: bool = False,
# motion module kwargs
motion_module_type = "VanillaGrid",
motion_module_kwargs = None,
use_reentrant: bool = False,
):
super().__init__()
# Define 3 blocks. Each block has its own normalization layer.
# NOTE: when new version comes, check norm2 and norm 3
# 1. Self-Attn
self.norm1 = AdaLayerNormShift(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
processor=HunyuanAttnProcessor2_0(),
)
self.attn_temporal = get_motion_module(
in_channels = dim,
motion_module_type = motion_module_type,
motion_module_kwargs = motion_module_kwargs,
)
# 2. Cross-Attn
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
qk_norm="layer_norm" if qk_norm else None,
eps=1e-6,
bias=True,
processor=HunyuanAttnProcessor2_0(),
)
# 3. Feed-forward
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout, ### 0.0
activation_fn=activation_fn, ### approx GeLU
final_dropout=final_dropout, ### 0.0
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
bias=ff_bias,
)
# 4. Skip Connection
if skip:
self.skip_norm = FP32LayerNorm(2 * dim, norm_eps, elementwise_affine=True)
self.skip_linear = nn.Linear(2 * dim, dim)
else:
self.skip_linear = None
if after_norm:
print("add after norm")
self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
else:
self.norm4 = None
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb=None,
skip=None,
num_frames: int = 16,
height: int = 32,
width: int = 32,
use_reentrant: bool = False,
) -> torch.Tensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat([hidden_states, skip], dim=-1)
cat = self.skip_norm(cat)
hidden_states = self.skip_linear(cat)
# 1. Self-Attention
norm_hidden_states = self.norm1(hidden_states, temb) ### checked: self.norm1 is correct
norm_hidden_states = rearrange(norm_hidden_states, "b (f d) c -> (b f) d c", f=num_frames)
attn_output = self.attn1(
norm_hidden_states,
image_rotary_emb=image_rotary_emb,
)
attn_output = rearrange(attn_output, "(b f) d c -> b (f d) c", f=num_frames)
hidden_states = hidden_states + attn_output
# 1.5. Temp-Attention
if self.attn_temporal is not None:
attn_output = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=num_frames, h=height, w=width)
attn_output = self.attn_temporal(attn_output)
hidden_states = rearrange(attn_output, "b c f h w -> b (f h w) c")
if num_frames != 1:
image_rotary_emb = (torch.cat([image_rotary_emb[0] for i in range(num_frames)], dim=0), torch.cat([image_rotary_emb[1] for i in range(num_frames)], dim=0))
# 2. Cross-Attention
hidden_states = hidden_states + self.attn2(
self.norm2(hidden_states),
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
mlp_inputs = self.norm3(hidden_states)
if self.norm4 is not None:
hidden_states = hidden_states + self.norm4(self.ff(mlp_inputs))
else:
hidden_states = hidden_states + self.ff(mlp_inputs)
return hidden_states
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
try:
from diffusers.loaders import FromOriginalVAEMixin
except:
from diffusers.loaders import FromOriginalModelMixin as FromOriginalVAEMixin
from diffusers.models.attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
from diffusers.models.autoencoders.vae import (DecoderOutput,
DiagonalGaussianDistribution)
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils.accelerate_utils import apply_forward_hook
from torch import nn
from ..vae.ldm.models.omnigen_enc_dec import Decoder as omnigen_Mag_Decoder
from ..vae.ldm.models.omnigen_enc_dec import Encoder as omnigen_Mag_Encoder
def str_eval(item):
if type(item) == str:
return eval(item)
else:
return item
class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
Tuple of downsample block types.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
scaling_factor (`float`, *optional*, defaults to 0.18215):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
force_upcast (`bool`, *optional*, default to `True`):
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
can be fine-tuned / trained to a lower range without loosing too much precision in which case
`force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
ch = 128,
ch_mult = [ 1,2,4,4 ],
use_gc_blocks = None,
down_block_types: tuple = None,
up_block_types: tuple = None,
mid_block_type: str = "MidBlock3D",
mid_block_use_attention: bool = True,
mid_block_attention_type: str = "3d",
mid_block_num_attention_heads: int = 1,
layers_per_block: int = 2,
act_fn: str = "silu",
num_attention_heads: int = 1,
latent_channels: int = 4,
norm_num_groups: int = 32,
scaling_factor: float = 0.1825,
slice_mag_vae=True,
slice_compression_vae=False,
cache_compression_vae=False,
use_tiling=False, # True
use_tiling_encoder=False,
use_tiling_decoder=False,
mini_batch_encoder=9,
mini_batch_decoder=3,
upcast_vae=False,
spatial_group_norm=False,
tile_sample_min_size=384,
tile_overlap_factor=0.25,
):
super().__init__()
down_block_types = str_eval(down_block_types)
up_block_types = str_eval(up_block_types)
self.encoder = omnigen_Mag_Encoder(
in_channels=in_channels,
out_channels=latent_channels,
down_block_types=down_block_types,
ch = ch,
ch_mult = ch_mult,
use_gc_blocks=use_gc_blocks,
mid_block_type=mid_block_type,
mid_block_use_attention=mid_block_use_attention,
mid_block_attention_type=mid_block_attention_type,
mid_block_num_attention_heads=mid_block_num_attention_heads,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
num_attention_heads=num_attention_heads,
double_z=True,
slice_mag_vae=slice_mag_vae,
slice_compression_vae=slice_compression_vae,
cache_compression_vae=cache_compression_vae,
mini_batch_encoder=mini_batch_encoder,
spatial_group_norm=spatial_group_norm,
)
self.decoder = omnigen_Mag_Decoder(
in_channels=latent_channels,
out_channels=out_channels,
up_block_types=up_block_types,
ch = ch,
ch_mult = ch_mult,
use_gc_blocks=use_gc_blocks,
mid_block_type=mid_block_type,
mid_block_use_attention=mid_block_use_attention,
mid_block_attention_type=mid_block_attention_type,
mid_block_num_attention_heads=mid_block_num_attention_heads,
layers_per_block=layers_per_block,
norm_num_groups=norm_num_groups,
act_fn=act_fn,
num_attention_heads=num_attention_heads,
slice_mag_vae=slice_mag_vae,
slice_compression_vae=slice_compression_vae,
cache_compression_vae=cache_compression_vae,
mini_batch_decoder=mini_batch_decoder,
spatial_group_norm=spatial_group_norm,
)
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
self.slice_mag_vae = slice_mag_vae
self.slice_compression_vae = slice_compression_vae
self.cache_compression_vae = cache_compression_vae
self.mini_batch_encoder = mini_batch_encoder
self.mini_batch_decoder = mini_batch_decoder
self.use_slicing = False
self.use_tiling = use_tiling
self.use_tiling_encoder = use_tiling_encoder
self.use_tiling_decoder = use_tiling_decoder
self.upcast_vae = upcast_vae
self.tile_sample_min_size = tile_sample_min_size
self.tile_overlap_factor = tile_overlap_factor
self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
self.scaling_factor = scaling_factor
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (omnigen_Mag_Encoder, omnigen_Mag_Decoder)):
module.gradient_checkpointing = value
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors["%s.processor"%(name)] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors("%s.%s"%(name, sub_name), child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop("%s.processor"%(name)))
for sub_name, child in module.named_children():
fn_recursive_attn_processor("%s.%s"%(name, sub_name), child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
@apply_forward_hook
def encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.upcast_vae: # False
x = x.float()
self.encoder = self.encoder.float()
self.quant_conv = self.quant_conv.float()
if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): # True, Almost False
x = self.tiled_encode(x, return_dict=return_dict)
return x
if self.use_tiling_encoder and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): # False, Almost False
x = self.tiled_encode(x, return_dict=return_dict)
return x
if self.use_slicing and x.shape[0] > 1: # False, False when B=1
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict: # False
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
if self.upcast_vae:
z = z.float()
self.decoder = self.decoder.float()
self.post_quant_conv = self.post_quant_conv.float()
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
if self.use_tiling_decoder and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
z = self.post_quant_conv(z)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.FloatTensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
1 - y / blend_extent
) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_h(
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
1 - x / blend_extent
) + b[:, :, :, :, x] * (x / blend_extent)
return b
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[3], overlap_size):
row = []
for j in range(0, x.shape[4], overlap_size):
tile = x[
:,
:,
:,
i : i + self.tile_sample_min_size,
j : j + self.tile_sample_min_size,
]
tile = self.encoder(tile)
tile = self.quant_conv(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
moments = torch.cat(result_rows, dim=3)
posterior = DiagonalGaussianDistribution(moments)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[3], overlap_size):
row = []
for j in range(0, z.shape[4], overlap_size):
tile = z[
:,
:,
:,
i : i + self.tile_latent_min_size,
j : j + self.tile_latent_min_size,
]
tile = self.post_quant_conv(tile)
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
# Handle the lower right corner tile separately
lower_right_original = z[
:,
:,
:,
-self.tile_latent_min_size:,
-self.tile_latent_min_size:
]
quantized_lower_right = self.decoder(self.post_quant_conv(lower_right_original))
# Combine
H, W = quantized_lower_right.size(-2), quantized_lower_right.size(-1)
x_weights = torch.linspace(0, 1, W).unsqueeze(0).repeat(H, 1)
y_weights = torch.linspace(0, 1, H).unsqueeze(1).repeat(1, W)
weights = torch.min(x_weights, y_weights)
if len(dec.size()) == 4:
weights = weights.unsqueeze(0).unsqueeze(0)
elif len(dec.size()) == 5:
weights = weights.unsqueeze(0).unsqueeze(0).unsqueeze(0)
weights = weights.to(dec.device)
quantized_area = dec[:, :, :, -H:, -W:]
combined = weights * quantized_lower_right + (1 - weights) * quantized_area
dec[:, :, :, -H:, -W:] = combined
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.FloatTensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.FloatTensor]:
r"""
Args:
sample (`torch.FloatTensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@classmethod
def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
import json
import os
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
model = cls.from_config(config, **vae_additional_kwargs)
from diffusers.utils import WEIGHTS_NAME
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
if not os.path.isfile(model_file):
raise RuntimeError(f"{model_file} does not exist")
state_dict = torch.load(model_file, map_location="cpu")
m, u = model.load_state_dict(state_dict, strict=False)
return model
import math
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.utils import deprecate
from diffusers.models.activations import FP32SiLU, get_activation
from diffusers.models.attention_processor import Attention
def get_timestep_embedding(
timesteps: torch.Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
embeddings. :return: an [N x dim] Tensor of positional embeddings.
"""
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
half_dim = embedding_dim // 2
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
)
exponent = exponent / (half_dim - downscale_freq_shift)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
# scale embeddings
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
class Timesteps(nn.Module):
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
super().__init__()
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift
def forward(self, timesteps):
t_emb = get_timestep_embedding(
timesteps,
self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
return t_emb
class TimestepEmbedding(nn.Module):
def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim=None,
sample_proj_bias=True,
):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
if cond_proj_dim is not None:
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None
self.act = get_activation(act_fn)
if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
if post_act_fn is None:
self.post_act = None
else:
self.post_act = get_activation(post_act_fn)
def forward(self, sample, condition=None):
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)
if self.act is not None:
sample = self.act(sample)
sample = self.linear_2(sample)
if self.post_act is not None:
sample = self.post_act(sample)
return sample
class PixArtAlphaTextProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
super().__init__()
if out_features is None:
out_features = hidden_size
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
if act_fn == "gelu_tanh":
self.act_1 = nn.GELU(approximate="tanh")
elif act_fn == "silu_fp32":
self.act_1 = FP32SiLU()
else:
raise ValueError(f"Unknown activation function: {act_fn}")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
import torch
import torch.nn as nn
import torch.nn.functional as F
class HunyuanDiTAttentionPool(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = torch.cat([x.mean(dim=1, keepdim=True), x], dim=1)
x = x + self.positional_embedding[None, :, :].to(x.dtype)
query = self.q_proj(x[:, :1])
key = self.k_proj(x)
value = self.v_proj(x)
batch_size, _, _ = query.size()
query = query.reshape(batch_size, -1, self.num_heads, query.size(-1) // self.num_heads).transpose(1, 2) # (1, H, N, E/H)
key = key.reshape(batch_size, -1, self.num_heads, key.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H)
value = value.reshape(batch_size, -1, self.num_heads, value.size(-1) // self.num_heads).transpose(1, 2) # (L+1, H, N, E/H)
x = F.scaled_dot_product_attention(query=query, key=key, value=value, attn_mask=None, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, 1, -1)
x = x.to(query.dtype)
x = self.c_proj(x)
return x.squeeze(1)
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.pooler = HunyuanDiTAttentionPool(
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
)
# Here we use a default learned embedder layer for future extension.
self.style_embedder = nn.Embedding(1, embedding_dim)
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
self.extra_embedder = PixArtAlphaTextProjection(
in_features=extra_in_dim,
hidden_size=embedding_dim * 4,
out_features=embedding_dim,
act_fn="silu_fp32",
)
def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256)
# extra condition1: text
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
# extra condition2: image meta size embdding
image_meta_size = get_timestep_embedding(image_meta_size.view(-1), 256, True, 0)
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
# extra condition3: style embedding
style_embedding = self.style_embedder(style) # (N, embedding_dim)
# Concatenate all extra vectors
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
return conditioning
\ No newline at end of file
"""Modified from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
"""
import math
import diffusers
import pkg_resources
import torch
installed_version = diffusers.__version__
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
from diffusers.models.attention_processor import (Attention,
AttnProcessor2_0,
HunyuanAttnProcessor2_0)
else:
from diffusers.models.attention_processor import Attention, AttnProcessor2_0
from diffusers.models.attention import FeedForward
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange, repeat
from torch import nn
from .norm import FP32LayerNorm
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
def get_motion_module(
in_channels,
motion_module_type: str,
motion_module_kwargs: dict,
):
if motion_module_type == "Vanilla":
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
elif motion_module_type == "VanillaGrid":
return VanillaTemporalModule(in_channels=in_channels, grid=True, **motion_module_kwargs,)
else:
raise ValueError
class VanillaTemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads = 8,
num_transformer_block = 2,
attention_block_types =( "Temporal_Self", "Temporal_Self" ),
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 4096,
temporal_attention_dim_div = 1,
zero_initialize = True,
block_size = 1,
grid = False,
remove_time_embedding_in_photo = False,
global_num_attention_heads = 16,
global_attention = False,
qk_norm = False,
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
grid=grid,
block_size=block_size,
remove_time_embedding_in_photo=remove_time_embedding_in_photo,
qk_norm=qk_norm,
)
self.global_transformer = GlobalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=global_num_attention_heads,
attention_head_dim=in_channels // global_num_attention_heads // temporal_attention_dim_div,
qk_norm=qk_norm,
) if global_attention else None
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
if global_attention:
self.global_transformer.proj_out = zero_module(self.global_transformer.proj_out)
def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None):
hidden_states = input_tensor
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
if self.global_transformer is not None:
hidden_states = self.global_transformer(hidden_states)
output = hidden_states
return output
class GlobalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
dropout = 0.0,
attention_bias = False,
upcast_attention = False,
qk_norm = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm1 = FP32LayerNorm(inner_dim)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.norm2 = FP32LayerNorm(inner_dim)
if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
self.attention = Attention(
query_dim=inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
qk_norm="layer_norm" if qk_norm else None,
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
)
else:
self.attention = Attention(
query_dim=inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states):
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length, height, width = hidden_states.shape[2], hidden_states.shape[3], hidden_states.shape[4]
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.proj_in(hidden_states)
# Attention Blocks
hidden_states = self.norm2(hidden_states)
hidden_states = self.attention(hidden_states)
hidden_states = self.proj_out(hidden_states)
output = hidden_states + residual
output = rearrange(output, "b (f h w) c -> b c f h w", f=video_length, h=height, w=width)
return output
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 4096,
grid = False,
block_size = 1,
remove_time_embedding_in_photo = False,
qk_norm = False,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.block_size = block_size
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
block_size=block_size,
grid=grid,
remove_time_embedding_in_photo=remove_time_embedding_in_photo,
qk_norm=qk_norm
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, height=height, weight=weight)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 4096,
block_size = 1,
grid = False,
remove_time_embedding_in_photo = False,
qk_norm = False,
):
super().__init__()
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
VersatileAttention(
attention_mode=block_name.split("_")[0],
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
block_size=block_size,
grid=grid,
remove_time_embedding_in_photo=remove_time_embedding_in_photo,
qk_norm="layer_norm" if qk_norm else None,
processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
) if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2") else \
VersatileAttention(
attention_mode=block_name.split("_")[0],
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
block_size=block_size,
grid=grid,
remove_time_embedding_in_photo=remove_time_embedding_in_photo,
)
)
norms.append(FP32LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = FP32LayerNorm(dim)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
for attention_block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(hidden_states)
hidden_states = attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
video_length=video_length,
height=height,
weight=weight,
) + hidden_states
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
class PositionalEncoding(nn.Module):
def __init__(
self,
d_model,
dropout = 0.,
max_len = 4096
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class VersatileAttention(Attention):
def __init__(
self,
attention_mode = None,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 4096,
grid = False,
block_size = 1,
remove_time_embedding_in_photo = False,
*args, **kwargs
):
super().__init__(*args, **kwargs)
assert attention_mode == "Temporal" or attention_mode == "Global"
self.attention_mode = attention_mode
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
self.block_size = block_size
self.grid = grid
self.remove_time_embedding_in_photo = remove_time_embedding_in_photo
self.pos_encoder = PositionalEncoding(
kwargs["query_dim"],
dropout=0.,
max_len=temporal_position_encoding_max_len
) if (temporal_position_encoding and attention_mode == "Temporal") or (temporal_position_encoding and attention_mode == "Global") else None
def extra_repr(self):
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
batch_size, sequence_length, _ = hidden_states.shape
if self.attention_mode == "Temporal":
# for add pos_encoder
_, before_d, _c = hidden_states.size()
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
if self.remove_time_embedding_in_photo:
if self.pos_encoder is not None and video_length > 1:
hidden_states = self.pos_encoder(hidden_states)
else:
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
if self.grid:
hidden_states = rearrange(hidden_states, "(b d) f c -> b f d c", f=video_length, d=before_d)
hidden_states = rearrange(hidden_states, "b f (h w) c -> b f h w c", h=height, w=weight)
hidden_states = rearrange(hidden_states, "b f (h n) (w m) c -> (b h w) (f n m) c", n=self.block_size, m=self.block_size)
d = before_d // self.block_size // self.block_size
else:
d = before_d
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
elif self.attention_mode == "Global":
# for add pos_encoder
_, d, _c = hidden_states.size()
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
hidden_states = rearrange(hidden_states, "(b d) f c -> b (f d) c", f=video_length, d=d)
else:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
bs = 512
new_hidden_states = []
for i in range(0, hidden_states.shape[0], bs):
__hidden_states = super().forward(
hidden_states[i : i + bs],
encoder_hidden_states=encoder_hidden_states[i : i + bs],
attention_mask=attention_mask
)
new_hidden_states.append(__hidden_states)
hidden_states = torch.cat(new_hidden_states, dim = 0)
if self.attention_mode == "Temporal":
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
if self.grid:
hidden_states = rearrange(hidden_states, "(b f n m) (h w) c -> (b f) h n w m c", f=video_length, n=self.block_size, m=self.block_size, h=height // self.block_size, w=weight // self.block_size)
hidden_states = rearrange(hidden_states, "b h n w m c -> b (h n) (w m) c")
hidden_states = rearrange(hidden_states, "b h w c -> b (h w) c")
elif self.attention_mode == "Global":
hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length, d=d)
return hidden_states
\ No newline at end of file
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn.functional as F
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from torch import nn
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
class FP32LayerNorm(nn.LayerNorm):
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
origin_dtype = inputs.dtype
if hasattr(self, 'weight') and self.weight is not None:
return F.layer_norm(
inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
).to(origin_dtype)
else:
return F.layer_norm(
inputs.float(), self.normalized_shape, None, None, self.eps
).to(origin_dtype)
class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
"""
For PixArt-Alpha.
Reference:
https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
"""
def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
super().__init__()
self.outdim = size_emb_dim
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.use_additional_conditions = use_additional_conditions
if use_additional_conditions:
self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2)
self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2)
def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
if self.use_additional_conditions:
resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
else:
conditioning = timesteps_emb
return conditioning
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
Parameters:
embedding_dim (`int`): The size of each embedding vector.
use_additional_conditions (`bool`): To use additional conditions for normalization or not.
"""
def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
super().__init__()
self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
def forward(
self,
timestep: torch.Tensor,
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
batch_size: Optional[int] = None,
hidden_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# No modulation happening here.
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
return self.linear(self.silu(embedded_timestep)), embedded_timestep
class AdaLayerNormShift(nn.Module):
r"""
Norm layer modified to incorporate timestep embeddings.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
num_embeddings (`int`): The size of the embeddings dictionary.
"""
def __init__(self, embedding_dim: int, elementwise_affine=True, eps=1e-6):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim)
self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
shift = self.linear(self.silu(emb.to(torch.float32)).to(emb.dtype))
x = self.norm(x) + shift.unsqueeze(dim=1)
return x
\ No newline at end of file
import math
from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn.init as init
from einops import rearrange
from torch import nn
def get_2d_sincos_pos_embed(
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, int):
grid_size = (grid_size, grid_size)
grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale
grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size[1], grid_size[0]])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Patch1D(nn.Module):
def __init__(
self,
channels: int,
use_conv: bool = False,
out_channels: Optional[int] = None,
stride: int = 2,
padding: int = 0,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
self.name = name
if use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, stride, stride=stride, padding=padding)
init.constant_(self.conv.weight, 0.0)
with torch.no_grad():
for i in range(len(self.conv.weight)): self.conv.weight[i, i] = 1 / stride
init.constant_(self.conv.bias, 0.0)
else:
assert self.channels == self.out_channels
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
return self.conv(inputs)
class UnPatch1D(nn.Module):
def __init__(
self,
channels: int,
use_conv: bool = False,
use_conv_transpose: bool = False,
out_channels: Optional[int] = None,
name: str = "conv",
):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
self.conv = None
if use_conv_transpose:
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
elif use_conv:
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
assert inputs.shape[1] == self.channels
if self.use_conv_transpose:
return self.conv(inputs)
outputs = F.interpolate(inputs, scale_factor=2.0, mode="nearest")
if self.use_conv:
outputs = self.conv(outputs)
return outputs
class Upsampler(nn.Module):
def __init__(
self,
spatial_upsample_factor: int = 1,
temporal_upsample_factor: int = 1,
):
super().__init__()
self.spatial_upsample_factor = spatial_upsample_factor
self.temporal_upsample_factor = temporal_upsample_factor
class TemporalUpsampler3D(Upsampler):
def __init__(self):
super().__init__(
spatial_upsample_factor=1,
temporal_upsample_factor=2,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.shape[2] > 1:
first_frame, x = x[:, :, :1], x[:, :, 1:]
x = F.interpolate(x, scale_factor=(2, 1, 1), mode="trilinear")
x = torch.cat([first_frame, x], dim=2)
return x
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
def divisible_by(num, den):
return (num % den) == 0
def is_odd(n):
return not divisible_by(n, 2)
class CausalConv3d(nn.Conv3d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size=3, # : int | tuple[int, int, int],
stride=1, # : int | tuple[int, int, int] = 1,
padding=1, # : int | tuple[int, int, int], # TODO: change it to 0.
dilation=1, # : int | tuple[int, int, int] = 1,
**kwargs,
):
kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size,) * 3
assert len(kernel_size) == 3, f"Kernel size must be a 3-tuple, got {kernel_size} instead."
stride = stride if isinstance(stride, tuple) else (stride,) * 3
assert len(stride) == 3, f"Stride must be a 3-tuple, got {stride} instead."
dilation = dilation if isinstance(dilation, tuple) else (dilation,) * 3
assert len(dilation) == 3, f"Dilation must be a 3-tuple, got {dilation} instead."
t_ks, h_ks, w_ks = kernel_size
_, h_stride, w_stride = stride
t_dilation, h_dilation, w_dilation = dilation
t_pad = (t_ks - 1) * t_dilation
# TODO: align with SD
if padding is None:
h_pad = math.ceil(((h_ks - 1) * h_dilation + (1 - h_stride)) / 2)
w_pad = math.ceil(((w_ks - 1) * w_dilation + (1 - w_stride)) / 2)
elif isinstance(padding, int):
h_pad = w_pad = padding
else:
assert NotImplementedError
self.temporal_padding = t_pad
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
padding=(0, h_pad, w_pad),
**kwargs,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, C, T, H, W)
x = F.pad(
x,
pad=(0, 0, 0, 0, self.temporal_padding, 0),
mode="replicate", # TODO: check if this is necessary
)
return super().forward(x)
class PatchEmbed3D(nn.Module):
"""3D Image to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=16,
time_patch_size=4,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv3d(
in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
# See:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent)
latent = rearrange(latent, "b c f h w -> (b f) c h w")
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
return (latent + pos_embed).to(latent.dtype)
class PatchEmbedF3D(nn.Module):
"""Fake 3D Image to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
self.proj_t = Patch1D(
embed_dim, True, stride=patch_size
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
# See:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
b, c, f, h, w = latent.size()
latent = rearrange(latent, "b c f h w -> (b f) c h w")
latent = self.proj(latent)
latent = rearrange(latent, "(b f) c h w -> b c f h w", f=f)
latent = rearrange(latent, "b c f h w -> (b h w) c f")
latent = self.proj_t(latent)
latent = rearrange(latent, "(b h w) c f -> b c f h w", h=h//2, w=w//2)
latent = rearrange(latent, "b c f h w -> (b f) c h w")
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
return (latent + pos_embed).to(latent.dtype)
class CasualPatchEmbed3D(nn.Module):
"""3D Image to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=16,
time_patch_size=4,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
interpolation_scale=1,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = CausalConv3d(
in_channels, embed_dim, kernel_size=(time_patch_size, patch_size, patch_size), stride=(time_patch_size, patch_size, patch_size), bias=bias, padding=None
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
self.patch_size = patch_size
# See:
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
self.height, self.width = height // patch_size, width // patch_size
self.base_size = height // patch_size
self.interpolation_scale = interpolation_scale
pos_embed = get_2d_sincos_pos_embed(
embed_dim, int(num_patches**0.5), base_size=self.base_size, interpolation_scale=self.interpolation_scale
)
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent)
latent = rearrange(latent, "b c f h w -> (b f) c h w")
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCFHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
# Interpolate positional embeddings if needed.
# (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
if self.height != height or self.width != width:
pos_embed = get_2d_sincos_pos_embed(
embed_dim=self.pos_embed.shape[-1],
grid_size=(height, width),
base_size=self.base_size,
interpolation_scale=self.interpolation_scale,
)
pos_embed = torch.from_numpy(pos_embed)
pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
else:
pos_embed = self.pos_embed
return (latent + pos_embed).to(latent.dtype)
# Copyright (c) Alibaba Cloud.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.init import normal_
def get_abs_pos(abs_pos, tgt_size):
# abs_pos: L, C
# tgt_size: M
# return: M, C
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float32)
omega /= embed_dim / 2.
omega = 1. / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class Resampler(nn.Module):
"""
A 2D perceiver-resampler network with one cross attention layers by
(grid_size**2) learnable queries and 2d sincos pos_emb
Outputs:
A tensor with the shape of (grid_size**2, embed_dim)
"""
def __init__(
self,
grid_size,
embed_dim,
num_heads,
kv_dim=None,
norm_layer=nn.LayerNorm
):
super().__init__()
self.num_queries = grid_size ** 2
self.embed_dim = embed_dim
self.num_heads = num_heads
self.pos_embed = nn.Parameter(
torch.from_numpy(get_2d_sincos_pos_embed(embed_dim, grid_size)).float()
).requires_grad_(False)
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
normal_(self.query, std=.02)
if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False)
else:
self.kv_proj = nn.Identity()
self.attn = nn.MultiheadAttention(embed_dim, num_heads)
self.ln_q = norm_layer(embed_dim)
self.ln_kv = norm_layer(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, key_padding_mask=None):
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
x = self.kv_proj(x)
x = self.ln_kv(x).permute(1, 0, 2)
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
key_padding_mask=key_padding_mask)[0]
return out.permute(1, 0, 2)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
\ No newline at end of file
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, Optional
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn.init as init
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import BasicTransformerBlock
from diffusers.models.embeddings import ImagePositionalEmbeddings, PatchEmbed
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormSingle
from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate,
is_torch_version)
from einops import rearrange
from torch import nn
try:
from diffusers.models.embeddings import PixArtAlphaTextProjection
except:
from diffusers.models.embeddings import \
CaptionProjection as PixArtAlphaTextProjection
from .attention import (KVCompressionTransformerBlock,
SelfAttentionTemporalTransformerBlock,
TemporalTransformerBlock)
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
class Transformer2DModel(ModelMixin, ConfigMixin):
"""
A 2D Transformer model for image-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
num_vector_embeds (`int`, *optional*):
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states.
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
attention_type: str = "default",
caption_channels: int = None,
# block type
basic_block_type: str = "basic",
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.basic_block_type = basic_block_type
inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
deprecation_message = (
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
)
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
norm_type = "ada_norm"
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is None."
)
elif self.is_input_vectorized and self.is_input_patches:
raise ValueError(
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
" sure that either `num_vector_embeds` or `num_patches` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
raise ValueError(
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
)
# 2. Define input layers
if self.is_input_continuous:
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = linear_cls(in_channels, inner_dim)
else:
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
self.height = sample_size
self.width = sample_size
self.num_vector_embeds = num_vector_embeds
self.num_latent_pixels = self.height * self.width
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
)
elif self.is_input_patches:
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
interpolation_scale = max(interpolation_scale, 1)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
basic_block = {
"basic": BasicTransformerBlock,
"kvcompression": KVCompressionTransformerBlock,
}[self.basic_block_type]
if self.basic_block_type == "kvcompression":
self.transformer_blocks = nn.ModuleList(
[
basic_block(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
kvcompression=False if d < 14 else True,
)
for d in range(num_layers)
]
)
else:
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
)
for d in range(num_layers)
]
)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = linear_cls(inner_dim, in_channels)
else:
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
elif self.is_input_patches and norm_type != "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
elif self.is_input_patches and norm_type == "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
# 5. PixArt-Alpha blocks.
self.adaln_single = None
self.use_additional_conditions = False
if norm_type == "ada_norm_single":
self.use_additional_conditions = self.config.sample_size == 128
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
# additional conditions until we find better name
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
self.caption_projection = None
if caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input
if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states)
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
batch_size = hidden_states.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
args = {
"basic": [],
"kvcompression": [1, height, width],
}[self.basic_block_type]
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
*args,
use_reentrant=False,
)
else:
kwargs = {
"basic": {},
"kvcompression": {"num_frames":1, "height":height, "width":width},
}[self.basic_block_type]
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
**kwargs
)
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
else:
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
if self.is_input_patches:
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@classmethod
def from_pretrained(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
print(f"loaded 2D transformer's pretrained weights from {pretrained_model_path} ...")
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model = cls.from_config(config, **transformer_additional_kwargs)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
if not os.path.isfile(model_file):
raise RuntimeError(f"{model_file} does not exist")
state_dict = torch.load(model_file, map_location="cpu")
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
state_dict['pos_embed.proj.weight'] = torch.tile(state_dict['proj_out.weight'], [1, 2, 1, 1])
if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
new_shape = model.state_dict()['proj_out.weight'].size()
state_dict['proj_out.weight'] = torch.tile(state_dict['proj_out.weight'], [patch_size, 1])
if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
new_shape = model.state_dict()['proj_out.bias'].size()
state_dict['proj_out.bias'] = torch.tile(state_dict['proj_out.bias'], [patch_size])
tmp_state_dict = {}
for key in state_dict:
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
tmp_state_dict[key] = state_dict[key]
else:
print(key, "Size don't match, skip")
state_dict = tmp_state_dict
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
print(f"### Postion Parameters: {sum(params) / 1e6} M")
return model
\ No newline at end of file
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import os
from dataclasses import dataclass
from typing import Any, Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn.init as init
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.attention import BasicTransformerBlock, FeedForward
from diffusers.models.embeddings import (PatchEmbed,
PixArtAlphaTextProjection, TimestepEmbedding, Timesteps)
from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormContinuous
from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, is_torch_version,
logging)
from diffusers.utils.torch_utils import maybe_allow_in_graph
from einops import rearrange
from torch import nn
from .attention import (HunyuanDiTBlock, HunyuanTemporalTransformerBlock,
SelfAttentionTemporalTransformerBlock,
TemporalTransformerBlock)
from .embeddings import HunyuanCombinedTimestepTextSizeStyleEmbedding
from .norm import AdaLayerNormSingle
from .patch import (CasualPatchEmbed3D, Patch1D, PatchEmbed3D, PatchEmbedF3D,
TemporalUpsampler3D, UnPatch1D)
from .resampler import Resampler
try:
from diffusers.models.embeddings import PixArtAlphaTextProjection
except:
from diffusers.models.embeddings import \
CaptionProjection as PixArtAlphaTextProjection
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
class CLIPProjection(nn.Module):
"""
Projects caption embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_features, hidden_size, num_tokens=120):
super().__init__()
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
self.act_1 = nn.GELU(approximate="tanh")
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
self.linear_2 = zero_module(self.linear_2)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
class TimePositionalEncoding(nn.Module):
def __init__(
self,
d_model,
dropout = 0.,
max_len = 24
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
b, c, f, h, w = x.size()
x = rearrange(x, "b c f h w -> (b h w) f c")
x = x + self.pe[:, :x.size(1)]
x = rearrange(x, "(b h w) f c -> b c f h w", b=b, h=h, w=w)
return self.dropout(x)
@dataclass
class Transformer3DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
class Transformer3DModel(ModelMixin, ConfigMixin):
"""
A 3D Transformer model for image-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
num_vector_embeds (`int`, *optional*):
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states.
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
attention_type: str = "default",
caption_channels: int = None,
# block type
basic_block_type: str = "motionmodule",
# enable_uvit
enable_uvit: bool = False,
# 3d patch params
patch_3d: bool = False,
fake_3d: bool = False,
time_patch_size: Optional[int] = None,
casual_3d: bool = False,
casual_3d_upsampler_index: Optional[list] = None,
# motion module kwargs
motion_module_type = "VanillaGrid",
motion_module_kwargs = None,
motion_module_kwargs_odd = None,
motion_module_kwargs_even = None,
# time position encoding
time_position_encoding_before_transformer = False,
qk_norm = False,
after_norm = False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
self.enable_uvit = enable_uvit
inner_dim = num_attention_heads * attention_head_dim
self.basic_block_type = basic_block_type
self.patch_3d = patch_3d
self.fake_3d = fake_3d
self.casual_3d = casual_3d
self.casual_3d_upsampler_index = casual_3d_upsampler_index
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
assert sample_size is not None, "Transformer3DModel over patched input must provide sample_size"
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
self.time_patch_size = self.patch_size if time_patch_size is None else time_patch_size
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
interpolation_scale = max(interpolation_scale, 1)
if self.casual_3d:
self.pos_embed = CasualPatchEmbed3D(
height=sample_size,
width=sample_size,
patch_size=patch_size,
time_patch_size=self.time_patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
elif self.patch_3d:
if self.fake_3d:
self.pos_embed = PatchEmbedF3D(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
else:
self.pos_embed = PatchEmbed3D(
height=sample_size,
width=sample_size,
patch_size=patch_size,
time_patch_size=self.time_patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
else:
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
# 3. Define transformers blocks
if self.basic_block_type == "motionmodule":
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
qk_norm=qk_norm,
after_norm=after_norm,
)
for d in range(num_layers)
]
)
elif self.basic_block_type == "global_motionmodule":
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs_even if d % 2 == 0 else motion_module_kwargs_odd,
qk_norm=qk_norm,
after_norm=after_norm,
)
for d in range(num_layers)
]
)
elif self.basic_block_type == "kvcompression_motionmodule":
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
kvcompression=False if d < 14 else True,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
qk_norm=qk_norm,
after_norm=after_norm,
)
for d in range(num_layers)
]
)
elif self.basic_block_type == "selfattentiontemporal":
self.transformer_blocks = nn.ModuleList(
[
SelfAttentionTemporalTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
qk_norm=qk_norm,
after_norm=after_norm,
)
for d in range(num_layers)
]
)
else:
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
)
for d in range(num_layers)
]
)
if self.casual_3d:
self.unpatch1d = TemporalUpsampler3D()
elif self.patch_3d and self.fake_3d:
self.unpatch1d = UnPatch1D(inner_dim, True)
if self.enable_uvit:
self.long_connect_fc = nn.ModuleList(
[
nn.Linear(inner_dim, inner_dim, True) for d in range(13)
]
)
for index in range(13):
self.long_connect_fc[index] = zero_module(self.long_connect_fc[index])
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if norm_type != "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
if self.patch_3d and not self.fake_3d:
self.proj_out_2 = nn.Linear(inner_dim, self.time_patch_size * patch_size * patch_size * self.out_channels)
else:
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
elif norm_type == "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
if self.patch_3d and not self.fake_3d:
self.proj_out = nn.Linear(inner_dim, self.time_patch_size * patch_size * patch_size * self.out_channels)
else:
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
# 5. PixArt-Alpha blocks.
self.adaln_single = None
self.use_additional_conditions = False
if norm_type == "ada_norm_single":
self.use_additional_conditions = self.config.sample_size == 128
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
# additional conditions until we find better name
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
self.caption_projection = None
self.clip_projection = None
if caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
if in_channels == 12:
self.clip_projection = CLIPProjection(in_features=768, hidden_size=inner_dim * 8)
self.gradient_checkpointing = False
self.time_position_encoding_before_transformer = time_position_encoding_before_transformer
if self.time_position_encoding_before_transformer:
self.t_pos = TimePositionalEncoding(max_len = 4096, d_model = inner_dim)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
inpaint_latents: torch.Tensor = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
clip_encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
clip_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer3DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
hidden_states = hidden_states.to(encoder_hidden_states.dtype)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
if clip_attention_mask is not None:
encoder_attention_mask = torch.cat([encoder_attention_mask, clip_attention_mask], dim=1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
if inpaint_latents is not None:
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
# 1. Input
if self.casual_3d:
video_length, height, width = (hidden_states.shape[-3] - 1) // self.time_patch_size + 1, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
elif self.patch_3d:
video_length, height, width = hidden_states.shape[-3] // self.time_patch_size, hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
else:
video_length, height, width = hidden_states.shape[-3], hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
hidden_states = self.pos_embed(hidden_states)
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
batch_size = hidden_states.shape[0] // video_length
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width)
# hidden_states
# bs, c, f, h, w => b (f h w ) c
if self.time_position_encoding_before_transformer:
hidden_states = self.t_pos(hidden_states)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
if clip_encoder_hidden_states is not None and encoder_hidden_states is not None:
batch_size = hidden_states.shape[0]
clip_encoder_hidden_states = self.clip_projection(clip_encoder_hidden_states)
clip_encoder_hidden_states = clip_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states = torch.cat([encoder_hidden_states, clip_encoder_hidden_states], dim = 1)
skips = []
skip_index = 0
for index, block in enumerate(self.transformer_blocks):
if self.enable_uvit:
if index >= 15:
long_connect = self.long_connect_fc[skip_index](skips.pop())
hidden_states = hidden_states + long_connect
skip_index += 1
if self.casual_3d_upsampler_index is not None and index in self.casual_3d_upsampler_index:
hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=video_length, h=height, w=width)
hidden_states = self.unpatch1d(hidden_states)
video_length = (video_length - 1) * 2 + 1
hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c", f=video_length, h=height, w=width)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
args = {
"basic": [],
"motionmodule": [video_length, height, width],
"global_motionmodule": [video_length, height, width],
"selfattentiontemporal": [],
"kvcompression_motionmodule": [video_length, height, width],
}[self.basic_block_type]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
*args,
**ckpt_kwargs,
)
else:
kwargs = {
"basic": {},
"motionmodule": {"num_frames":video_length, "height":height, "width":width},
"global_motionmodule": {"num_frames":video_length, "height":height, "width":width},
"selfattentiontemporal": {},
"kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
}[self.basic_block_type]
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
**kwargs
)
if self.enable_uvit:
if index < 13:
skips.append(hidden_states)
if self.fake_3d and self.patch_3d:
hidden_states = rearrange(hidden_states, "b (f h w) c -> (b h w) c f", f=video_length, w=width, h=height)
hidden_states = self.unpatch1d(hidden_states)
hidden_states = rearrange(hidden_states, "(b h w) c f -> b (f h w) c", w=width, h=height)
# 3. Output
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
if self.patch_3d:
if self.fake_3d:
hidden_states = hidden_states.reshape(
shape=(-1, video_length * self.patch_size, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
else:
hidden_states = hidden_states.reshape(
shape=(-1, video_length, height, width, self.time_patch_size, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nfhwopqc->ncfohpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, video_length * self.time_patch_size, height * self.patch_size, width * self.patch_size)
)
else:
hidden_states = hidden_states.reshape(
shape=(-1, video_length, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, video_length, height * self.patch_size, width * self.patch_size)
)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
@classmethod
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model = cls.from_config(config, **transformer_additional_kwargs)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
if not os.path.isfile(model_file):
raise RuntimeError(f"{model_file} does not exist")
state_dict = torch.load(model_file, map_location="cpu")
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
if len(new_shape) == 5:
state_dict['pos_embed.proj.weight'] = state_dict['pos_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
state_dict['pos_embed.proj.weight'][:, :, :-1] = 0
else:
model.state_dict()['pos_embed.proj.weight'][:, :4, :, :] = state_dict['pos_embed.proj.weight']
model.state_dict()['pos_embed.proj.weight'][:, 4:, :, :] = 0
state_dict['pos_embed.proj.weight'] = model.state_dict()['pos_embed.proj.weight']
if model.state_dict()['proj_out.weight'].size() != state_dict['proj_out.weight'].size():
new_shape = model.state_dict()['proj_out.weight'].size()
state_dict['proj_out.weight'] = torch.tile(state_dict['proj_out.weight'], [patch_size, 1])
if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
new_shape = model.state_dict()['proj_out.bias'].size()
state_dict['proj_out.bias'] = torch.tile(state_dict['proj_out.bias'], [patch_size])
tmp_state_dict = {}
for key in state_dict:
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
tmp_state_dict[key] = state_dict[key]
# else:
# print(key, "Size don't match, skip")
state_dict = tmp_state_dict
m, u = model.load_state_dict(state_dict, strict=False)
print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
params = [p.numel() if "attn_temporal." in n else 0 for n, p in model.named_parameters()]
print(f"### Attn temporal Parameters: {sum(params) / 1e6} M")
return model
class HunyuanTransformer3DModel(ModelMixin, ConfigMixin):
"""
HunYuanDiT: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88):
The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
patch_size (`int`, *optional*):
The size of the patch to use for the input.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward.
sample_size (`int`, *optional*):
The width of the latent images. This is fixed during training since it is used to learn a number of
position embeddings.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
cross_attention_dim (`int`, *optional*):
The number of dimension in the clip text embedding.
hidden_size (`int`, *optional*):
The size of hidden layer in the conditioning embedding layers.
num_layers (`int`, *optional*, defaults to 1):
The number of layers of Transformer blocks to use.
mlp_ratio (`float`, *optional*, defaults to 4.0):
The ratio of the hidden layer size to the input size.
learn_sigma (`bool`, *optional*, defaults to `True`):
Whether to predict variance.
cross_attention_dim_t5 (`int`, *optional*):
The number dimensions in t5 text embedding.
pooled_projection_dim (`int`, *optional*):
The size of the pooled projection.
text_len (`int`, *optional*):
The length of the clip text embedding.
text_len_t5 (`int`, *optional*):
The length of the T5 text embedding.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
patch_size: Optional[int] = None,
n_query=16,
projection_dim=768,
activation_fn: str = "gelu-approximate",
sample_size=32,
hidden_size=1152,
num_layers: int = 28,
mlp_ratio: float = 4.0,
learn_sigma: bool = True,
cross_attention_dim: int = 1024,
norm_type: str = "layer_norm",
cross_attention_dim_t5: int = 2048,
pooled_projection_dim: int = 1024,
text_len: int = 77,
text_len_t5: int = 256,
# block type
basic_block_type: str = "basic",
# motion module kwargs
motion_module_type = "VanillaGrid",
motion_module_kwargs = None,
motion_module_kwargs_odd = None,
motion_module_kwargs_even = None,
time_position_encoding = False,
after_norm = False,
):
super().__init__()
# 4. Define output layers
if learn_sigma:
self.out_channels = in_channels * 2 if out_channels is None else out_channels
else:
self.out_channels = in_channels if out_channels is None else out_channels
self.enable_inpaint = in_channels * 2 != self.out_channels if learn_sigma else in_channels != self.out_channels
self.num_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.basic_block_type = basic_block_type
self.text_embedder = PixArtAlphaTextProjection(
in_features=cross_attention_dim_t5,
hidden_size=cross_attention_dim_t5 * 4,
out_features=cross_attention_dim,
act_fn="silu_fp32",
)
self.text_embedding_padding = nn.Parameter(
torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
in_channels=in_channels,
embed_dim=hidden_size,
patch_size=patch_size,
pos_embed_type=None,
)
self.time_extra_emb = HunyuanCombinedTimestepTextSizeStyleEmbedding(
hidden_size,
pooled_projection_dim=pooled_projection_dim,
seq_len=text_len_t5,
cross_attention_dim=cross_attention_dim_t5,
)
# 3. Define transformers blocks
if self.basic_block_type == "motionmodule":
self.blocks = nn.ModuleList(
[
HunyuanTemporalTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
skip=layer > num_layers // 2,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
after_norm=after_norm,
)
for layer in range(num_layers)
]
)
elif self.basic_block_type == "global_motionmodule":
self.blocks = nn.ModuleList(
[
HunyuanTemporalTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
skip=layer > num_layers // 2,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs_even if layer % 2 == 0 else motion_module_kwargs_odd,
after_norm=after_norm,
)
for layer in range(num_layers)
]
)
elif self.basic_block_type == "hybrid_attention":
self.blocks = nn.ModuleList(
[
HunyuanDiTBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
skip=layer > num_layers // 2,
after_norm=after_norm,
time_position_encoding=time_position_encoding,
is_local_attention=False if layer % 2 == 0 else True,
local_attention_frames=2,
enable_inpaint=self.enable_inpaint,
)
for layer in range(num_layers)
]
)
else:
# HunyuanDiT Blocks
self.blocks = nn.ModuleList(
[
HunyuanDiTBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
activation_fn=activation_fn,
ff_inner_dim=int(self.inner_dim * mlp_ratio),
cross_attention_dim=cross_attention_dim,
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
skip=layer > num_layers // 2,
after_norm=after_norm,
time_position_encoding=time_position_encoding,
enable_inpaint=self.enable_inpaint,
)
for layer in range(num_layers)
]
)
self.n_query = n_query
if self.enable_inpaint:
self.clip_padding = nn.Parameter(
torch.randn((self.n_query, cross_attention_dim)) * 0.02
)
self.clip_projection = Resampler(
int(math.sqrt(n_query)),
embed_dim=cross_attention_dim,
num_heads=self.config.num_attention_heads,
kv_dim=projection_dim,
norm_layer=nn.LayerNorm,
)
else:
self.clip_padding = None
self.clip_projection = None
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
self.gradient_checkpointing = False
self.hidden_cache_size = 0
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states,
timestep,
encoder_hidden_states=None,
text_embedding_mask=None,
encoder_hidden_states_t5=None,
text_embedding_mask_t5=None,
image_meta_size=None,
style=None,
image_rotary_emb=None,
inpaint_latents=None,
clip_encoder_hidden_states: Optional[torch.Tensor]=None,
clip_attention_mask: Optional[torch.Tensor]=None,
return_dict=True,
):
"""
The [`HunyuanDiT2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
The input tensor.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step.
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of `BertModel`.
text_embedding_mask: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of `BertModel`.
encoder_hidden_states_t5 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. This is the output of T5 Text Encoder.
text_embedding_mask_t5: torch.Tensor
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. This is the output
of T5 Text Encoder.
image_meta_size (torch.Tensor):
Conditional embedding indicate the image sizes
style: torch.Tensor:
Conditional embedding indicate the style
image_rotary_emb (`torch.Tensor`):
The image rotary embeddings to apply on query and key tensors during attention calculation.
return_dict: bool
Whether to return a dictionary.
"""
if inpaint_latents is not None:
hidden_states = torch.concat([hidden_states, inpaint_latents], 1)
# unpatchify: (N, out_channels, H, W)
patch_size = self.pos_embed.patch_size
video_length, height, width = hidden_states.shape[-3], hidden_states.shape[-2] // patch_size, hidden_states.shape[-1] // patch_size
hidden_states = rearrange(hidden_states, "b c f h w ->(b f) c h w")
hidden_states = self.pos_embed(hidden_states)
hidden_states = rearrange(hidden_states, "(b f) (h w) c -> b c f h w", f=video_length, h=height, w=width)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
temb = self.time_extra_emb(
timestep, encoder_hidden_states_t5, image_meta_size, style, hidden_dtype=timestep.dtype
) # [B, D]
# text projection
batch_size, sequence_length, _ = encoder_hidden_states_t5.shape
encoder_hidden_states_t5 = self.text_embedder(
encoder_hidden_states_t5.view(-1, encoder_hidden_states_t5.shape[-1])
)
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, sequence_length, -1)
encoder_hidden_states = torch.cat([encoder_hidden_states, encoder_hidden_states_t5], dim=1)
text_embedding_mask = torch.cat([text_embedding_mask, text_embedding_mask_t5], dim=-1)
text_embedding_mask = text_embedding_mask.unsqueeze(2).bool()
encoder_hidden_states = torch.where(text_embedding_mask, encoder_hidden_states, self.text_embedding_padding)
if clip_encoder_hidden_states is not None:
batch_size = encoder_hidden_states.shape[0]
clip_encoder_hidden_states = self.clip_projection(clip_encoder_hidden_states)
clip_encoder_hidden_states = clip_encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
clip_attention_mask = clip_attention_mask.unsqueeze(2).bool()
clip_encoder_hidden_states = torch.where(clip_attention_mask, clip_encoder_hidden_states, self.clip_padding)
skips = []
skips_cpu_cache = []
skips_cache_size = self.hidden_cache_size # Equal or bigger than this value, skips will be moved to cpu
if skips_cache_size <= 0:
skips_cache_size = self.config.num_layers
for layer, block in enumerate(self.blocks):
if layer > self.config.num_layers // 2:
if skips_cache_size == 1:
skip_cpu = skips.pop()
skip = skip_cpu.to(hidden_states.device)
else:
if len(skips) == 0:
skips_cache_block = skips_cpu_cache.pop()
for si in range(len(skips_cache_block)):
skips.append(skips_cache_block[si].to(hidden_states.device))
skip = skips.pop()
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
args = {
"basic": [video_length, height, width, clip_encoder_hidden_states],
"hybrid_attention": [video_length, height, width, clip_encoder_hidden_states],
"motionmodule": [video_length, height, width],
"global_motionmodule": [video_length, height, width],
}[self.basic_block_type]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
skip,
*args,
**ckpt_kwargs,
)
else:
kwargs = {
"basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
"hybrid_attention": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
"motionmodule": {"num_frames":video_length, "height":height, "width":width},
"global_motionmodule": {"num_frames":video_length, "height":height, "width":width},
}[self.basic_block_type]
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
skip=skip,
**kwargs
) # (N, L, D)
else:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
args = {
"basic": [None, video_length, height, width, clip_encoder_hidden_states],
"hybrid_attention": [None, video_length, height, width, clip_encoder_hidden_states],
"motionmodule": [None, video_length, height, width],
"global_motionmodule": [None, video_length, height, width],
}[self.basic_block_type]
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
*args,
**ckpt_kwargs,
)
else:
kwargs = {
"basic": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
"hybrid_attention": {"num_frames":video_length, "height":height, "width":width, "clip_encoder_hidden_states":clip_encoder_hidden_states},
"motionmodule": {"num_frames":video_length, "height":height, "width":width},
"global_motionmodule": {"num_frames":video_length, "height":height, "width":width},
}[self.basic_block_type]
hidden_states = block(
hidden_states,
temb=temb,
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**kwargs
) # (N, L, D)
if layer < (self.config.num_layers // 2 - 1):
if skips_cache_size == 1:
skips.append(hidden_states.to("cpu"))
else:
skips.append(hidden_states)
if len(skips) >= skips_cache_size:
skips_cache_block = []
for si in range(len(skips)):
skips_cache_block.append(skips[si].to("cpu"))
skips = []
skips_cpu_cache.append(skips_cache_block)
# final layer
hidden_states = self.norm_out(hidden_states, temb.to(torch.float32))
hidden_states = self.proj_out(hidden_states)
# (N, L, patch_size ** 2 * out_channels)
hidden_states = hidden_states.reshape(
shape=(hidden_states.shape[0], video_length, height, width, patch_size, patch_size, self.out_channels)
)
hidden_states = torch.einsum("nfhwpqc->ncfhpwq", hidden_states)
output = hidden_states.reshape(
shape=(hidden_states.shape[0], self.out_channels, video_length, height * patch_size, width * patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
@classmethod
def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, patch_size=2, transformer_additional_kwargs={}):
if subfolder is not None:
pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
# print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
config_file = os.path.join(pretrained_model_path, 'config.json')
if not os.path.isfile(config_file):
raise RuntimeError(f"{config_file} does not exist")
with open(config_file, "r") as f:
config = json.load(f)
from diffusers.utils import WEIGHTS_NAME
model = cls.from_config(config, **transformer_additional_kwargs)
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
model_file_safetensors = model_file.replace(".bin", ".safetensors")
if os.path.exists(model_file_safetensors):
from safetensors.torch import load_file, safe_open
state_dict = load_file(model_file_safetensors)
else:
if not os.path.isfile(model_file):
raise RuntimeError(f"{model_file} does not exist")
state_dict = torch.load(model_file, map_location="cpu")
if model.state_dict()['pos_embed.proj.weight'].size() != state_dict['pos_embed.proj.weight'].size():
new_shape = model.state_dict()['pos_embed.proj.weight'].size()
if len(new_shape) == 5:
state_dict['pos_embed.proj.weight'] = state_dict['pos_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
state_dict['pos_embed.proj.weight'][:, :, :-1] = 0
if model.state_dict()['proj_out.bias'].size() != state_dict['proj_out.bias'].size():
if model.state_dict()['proj_out.bias'].size()[0] > state_dict['proj_out.bias'].size()[0]:
model.state_dict()['proj_out.bias'][:state_dict['proj_out.bias'].size()[0]] = state_dict['proj_out.bias']
state_dict['proj_out.bias'] = model.state_dict()['proj_out.bias']
tmp_state_dict = {}
for key in state_dict:
if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
tmp_state_dict[key] = state_dict[key]
state_dict = tmp_state_dict
m, u = model.load_state_dict(state_dict, strict=False)
# print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
# print(m)
params = [p.numel() if "mamba" in n else 0 for n, p in model.named_parameters()]
# print(f"### Mamba Parameters: {sum(params) / 1e6} M")
params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
# print(f"### attn1 Parameters: {sum(params) / 1e6} M")
params = [p.numel() for n, p in model.named_parameters()]
# print(f"### Total Parameters: {sum(params) / 1e6} M")
return model
\ No newline at end of file
# Copyright 2024 HunyuanDiT Authors and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import html
import inspect
import re
import urllib.parse as ul
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import DiffusionPipeline, ImagePipelineOutput
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKL, HunyuanDiT2DModel
from diffusers.models.embeddings import get_2d_rotary_pos_embed
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.safety_checker import \
StableDiffusionSafetyChecker
from diffusers.schedulers import DDPMScheduler, DPMSolverMultistepScheduler
from diffusers.utils import (BACKENDS_MAPPING, BaseOutput, deprecate,
is_bs4_available, is_ftfy_available,
is_torch_xla_available, logging,
replace_example_docstring)
from diffusers.utils.torch_utils import randn_tensor
from einops import rearrange
from PIL import Image
from tqdm import tqdm
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import HunyuanDiTPipeline
>>> pipe = HunyuanDiTPipeline.from_pretrained("Tencent-Hunyuan/HunyuanDiT", torch_dtype=torch.float16)
>>> pipe.to("cuda")
>>> # You may also use English prompt as HunyuanDiT supports both English and Chinese
>>> # prompt = "An astronaut riding a horse"
>>> prompt = "一个宇航员在骑马"
>>> image = pipe(prompt).images[0]
```
"""
STANDARD_RATIO = np.array(
[
1.0, # 1:1
4.0 / 3.0, # 4:3
3.0 / 4.0, # 3:4
16.0 / 9.0, # 16:9
9.0 / 16.0, # 9:16
]
)
STANDARD_SHAPE = [
[(1024, 1024), (1280, 1280)], # 1:1
[(1024, 768), (1152, 864), (1280, 960)], # 4:3
[(768, 1024), (864, 1152), (960, 1280)], # 3:4
[(1280, 720)], # 16:9
[(720, 1280)], # 9:16
]
STANDARD_AREA = [np.array([w * h for w, h in shapes]) for shapes in STANDARD_SHAPE]
SUPPORTED_SHAPE = [
(1024, 1024),
(1280, 1280), # 1:1
(1024, 768),
(1152, 864),
(1280, 960), # 4:3
(768, 1024),
(864, 1152),
(960, 1280), # 3:4
(1280, 720), # 16:9
(720, 1280), # 9:16
]
def map_to_standard_shapes(target_width, target_height):
target_ratio = target_width / target_height
closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio))
closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height))
width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx]
return width, height
def get_resize_crop_region_for_grid(src, tgt_size):
th = tw = tgt_size
h, w = src
r = h / w
# resize
if r > 1:
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
return noise_cfg
@dataclass
class RuyiPipelineOutput(BaseOutput):
videos: Union[torch.Tensor, np.ndarray]
class RuyiInpaintPipeline(DiffusionPipeline):
r"""
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vae ([`AutoencoderKLMagvit`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
transformer ([`HunyuanTransformer3DModel`]):
The HunyuanDiT model designed by Tencent Hunyuan.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
"""
model_cpu_offload_seq = "clip_image_encoder->transformer->vae"
_optional_components = [
"safety_checker",
"feature_extractor",
"clip_image_encoder",
]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
"negative_prompt_embeds",
"prompt_embeds_2",
"negative_prompt_embeds_2",
]
def __init__(
self,
vae: AutoencoderKL,
transformer: HunyuanDiT2DModel,
scheduler: DDPMScheduler,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
clip_image_processor:CLIPImageProcessor = None,
clip_image_encoder:CLIPVisionModelWithProjection = None,
):
super().__init__()
self.register_modules(
vae=vae,
transformer=transformer,
scheduler=scheduler,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
clip_image_processor=clip_image_processor,
clip_image_encoder=clip_image_encoder,
)
if safety_checker is None and requires_safety_checker:
logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
)
if safety_checker is not None and feature_extractor is None:
raise ValueError(
"Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = self.transformer.config.sample_size
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
)
self.model_cpu_offload_flag = False
def enable_sequential_cpu_offload(self, *args, **kwargs):
super().enable_sequential_cpu_offload(*args, **kwargs)
self.model_cpu_offload_flag = False
if hasattr(self.transformer, "clip_projection") and self.transformer.clip_projection is not None:
import accelerate
accelerate.hooks.remove_hook_from_module(self.transformer.clip_projection, recurse=True)
self.transformer.clip_projection = self.transformer.clip_projection.to("cuda")
def enable_model_cpu_offload(self, *args, **kwargs):
super().enable_model_cpu_offload(*args, **kwargs)
self.model_cpu_offload_flag = True
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
def run_safety_checker(self, image, device, dtype):
if self.safety_checker is None:
has_nsfw_concept = None
else:
if torch.is_tensor(image):
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
else:
feature_extractor_input = self.image_processor.numpy_to_pil(image)
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
image, has_nsfw_concept = self.safety_checker(
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
)
return image, has_nsfw_concept
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
prompt_embeds_2=None,
negative_prompt_embeds_2=None,
prompt_attention_mask_2=None,
negative_prompt_attention_mask_2=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is None and prompt_embeds_2 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if prompt_embeds_2 is not None and prompt_attention_mask_2 is None:
raise ValueError("Must provide `prompt_attention_mask_2` when specifying `prompt_embeds_2`.")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if negative_prompt_embeds_2 is not None and negative_prompt_attention_mask_2 is None:
raise ValueError(
"Must provide `negative_prompt_attention_mask_2` when specifying `negative_prompt_embeds_2`."
)
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_embeds_2 is not None and negative_prompt_embeds_2 is not None:
if prompt_embeds_2.shape != negative_prompt_embeds_2.shape:
raise ValueError(
"`prompt_embeds_2` and `negative_prompt_embeds_2` must have the same shape when passed directly, but"
f" got: `prompt_embeds_2` {prompt_embeds_2.shape} != `negative_prompt_embeds_2`"
f" {negative_prompt_embeds_2.shape}."
)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
t_start = max(num_inference_steps - init_timestep, 0)
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
return timesteps, num_inference_steps - t_start
def prepare_mask_latents(
self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
video_length = mask.shape[2]
mask = mask.to(device=device, dtype=self.vae.dtype)
if self.vae.quant_conv.weight.ndim==5:
bs = 1
mini_batch_encoder = self.vae.mini_batch_encoder
new_mask = []
for i in range(0, mask.shape[0], bs):
mask_bs = mask[i : i + bs]
mask_bs = self.vae.encode(mask_bs)[0]
mask_bs = mask_bs.sample()
new_mask.append(mask_bs)
mask = torch.cat(new_mask, dim = 0)
mask = mask * self.vae.config.scaling_factor
else:
if mask.shape[1] == 4:
mask = mask
else:
video_length = mask.shape[2]
mask = rearrange(mask, "b c f h w -> (b f) c h w")
mask = self._encode_vae_image(mask, generator=generator)
mask = rearrange(mask, "(b f) c h w -> b c f h w", f=video_length)
masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
if self.vae.quant_conv.weight.ndim==5:
bs = 1
new_mask_pixel_values = []
for i in range(0, masked_image.shape[0], bs):
mask_pixel_values_bs = masked_image[i : i + bs]
mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
mask_pixel_values_bs = mask_pixel_values_bs.sample()
new_mask_pixel_values.append(mask_pixel_values_bs)
masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
else:
if masked_image.shape[1] == 4:
masked_image_latents = masked_image
else:
video_length = mask.shape[2]
masked_image = rearrange(masked_image, "b c f h w -> (b f) c h w")
masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
masked_image_latents = rearrange(masked_image_latents, "(b f) c h w -> b c f h w", f=video_length)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
video_length,
dtype,
device,
generator,
latents=None,
video=None,
timestep=None,
is_strength_max=True,
return_noise=False,
return_video_latents=False,
):
video_latents = None
if self.vae.quant_conv.weight.ndim==5:
mini_batch_encoder = self.vae.mini_batch_encoder
mini_batch_decoder = self.vae.mini_batch_decoder
shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
else:
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if return_video_latents or (latents is None and not is_strength_max):
video = video.to(device=device, dtype=self.vae.dtype)
if self.vae.quant_conv.weight.ndim==5:
bs = 1
new_video = []
for i in range(0, video.shape[0], bs):
video_bs = video[i : i + bs]
video_bs = self.vae.encode(video_bs)[0]
video_bs = video_bs.sample()
new_video.append(video_bs)
video = torch.cat(new_video, dim = 0)
video = video * self.vae.config.scaling_factor
else:
if video.shape[1] == 4:
video = video
else:
video_length = video.shape[2]
video = rearrange(video, "b c f h w -> (b f) c h w")
video = self._encode_vae_image(video, generator=generator)
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
video_latents = video_latents.to(device=device, dtype=dtype)
if latents is None:
# TODO: a fast but brute force fix, sometimes the computed shape is not equals to the video latent's shape
if video_latents is not None:
shape = video_latents.shape
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
# if strength is 1. then initialise the latents to noise, else initial to image + noise
latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
# if pure noise then scale the initial latents by the Scheduler's init sigma
latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
else:
noise = latents.to(device)
latents = noise * self.scheduler.init_noise_sigma
# scale the initial noise by the standard deviation required by the scheduler
outputs = (latents,)
if return_noise:
outputs += (noise,)
if return_video_latents:
outputs += (video_latents,)
return outputs
def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
if video.size()[2] <= mini_batch_encoder:
return video
prefix_index_before = mini_batch_encoder // 2
prefix_index_after = mini_batch_encoder - prefix_index_before
pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
# Encode middle videos
latents = self.vae.encode(pixel_values)[0]
latents = latents.mode()
middle_video = self.vae.decode(latents)[0]
video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
return video
def decode_latents(self, latents):
video_length = latents.shape[2]
latents = 1 / self.vae.config.scaling_factor * latents
if self.vae.quant_conv.weight.ndim==5:
mini_batch_encoder = self.vae.mini_batch_encoder
mini_batch_decoder = self.vae.mini_batch_decoder
video = self.vae.decode(latents)[0]
video = video.clamp(-1, 1)
if not self.vae.cache_compression_vae:
video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
else:
latents = rearrange(latents, "b c f h w -> (b f) c h w")
video = []
for frame_idx in tqdm(range(latents.shape[0])):
video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
video = torch.cat(video)
video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
video = (video / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
video = video.cpu().float().numpy()
return video
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
video_length: Optional[int] = None,
video: Union[torch.FloatTensor] = None,
mask_video: Union[torch.FloatTensor] = None,
masked_video_latents: Union[torch.FloatTensor] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 5.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: Optional[float] = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_2: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_2: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
prompt_attention_mask_2: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask_2: Optional[torch.Tensor] = None,
output_type: Optional[str] = "latent",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
guidance_rescale: float = 0.0,
original_size: Optional[Tuple[int, int]] = (1024, 1024),
target_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
use_resolution_binning: bool = False,
clip_image: Image = None,
clip_apply_ratio: float = 0.40,
strength: float = 1.0,
comfyui_progressbar: bool = False,
):
r"""
The call function to the pipeline for generation with HunyuanDiT.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. This parameter is modulated by `strength`.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
negative_prompt_embeds_2 (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds` is passed directly.
prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the prompt. Required when `prompt_embeds_2` is passed directly.
negative_prompt_attention_mask (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds` is passed directly.
negative_prompt_attention_mask_2 (`torch.Tensor`, *optional*):
Attention mask for the negative prompt. Required when `negative_prompt_embeds_2` is passed directly.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback_on_step_end (`Callable[[int, int, Dict], None]`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A callback function or a list of callback functions to be called at the end of each denoising step.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
A list of tensor inputs that should be passed to the callback function. If not defined, all tensor
inputs will be passed.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Rescale the noise_cfg according to `guidance_rescale`. Based on findings of [Common Diffusion Noise
Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
original_size (`Tuple[int, int]`, *optional*, defaults to `(1024, 1024)`):
The original size of the image. Used to calculate the time ids.
target_size (`Tuple[int, int]`, *optional*):
The target size of the image. Used to calculate the time ids.
crops_coords_top_left (`Tuple[int, int]`, *optional*, defaults to `(0, 0)`):
The top left coordinates of the crop. Used to calculate the time ids.
use_resolution_binning (`bool`, *optional*, defaults to `True`):
Whether to use resolution binning or not. If `True`, the input resolution will be mapped to the closest
standard resolution. Supported resolutions are 1024x1024, 1280x1280, 1024x768, 1152x864, 1280x960,
768x1024, 864x1152, 960x1280, 1280x768, and 768x1280. It is recommended to set this to `True`.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. default height and width
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
height = int(height // 16 * 16)
width = int(width // 16 * 16)
if use_resolution_binning and (height, width) not in SUPPORTED_SHAPE:
width, height = map_to_standard_shapes(width, height)
height = int(height)
width = int(width)
logger.warning(f"Reshaped to (height, width)=({height}, {width}), Supported shapes are {SUPPORTED_SHAPE}")
# 2. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_attention_mask_2,
negative_prompt_attention_mask_2,
callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._interrupt = False
# 3. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 4. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps, num_inference_steps = self.get_timesteps(
num_inference_steps=num_inference_steps, strength=strength, device=device
)
if comfyui_progressbar:
from comfy.utils import ProgressBar
pbar = ProgressBar(num_inference_steps + 3)
# at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
# create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
is_strength_max = strength == 1.0
if video is not None:
video_length = video.shape[2]
init_video = self.image_processor.preprocess(rearrange(video, "b c f h w -> (b f) c h w"), height=height, width=width)
init_video = init_video.to(dtype=torch.float32)
init_video = rearrange(init_video, "(b f) c h w -> b c f h w", f=video_length)
else:
init_video = None
# Prepare latent variables
num_channels_latents = self.vae.config.latent_channels
num_channels_transformer = self.transformer.config.in_channels
return_image_latents = num_channels_transformer == num_channels_latents
# Make vae to cuda
if self.model_cpu_offload_flag:
self.vae = self.vae.to(device)
torch.cuda.empty_cache()
# 5. Prepare latents.
latents_outputs = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
video_length,
prompt_embeds.dtype,
device,
generator,
latents,
video=init_video,
timestep=latent_timestep,
is_strength_max=is_strength_max,
return_noise=True,
return_video_latents=return_image_latents,
)
if return_image_latents:
latents, noise, image_latents = latents_outputs
else:
latents, noise = latents_outputs
latents_dtype = latents.dtype
if comfyui_progressbar:
pbar.update(1)
if clip_image is not None:
inputs = self.clip_image_processor(images=clip_image, return_tensors="pt")
inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype)
clip_encoder_hidden_states = self.clip_image_encoder(**inputs).last_hidden_state[:, 1:]
clip_encoder_hidden_states_neg = torch.zeros(
[
batch_size,
int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2,
int(self.clip_image_encoder.config.hidden_size)
]
).to(latents.device, dtype=latents.dtype)
clip_attention_mask = torch.ones([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
clip_attention_mask_neg = torch.zeros([batch_size, self.transformer.n_query]).to(latents.device, dtype=latents.dtype)
clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if self.do_classifier_free_guidance else clip_encoder_hidden_states
clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if self.do_classifier_free_guidance else clip_attention_mask
elif clip_image is None and num_channels_transformer != num_channels_latents:
clip_encoder_hidden_states = torch.zeros(
[
batch_size,
int(self.clip_image_encoder.config.image_size / self.clip_image_encoder.config.patch_size) ** 2,
int(self.clip_image_encoder.config.hidden_size)
]
).to(latents.device, dtype=latents.dtype)
clip_attention_mask = torch.zeros([batch_size, self.transformer.n_query])
clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype)
clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if self.do_classifier_free_guidance else clip_encoder_hidden_states
clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if self.do_classifier_free_guidance else clip_attention_mask
else:
clip_encoder_hidden_states_input = None
clip_attention_mask_input = None
if comfyui_progressbar:
pbar.update(1)
if mask_video is not None:
if (mask_video == 255).all():
mask_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
masked_video_latents_input = (
torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
)
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
else:
# Prepare mask latent variables
video_length = video.shape[2]
mask_condition = self.mask_processor.preprocess(rearrange(mask_video, "b c f h w -> (b f) c h w"), height=height, width=width)
mask_condition = mask_condition.to(dtype=torch.float32)
mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
if num_channels_transformer != num_channels_latents:
mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
if masked_video_latents is None:
masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
else:
masked_video = masked_video_latents
mask_latents, masked_video_latents = self.prepare_mask_latents(
mask_condition_tile,
masked_video,
batch_size,
height,
width,
prompt_embeds.dtype,
device,
generator,
self.do_classifier_free_guidance,
)
mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
mask_input = torch.cat([mask_latents] * 2) if self.do_classifier_free_guidance else mask_latents
masked_video_latents_input = (
torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
)
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
else:
mask = torch.tile(mask_condition, [1, num_channels_latents, 1, 1, 1])
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
inpaint_latents = None
else:
if num_channels_transformer != num_channels_latents:
mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
mask_input = torch.cat([mask] * 2) if self.do_classifier_free_guidance else mask
masked_video_latents_input = (
torch.cat([masked_video_latents] * 2) if self.do_classifier_free_guidance else masked_video_latents
)
inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
else:
mask = torch.zeros_like(init_video[:, :1])
mask = torch.tile(mask, [1, num_channels_latents, 1, 1, 1])
mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
inpaint_latents = None
if comfyui_progressbar:
pbar.update(1)
# Check that sizes of mask, masked image and latents match
if num_channels_transformer == 48:
# default case for runwayml/stable-diffusion-inpainting
num_channels_mask = mask_latents.shape[1]
num_channels_masked_image = masked_video_latents.shape[1]
if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
raise ValueError(
f"Incorrect configuration settings! The config of `pipeline.transformer`: {self.transformer.config} expects"
f" {self.transformer.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
" `pipeline.transformer` or your `mask_image` or `image` input."
)
elif num_channels_transformer != num_channels_latents:
raise ValueError(
f"The transformer {self.transformer.__class__} should have 4 input channels, not {self.transformer.config.in_channels}."
)
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7 create image_rotary_emb, style embedding & time ids
grid_height = height // 8 // self.transformer.config.patch_size
grid_width = width // 8 // self.transformer.config.patch_size
base_size = 512 // 8 // self.transformer.config.patch_size
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
image_rotary_emb = get_2d_rotary_pos_embed(
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
)
style = torch.tensor([0], device=device)
target_size = target_size or (height, width)
add_time_ids = list(original_size + target_size + crops_coords_top_left)
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2])
add_time_ids = torch.cat([add_time_ids] * 2, dim=0)
style = torch.cat([style] * 2, dim=0)
prompt_embeds = prompt_embeds.to(device=device)
prompt_attention_mask = prompt_attention_mask.to(device=device)
prompt_embeds_2 = prompt_embeds_2.to(device=device)
prompt_attention_mask_2 = prompt_attention_mask_2.to(device=device)
add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=device).repeat(
batch_size * num_images_per_prompt, 1
)
style = style.to(device=device).repeat(batch_size * num_images_per_prompt)
# Empty vae cache
if self.model_cpu_offload_flag:
self.vae = self.vae.to("cpu")
self.transformer = self.transformer.to(device)
torch.cuda.empty_cache()
# 10. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None:
clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input)
clip_attention_mask_actual_input = torch.zeros_like(clip_attention_mask_input)
else:
clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input
clip_attention_mask_actual_input = clip_attention_mask_input
current_timestep = t
if not torch.is_tensor(current_timestep):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(current_timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
elif len(current_timestep.shape) == 0:
current_timestep = current_timestep[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
current_timestep = current_timestep.expand(latent_model_input.shape[0]).to(
dtype=latent_model_input.dtype
)
# predict the noise residual
noise_pred = self.transformer(
latent_model_input,
current_timestep,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
inpaint_latents=inpaint_latents,
clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input,
clip_attention_mask=clip_attention_mask_actual_input,
return_dict=False,
)[0]
noise_pred, _ = noise_pred.chunk(2, dim=1)
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if num_channels_transformer == 4:
init_latents_proper = image_latents
init_mask = mask
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
init_latents_proper = self.scheduler.add_noise(
init_latents_proper, noise, torch.tensor([noise_timestep])
)
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds_2 = callback_outputs.pop("prompt_embeds_2", prompt_embeds_2)
negative_prompt_embeds_2 = callback_outputs.pop(
"negative_prompt_embeds_2", negative_prompt_embeds_2
)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if comfyui_progressbar:
pbar.update(1)
# Make vae to cuda
if self.model_cpu_offload_flag:
self.transformer = self.transformer.to("cpu")
self.vae = self.vae.to(device)
torch.cuda.empty_cache()
# Post-processing
video = self.decode_latents(latents)
if self.model_cpu_offload_flag:
# Make vae to cpu
self.vae = self.vae.to("cpu")
torch.cuda.empty_cache()
# Convert to tensor
if output_type == "latent":
video = torch.from_numpy(video)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return video
return RuyiPipelineOutput(videos=video)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment