Commit c07946d8 authored by hepj's avatar hepj
Browse files

dit & video

parents
# !/bin/python3
# isort: skip_file
import argparse
import math
import os
import time
from collections import deque
from copy import deepcopy
import torch
import torch.distributed as dist
import wandb
from accelerate.utils import set_seed
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version
from peft import LoraConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm.auto import tqdm
from fastvideo.dataset.latent_datasets import (LatentDataset, latent_collate_function)
from fastvideo.distill.discriminator import Discriminator
from fastvideo.distill.solver import EulerSolver, extract_into_tensor
from fastvideo.models.mochi_hf.mochi_latents_utils import normalize_dit_input
from fastvideo.models.mochi_hf.pipeline_mochi import linear_quadratic_schedule
from fastvideo.utils.checkpoint import (resume_lora_optimizer, resume_training_generator_discriminator, save_checkpoint,
save_lora_checkpoint)
from fastvideo.utils.communications import (broadcast, sp_parallel_dataloader_wrapper)
from fastvideo.utils.dataset_utils import LengthGroupedSampler
from fastvideo.utils.fsdp_util import (apply_fsdp_checkpointing, get_discriminator_fsdp_kwargs, get_dit_fsdp_kwargs)
from fastvideo.utils.load import load_transformer
from fastvideo.utils.logging_ import main_print
from fastvideo.utils.parallel_states import (destroy_sequence_parallel_group, get_sequence_parallel_state,
initialize_sequence_parallel_state)
from fastvideo.utils.validation import log_validation
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.31.0")
def gan_d_loss(
discriminator,
teacher_transformer,
sample_fake,
sample_real,
timestep,
encoder_hidden_states,
encoder_attention_mask,
weight,
discriminator_head_stride,
):
loss = 0.0
# collate sample_fake and sample_real
with torch.no_grad():
fake_features = teacher_transformer(
sample_fake,
encoder_hidden_states,
timestep,
encoder_attention_mask,
output_features=True,
output_features_stride=discriminator_head_stride,
return_dict=False,
)[1]
real_features = teacher_transformer(
sample_real,
encoder_hidden_states,
timestep,
encoder_attention_mask,
output_features=True,
output_features_stride=discriminator_head_stride,
return_dict=False,
)[1]
fake_outputs = discriminator(fake_features)
real_outputs = discriminator(real_features)
for fake_output, real_output in zip(fake_outputs, real_outputs):
loss += (torch.mean(weight * torch.relu(fake_output.float() + 1)) + torch.mean(
weight * torch.relu(1 - real_output.float()))) / (discriminator.head_num * discriminator.num_h_per_head)
return loss
def gan_g_loss(
discriminator,
teacher_transformer,
sample_fake,
timestep,
encoder_hidden_states,
encoder_attention_mask,
weight,
discriminator_head_stride,
):
loss = 0.0
features = teacher_transformer(
sample_fake,
encoder_hidden_states,
timestep,
encoder_attention_mask,
output_features=True,
output_features_stride=discriminator_head_stride,
return_dict=False,
)[1]
fake_outputs = discriminator(features, )
for fake_output in fake_outputs:
loss += torch.mean(
weight * torch.relu(1 - fake_output.float())) / (discriminator.head_num * discriminator.num_h_per_head)
return loss
def distill_one_step_adv(
transformer,
model_type,
teacher_transformer,
optimizer,
discriminator,
discriminator_optimizer,
lr_scheduler,
loader,
noise_scheduler,
solver,
noise_random_generator,
sp_size,
max_grad_norm,
uncond_prompt_embed,
uncond_prompt_mask,
num_euler_timesteps,
multiphase,
not_apply_cfg_solver,
distill_cfg,
adv_weight,
discriminator_head_stride,
):
optimizer.zero_grad()
discriminator_optimizer.zero_grad()
(
latents,
encoder_hidden_states,
latents_attention_mask,
encoder_attention_mask,
) = next(loader)
model_input = normalize_dit_input(model_type, latents)
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
index = torch.randint(0, num_euler_timesteps, (bsz, ), device=model_input.device).long()
if sp_size > 1:
broadcast(index)
# Add noise according to flow matching.
# sigmas = get_sigmas(start_timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
sigmas = extract_into_tensor(solver.sigmas, index, model_input.shape)
sigmas_prev = extract_into_tensor(solver.sigmas_prev, index, model_input.shape)
timesteps = (sigmas * noise_scheduler.config.num_train_timesteps).view(-1)
# if squeeze to [], unsqueeze to [1]
timesteps_prev = (sigmas_prev * noise_scheduler.config.num_train_timesteps).view(-1)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
# Predict the noise residual
with torch.autocast("cuda", dtype=torch.bfloat16):
model_pred = transformer(
noisy_model_input,
encoder_hidden_states,
timesteps,
encoder_attention_mask, # B, L
return_dict=False,
)[0]
# if accelerator.is_main_process:
model_pred, end_index = solver.euler_style_multiphase_pred(noisy_model_input, model_pred, index, multiphase)
# # simplified flow matching aka 0-rectified flow matching loss
# # target = model_input - noise
# target = model_input
adv_index = torch.empty_like(end_index)
for i in range(end_index.size(0)):
adv_index[i] = torch.randint(
end_index[i].item(),
end_index[i].item() + num_euler_timesteps // multiphase,
(1, ),
dtype=end_index.dtype,
device=end_index.device,
)
sigmas_end = extract_into_tensor(solver.sigmas_prev, end_index, model_input.shape)
sigmas_adv = extract_into_tensor(solver.sigmas_prev, adv_index, model_input.shape)
timesteps_adv = (sigmas_adv * noise_scheduler.config.num_train_timesteps).view(-1)
with torch.no_grad():
w = distill_cfg
with torch.autocast("cuda", dtype=torch.bfloat16):
cond_teacher_output = teacher_transformer(
noisy_model_input,
encoder_hidden_states,
timesteps,
encoder_attention_mask, # B, L
return_dict=False,
)[0].float()
if not_apply_cfg_solver:
uncond_teacher_output = cond_teacher_output
else:
# Get teacher model prediction on noisy_latents and unconditional embedding
with torch.autocast("cuda", dtype=torch.bfloat16):
uncond_teacher_output = teacher_transformer(
noisy_model_input,
uncond_prompt_embed.unsqueeze(0).expand(bsz, -1, -1),
timesteps,
uncond_prompt_mask.unsqueeze(0).expand(bsz, -1),
return_dict=False,
)[0].float()
teacher_output = cond_teacher_output + w * (cond_teacher_output - uncond_teacher_output)
x_prev = solver.euler_step(noisy_model_input, teacher_output, index)
# 20.4.12. Get target LCM prediction on x_prev, w, c, t_n
with torch.no_grad():
with torch.autocast("cuda", dtype=torch.bfloat16):
target_pred = transformer(
x_prev.float(),
encoder_hidden_states,
timesteps_prev,
encoder_attention_mask, # B, L
return_dict=False,
)[0]
target, end_index = solver.euler_style_multiphase_pred(x_prev, target_pred, index, multiphase, True)
real_adv = ((1 - sigmas_adv) * target + (sigmas_adv - sigmas_end) * torch.randn_like(target)) / (1 - sigmas_end)
fake_adv = ((1 - sigmas_adv) * model_pred +
(sigmas_adv - sigmas_end) * torch.randn_like(model_pred)) / (1 - sigmas_end)
huber_c = 0.001
g_loss = torch.mean(torch.sqrt((model_pred.float() - target.float())**2 + huber_c**2) - huber_c)
discriminator.requires_grad_(False)
with torch.autocast("cuda", dtype=torch.bfloat16):
g_gan_loss = adv_weight * gan_g_loss(
discriminator,
teacher_transformer,
fake_adv.float(),
timesteps_adv,
encoder_hidden_states.float(),
encoder_attention_mask,
1.0,
discriminator_head_stride,
)
g_loss += g_gan_loss
g_loss.backward()
g_loss = g_loss.detach().clone()
dist.all_reduce(g_loss, op=dist.ReduceOp.AVG)
g_grad_norm = transformer.clip_grad_norm_(max_grad_norm).item()
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
discriminator_optimizer.zero_grad()
discriminator.requires_grad_(True)
with torch.autocast("cuda", dtype=torch.bfloat16):
d_loss = gan_d_loss(
discriminator,
teacher_transformer,
fake_adv.detach(),
real_adv.detach(),
timesteps_adv,
encoder_hidden_states,
encoder_attention_mask,
1.0,
discriminator_head_stride,
)
d_loss.backward()
d_grad_norm = discriminator.clip_grad_norm_(max_grad_norm).item()
discriminator_optimizer.step()
discriminator_optimizer.zero_grad()
return g_loss, g_grad_norm, d_loss, d_grad_norm
def main(args):
torch.backends.cuda.matmul.allow_tf32 = True
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
dist.init_process_group("nccl")
torch.cuda.set_device(local_rank)
device = torch.cuda.current_device()
initialize_sequence_parallel_state(args.sp_size)
# If passed along, set the training seed now. On GPU...
if args.seed is not None:
# TODO: t within the same seq parallel group should be the same. Noise should be different.
set_seed(args.seed + rank)
# We use different seeds for the noise generation in each process to ensure that the noise is different in a batch.
noise_random_generator = None
# Handle the repository creation
if rank <= 0 and args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# For mixed precision training we cast all non-trainable weights to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
# Create model:
main_print(f"--> loading model from {args.pretrained_model_name_or_path}")
# keep the master weight to float32
transformer = load_transformer(
args.model_type,
args.dit_model_name_or_path,
args.pretrained_model_name_or_path,
torch.float32 if args.master_weight_type == "fp32" else torch.bfloat16,
)
teacher_transformer = deepcopy(transformer)
discriminator = Discriminator(
args.discriminator_head_stride,
total_layers=48 if args.model_type == "mochi" else 40,
)
if args.use_lora:
transformer.requires_grad_(False)
transformer_lora_config = LoraConfig(
r=args.lora_rank,
lora_alpha=args.lora_alpha,
init_lora_weights=True,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
main_print(
f" Total transformer parameters = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e6} M"
)
# discriminator
main_print(
f" Total discriminator parameters = {sum(p.numel() for p in discriminator.parameters() if p.requires_grad) / 1e6} M"
)
main_print(f"--> Initializing FSDP with sharding strategy: {args.fsdp_sharding_startegy}")
fsdp_kwargs, no_split_modules = get_dit_fsdp_kwargs(
transformer,
args.fsdp_sharding_startegy,
args.use_lora,
args.use_cpu_offload,
args.master_weight_type,
)
discriminator_fsdp_kwargs = get_discriminator_fsdp_kwargs(args.master_weight_type)
if args.use_lora:
assert args.model_type == "mochi", "LoRA is only supported for Mochi model."
transformer.config.lora_rank = args.lora_rank
transformer.config.lora_alpha = args.lora_alpha
transformer.config.lora_target_modules = ["to_k", "to_q", "to_v", "to_out.0"]
transformer._no_split_modules = no_split_modules
fsdp_kwargs["auto_wrap_policy"] = fsdp_kwargs["auto_wrap_policy"](transformer)
transformer = FSDP(
transformer,
**fsdp_kwargs,
)
teacher_transformer = FSDP(
teacher_transformer,
**fsdp_kwargs,
)
discriminator = FSDP(
discriminator,
**discriminator_fsdp_kwargs,
)
main_print("--> model loaded")
if args.gradient_checkpointing:
apply_fsdp_checkpointing(transformer, no_split_modules, args.selective_checkpointing)
apply_fsdp_checkpointing(teacher_transformer, no_split_modules, args.selective_checkpointing)
# Set model as trainable.
transformer.train()
teacher_transformer.requires_grad_(False)
noise_scheduler = FlowMatchEulerDiscreteScheduler(shift=args.shift)
if args.scheduler_type == "pcm_linear_quadratic":
sigmas = linear_quadratic_schedule(noise_scheduler.config.num_train_timesteps, args.linear_quadratic_threshold)
sigmas = torch.tensor(sigmas).to(dtype=torch.float32)
else:
sigmas = noise_scheduler.sigmas
solver = EulerSolver(
sigmas.numpy()[::-1],
noise_scheduler.config.num_train_timesteps,
euler_timesteps=args.num_euler_timesteps,
)
solver.to(device)
params_to_optimize = transformer.parameters()
params_to_optimize = list(filter(lambda p: p.requires_grad, params_to_optimize))
optimizer = torch.optim.AdamW(
params_to_optimize,
lr=args.learning_rate,
betas=(0.9, 0.999),
weight_decay=args.weight_decay,
eps=1e-8,
)
discriminator_optimizer = torch.optim.AdamW(
discriminator.parameters(),
lr=args.discriminator_learning_rate,
betas=(0, 0.999),
weight_decay=args.weight_decay,
eps=1e-8,
)
init_steps = 0
if args.resume_from_lora_checkpoint:
transformer, optimizer, init_steps = resume_lora_optimizer(transformer, args.resume_from_lora_checkpoint,
optimizer)
elif args.resume_from_checkpoint:
(
transformer,
optimizer,
discriminator,
discriminator_optimizer,
init_steps,
) = resume_training_generator_discriminator(
transformer,
optimizer,
discriminator,
discriminator_optimizer,
args.resume_from_checkpoint,
rank,
)
main_print(f"optimizer: {optimizer}")
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * world_size,
num_training_steps=args.max_train_steps * world_size,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
last_epoch=init_steps - 1,
)
train_dataset = LatentDataset(args.data_json_path, args.num_latent_t, args.cfg)
uncond_prompt_embed = train_dataset.uncond_prompt_embed
uncond_prompt_mask = train_dataset.uncond_prompt_mask
sampler = (LengthGroupedSampler(
args.train_batch_size,
rank=rank,
world_size=world_size,
lengths=train_dataset.lengths,
group_frame=args.group_frame,
group_resolution=args.group_resolution,
) if (args.group_frame or args.group_resolution) else DistributedSampler(
train_dataset, rank=rank, num_replicas=world_size, shuffle=False))
train_dataloader = DataLoader(
train_dataset,
sampler=sampler,
collate_fn=latent_collate_function,
pin_memory=True,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
drop_last=True,
)
assert args.gradient_accumulation_steps == 1
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args.gradient_accumulation_steps * args.sp_size / args.train_sp_batch_size)
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if rank <= 0:
project = args.tracker_project_name or "fastvideo"
wandb.init(project=project, config=args)
# Train!
total_batch_size = (world_size * args.gradient_accumulation_steps / args.sp_size * args.train_sp_batch_size)
main_print("***** Running training *****")
main_print(f" Num examples = {len(train_dataset)}")
main_print(f" Dataloader size = {len(train_dataloader)}")
main_print(f" Num Epochs = {args.num_train_epochs}")
main_print(f" Resume training from step {init_steps}")
main_print(f" Instantaneous batch size per device = {args.train_batch_size}")
main_print(f" Total train batch size (w. data & sequence parallel, accumulation) = {total_batch_size}")
main_print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
main_print(f" Total optimization steps = {args.max_train_steps}")
main_print(
f" Total training parameters per FSDP shard = {sum(p.numel() for p in transformer.parameters() if p.requires_grad) / 1e9} B"
)
# print dtype
main_print(f" Master weight dtype: {transformer.parameters().__next__().dtype}")
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=init_steps,
desc="Steps",
# Only show the progress bar once on each machine.
disable=local_rank > 0,
)
loader = sp_parallel_dataloader_wrapper(
train_dataloader,
device,
args.train_batch_size,
args.sp_size,
args.train_sp_batch_size,
)
step_times = deque(maxlen=100)
# log_validation(args, transformer, device,
# torch.bfloat16, 0, scheduler_type=args.scheduler_type, shift=args.shift, num_euler_timesteps=args.num_euler_timesteps, linear_quadratic_threshold=args.linear_quadratic_threshold,ema=False)
def get_num_phases(multi_phased_distill_schedule, step):
# step-phase,step-phase
multi_phases = multi_phased_distill_schedule.split(",")
phase = multi_phases[-1].split("-")[-1]
for step_phases in multi_phases:
phase_step, phase = step_phases.split("-")
if step <= int(phase_step):
return int(phase)
return phase
for i in range(init_steps):
_ = next(loader)
for step in range(init_steps + 1, args.max_train_steps + 1):
assert args.multi_phased_distill_schedule is not None
num_phases = get_num_phases(args.multi_phased_distill_schedule, step)
start_time = time.time()
(
generator_loss,
generator_grad_norm,
discriminator_loss,
discriminator_grad_norm,
) = distill_one_step_adv(
transformer,
args.model_type,
teacher_transformer,
optimizer,
discriminator,
discriminator_optimizer,
lr_scheduler,
loader,
noise_scheduler,
solver,
noise_random_generator,
args.sp_size,
args.max_grad_norm,
uncond_prompt_embed,
uncond_prompt_mask,
args.num_euler_timesteps,
num_phases,
args.not_apply_cfg_solver,
args.distill_cfg,
args.adv_weight,
args.discriminator_head_stride,
)
step_time = time.time() - start_time
step_times.append(step_time)
avg_step_time = sum(step_times) / len(step_times)
progress_bar.set_postfix({
"g_loss": f"{generator_loss:.4f}",
"d_loss": f"{discriminator_loss:.4f}",
"g_grad_norm": generator_grad_norm,
"d_grad_norm": discriminator_grad_norm,
"step_time": f"{step_time:.2f}s",
})
progress_bar.update(1)
# if rank <= 0:
# wandb.log(
# {
# "generator_loss": generator_loss,
# "discriminator_loss": discriminator_loss,
# "generator_grad_norm": generator_grad_norm,
# "discriminator_grad_norm": discriminator_grad_norm,
# "learning_rate": lr_scheduler.get_last_lr()[0],
# "step_time": step_time,
# "avg_step_time": avg_step_time,
# },
# step=step,
# )
if step % args.checkpointing_steps == 0:
main_print(f"--> saving checkpoint at step {step}")
if args.use_lora:
# Save LoRA weights
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, step)
else:
# Your existing checkpoint saving code
# TODO
# save_checkpoint_generator_discriminator(
# transformer,
# optimizer,
# discriminator,
# discriminator_optimizer,
# rank,
# args.output_dir,
# step,
# )
save_checkpoint(transformer, rank, args.output_dir, step)
main_print(f"--> checkpoint saved at step {step}")
dist.barrier()
if args.log_validation and step % args.validation_steps == 0:
log_validation(
args,
transformer,
device,
torch.bfloat16,
step,
scheduler_type=args.scheduler_type,
shift=args.shift,
num_euler_timesteps=args.num_euler_timesteps,
linear_quadratic_threshold=args.linear_quadratic_threshold,
linear_range=args.linear_range,
ema=False,
)
if args.use_lora:
save_lora_checkpoint(transformer, optimizer, rank, args.output_dir, args.max_train_steps)
else:
save_checkpoint(transformer, rank, args.output_dir, args.max_train_steps)
if get_sequence_parallel_state():
destroy_sequence_parallel_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model_type", type=str, default="mochi", help="The type of model to train.")
# dataset & dataloader
parser.add_argument("--data_json_path", type=str, required=True)
parser.add_argument("--num_height", type=int, default=480)
parser.add_argument("--num_width", type=int, default=848)
parser.add_argument("--num_frames", type=int, default=163)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=10,
help="Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process.",
)
parser.add_argument(
"--train_batch_size",
type=int,
default=16,
help="Batch size (per device) for the training dataloader.",
)
parser.add_argument("--num_latent_t", type=int, default=28, help="Number of latent timesteps.")
parser.add_argument("--group_frame", action="store_true") # TODO
parser.add_argument("--group_resolution", action="store_true") # TODO
# text encoder & vae & diffusion model
parser.add_argument("--pretrained_model_name_or_path", type=str)
parser.add_argument("--dit_model_name_or_path", type=str)
parser.add_argument("--cache_dir", type=str, default="./cache_dir")
# diffusion setting
parser.add_argument("--ema_decay", type=float, default=0.999)
parser.add_argument("--ema_start_step", type=int, default=0)
parser.add_argument("--cfg", type=float, default=0.1)
# validation & logs
parser.add_argument("--validation_sampling_steps", type=str, default="64")
parser.add_argument("--validation_guidance_scale", type=str, default="4.5")
parser.add_argument("--validation_steps", type=float, default=64)
parser.add_argument("--log_validation", action="store_true")
parser.add_argument("--tracker_project_name", type=str, default=None)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=("Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."),
)
parser.add_argument("--validation_prompt_dir", type=str)
parser.add_argument("--shift", type=float, default=1.0)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--resume_from_lora_checkpoint",
type=str,
default=None,
help=("Whether training should be resumed from a previous lora checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'),
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."),
)
# optimizer & scheduler & Training
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--discriminator_learning_rate",
type=float,
default=1e-5,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_warmup_steps",
type=int,
default=10,
help="Number of steps for the warmup in the lr scheduler.",
)
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument("--selective_checkpointing", type=float, default=1.0)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=("Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."),
)
parser.add_argument(
"--use_cpu_offload",
action="store_true",
help="Whether to use CPU offload for param & gradient & optimizer states.",
)
parser.add_argument("--sp_size", type=int, default=1, help="For sequence parallel")
parser.add_argument(
"--train_sp_batch_size",
type=int,
default=1,
help="Batch size for sequence parallel training",
)
parser.add_argument(
"--use_lora",
action="store_true",
default=False,
help="Whether to use LoRA for finetuning.",
)
parser.add_argument("--lora_alpha", type=int, default=256, help="Alpha parameter for LoRA.")
parser.add_argument("--lora_rank", type=int, default=128, help="LoRA rank parameter. ")
parser.add_argument("--fsdp_sharding_startegy", default="full")
parser.add_argument("--multi_phased_distill_schedule", type=str, default=None)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
# lr_scheduler
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=('The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'),
)
parser.add_argument("--num_euler_timesteps", type=int, default=100)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of cycles in the learning rate scheduler.",
)
parser.add_argument(
"--lr_power",
type=float,
default=1.0,
help="Power factor of the polynomial scheduler.",
)
parser.add_argument(
"--not_apply_cfg_solver",
action="store_true",
help="Whether to apply the cfg_solver.",
)
parser.add_argument("--distill_cfg", type=float, default=3.0, help="Distillation coefficient.")
# ["euler_linear_quadratic", "pcm", "pcm_linear_qudratic"]
parser.add_argument("--scheduler_type", type=str, default="pcm", help="The scheduler type to use.")
parser.add_argument(
"--adv_weight",
type=float,
default=0.1,
help="The weight of the adversarial loss.",
)
parser.add_argument(
"--discriminator_head_stride",
type=int,
default=2,
help="The stride of the discriminator head.",
)
parser.add_argument(
"--linear_range",
type=float,
default=0.5,
help="Range for linear quadratic scheduler.",
)
parser.add_argument("--weight_decay", type=float, default=0.001, help="Weight decay to apply.")
parser.add_argument(
"--linear_quadratic_threshold",
type=float,
default=0.025,
help="The threshold of the linear quadratic scheduler.",
)
parser.add_argument(
"--master_weight_type",
type=str,
default="fp32",
help="Weight type to use - fp32 or bf16.",
)
args = parser.parse_args()
main(args)
from einops import rearrange
from flash_attn import flash_attn_varlen_qkvpacked_func
from flash_attn.bert_padding import pad_input, unpad_input
import torch
# @torch._dynamo.disable
def flash_attn_no_pad(qkv, key_padding_mask, causal=False, dropout_p=0.0, softmax_scale=None):
# adapted from https://github.com/Dao-AILab/flash-attention/blob/13403e81157ba37ca525890f2f0f2137edf75311/flash_attn/flash_attention.py#L27
batch_size = qkv.shape[0]
seqlen = qkv.shape[1]
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad,
cu_seqlens,
max_s,
dropout_p,
softmax_scale=softmax_scale,
causal=causal,
)
output = rearrange(
pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen),
"b s (h d) -> b s h d",
h=nheads,
)
return output
import os
import torch
__all__ = [
"C_SCALE",
"PROMPT_TEMPLATE",
"MODEL_BASE",
"PRECISIONS",
"NORMALIZATION_TYPE",
"ACTIVATION_TYPE",
"VAE_PATH",
"TEXT_ENCODER_PATH",
"TOKENIZER_PATH",
"TEXT_PROJECTION",
"DATA_TYPE",
"NEGATIVE_PROMPT",
]
PRECISION_TO_TYPE = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
# =================== Constant Values =====================
# Computation scale factor, 1P = 1_000_000_000_000_000. Tensorboard will display the value in PetaFLOPS to avoid
# overflow error when tensorboard logging values.
C_SCALE = 1_000_000_000_000_000
# When using decoder-only models, we must provide a prompt template to instruct the text encoder
# on how to generate the text.
# --------------------------------------------------------------------
PROMPT_TEMPLATE_ENCODE = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
PROMPT_TEMPLATE_ENCODE_VIDEO = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
PROMPT_TEMPLATE = {
"dit-llm-encode": {
"template": PROMPT_TEMPLATE_ENCODE,
"crop_start": 36,
},
"dit-llm-encode-video": {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95,
},
}
# ======================= Model ======================
PRECISIONS = {"fp32", "fp16", "bf16"}
NORMALIZATION_TYPE = {"layer", "rms"}
ACTIVATION_TYPE = {"relu", "silu", "gelu", "gelu_tanh"}
# =================== Model Path =====================
MODEL_BASE = os.getenv("MODEL_BASE", "./data/hunyuan")
# =================== Data =======================
DATA_TYPE = {"image", "video", "image_video"}
# 3D VAE
VAE_PATH = {"884-16c-hy": f"{MODEL_BASE}/hunyuan-video-t2v-720p/vae"}
# Text Encoder
TEXT_ENCODER_PATH = {
"clipL": f"{MODEL_BASE}/text_encoder_2",
"llm": f"{MODEL_BASE}/text_encoder",
}
# Tokenizer
TOKENIZER_PATH = {
"clipL": f"{MODEL_BASE}/text_encoder_2",
"llm": f"{MODEL_BASE}/text_encoder",
}
TEXT_PROJECTION = {
"linear", # Default, an nn.Linear() layer
"single_refiner", # Single TokenRefiner. Refer to LI-DiT
}
# ruff: noqa: F401
from .pipelines import HunyuanVideoPipeline
from .schedulers import FlowMatchDiscreteScheduler
# ruff: noqa: F401
from .pipeline_hunyuan_video import HunyuanVideoPipeline
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
import inspect
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderKL
from diffusers.models.lora import adjust_lora_scale_text_encoder
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, deprecate, logging, replace_example_docstring,
scale_lora_layers)
from diffusers.utils.torch_utils import randn_tensor
from einops import rearrange
from fastvideo.utils.communications import all_gather
from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info
from ...constants import PRECISION_TO_TYPE
from ...modules import HYVideoDiffusionTransformer
from ...text_encoder import TextEncoder
from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """"""
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
"""
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
# rescale the results from guidance (fixes overexposure)
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
noise_cfg = (guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg)
return noise_cfg
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler.")
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler.")
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
@dataclass
class HunyuanVideoPipelineOutput(BaseOutput):
videos: Union[torch.Tensor, np.ndarray]
class HunyuanVideoPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-video generation using HunyuanVideo.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
text_encoder ([`TextEncoder`]):
Frozen text-encoder.
text_encoder_2 ([`TextEncoder`]):
Frozen text-encoder_2.
transformer ([`HYVideoDiffusionTransformer`]):
A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
_optional_components = ["text_encoder_2"]
_exclude_from_cpu_offload = ["transformer"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
vae: AutoencoderKL,
text_encoder: TextEncoder,
transformer: HYVideoDiffusionTransformer,
scheduler: KarrasDiffusionSchedulers,
text_encoder_2: Optional[TextEncoder] = None,
progress_bar_config: Dict[str, Any] = None,
args=None,
):
super().__init__()
# ==========================================================================================
if progress_bar_config is None:
progress_bar_config = {}
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
self._progress_bar_config.update(progress_bar_config)
self.args = args
# ==========================================================================================
if (hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file")
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if (hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True):
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file")
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
text_encoder_2=text_encoder_2,
)
self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
def encode_prompt(
self,
prompt,
device,
num_videos_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_attention_mask: Optional[torch.Tensor] = None,
lora_scale: Optional[float] = None,
clip_skip: Optional[int] = None,
text_encoder: Optional[TextEncoder] = None,
data_type: Optional[str] = "image",
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`):
torch device
num_videos_per_prompt (`int`):
number of videos that should be generated per prompt
do_classifier_free_guidance (`bool`):
whether to use classifier free guidance or not
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the video generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
attention_mask (`torch.Tensor`, *optional*):
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
negative_attention_mask (`torch.Tensor`, *optional*):
lora_scale (`float`, *optional*):
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
text_encoder (TextEncoder, *optional*):
data_type (`str`, *optional*):
"""
if text_encoder is None:
text_encoder = self.text_encoder
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
self._lora_scale = lora_scale
# dynamically adjust the LoRA scale
if not USE_PEFT_BACKEND:
adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
else:
scale_lora_layers(text_encoder.model, lora_scale)
if prompt_embeds is None:
# textual inversion: process multi-vector tokens if necessary
if isinstance(self, TextualInversionLoaderMixin):
prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
if clip_skip is None:
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device)
prompt_embeds = prompt_outputs.hidden_state
else:
prompt_outputs = text_encoder.encode(
text_inputs,
output_hidden_states=True,
data_type=data_type,
device=device,
)
# Access the `hidden_states` first, that contains a tuple of
# all the hidden states from the encoder layers. Then index into
# the tuple to access the hidden states from the desired layer.
prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
# We also need to apply the final LayerNorm here to not mess with the
# representations. The `last_hidden_states` that we typically use for
# obtaining the final prompt representations passes through the LayerNorm
# layer.
prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds)
attention_mask = prompt_outputs.attention_mask
if attention_mask is not None:
attention_mask = attention_mask.to(device)
bs_embed, seq_len = attention_mask.shape
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
if text_encoder is not None:
prompt_embeds_dtype = text_encoder.dtype
elif self.transformer is not None:
prompt_embeds_dtype = self.transformer.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
if prompt_embeds.ndim == 2:
bs_embed, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
else:
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
return (
prompt_embeds,
negative_prompt_embeds,
attention_mask,
negative_attention_mask,
)
def decode_latents(self, latents, enable_tiling=True):
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
latents = 1 / self.vae.config.scaling_factor * latents
if enable_tiling:
self.vae.enable_tiling()
image = self.vae.decode(latents, return_dict=False)[0]
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
if image.ndim == 4:
image = image.cpu().permute(0, 2, 3, 1).float()
else:
image = image.cpu().float()
return image
def prepare_extra_func_kwargs(self, func, kwargs):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
extra_step_kwargs = {}
for k, v in kwargs.items():
accepts = k in set(inspect.signature(func).parameters.keys())
if accepts:
extra_step_kwargs[k] = v
return extra_step_kwargs
def check_inputs(
self,
prompt,
height,
width,
video_length,
callback_steps,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
vae_ver="88-4c-sd",
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if video_length is not None:
if "884" in vae_ver:
if video_length != 1 and (video_length - 1) % 4 != 0:
raise ValueError(f"`video_length` has to be 1 or a multiple of 4 but is {video_length}.")
elif "888" in vae_ver:
if video_length != 1 and (video_length - 1) % 8 != 0:
raise ValueError(f"`video_length` has to be 1 or a multiple of 8 but is {video_length}.")
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}.")
if callback_on_step_end_tensor_inputs is not None and not all(k in self._callback_tensor_inputs
for k in callback_on_step_end_tensor_inputs):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two.")
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined.")
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}.")
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
video_length,
dtype,
device,
generator,
latents=None,
):
shape = (
batch_size,
num_channels_latents,
video_length,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators.")
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
if hasattr(self.scheduler, "init_noise_sigma"):
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
def get_guidance_scale_embedding(
self,
w: torch.Tensor,
embedding_dim: int = 512,
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
Args:
w (`torch.Tensor`):
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
embedding_dim (`int`, *optional*, defaults to 512):
Dimension of the embeddings to generate.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
Data type of the generated embeddings.
Returns:
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
"""
assert len(w.shape) == 1
w = w * 1000.0
half_dim = embedding_dim // 2
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
emb = w.to(dtype)[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1))
assert emb.shape == (w.shape[0], embedding_dim)
return emb
@property
def guidance_scale(self):
return self._guidance_scale
@property
def guidance_rescale(self):
return self._guidance_rescale
@property
def clip_skip(self):
return self._clip_skip
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
# return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
return self._guidance_scale > 1
@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]],
height: int,
width: int,
video_length: int,
data_type: str = "video",
num_inference_steps: int = 50,
timesteps: List[int] = None,
sigmas: List[float] = None,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_videos_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Union[Callable[[int, int, Dict], None], PipelineCallback,
MultiPipelineCallbacks, ]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
vae_ver: str = "88-4c-sd",
enable_tiling: bool = False,
enable_vae_sp: bool = False,
n_tokens: Optional[int] = None,
embedded_guidance_scale: Optional[float] = None,
mask_strategy: Optional[Dict[str, list]] = None,
**kwargs,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
height (`int`):
The height in pixels of the generated image.
width (`int`):
The width in pixels of the generated image.
video_length (`int`):
The number of frames in the generated video.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 7.5):
A higher guidance scale value encourages the model to generate images closely linked to the text
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
plain tuple.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
Examples:
Returns:
[`~HunyuanVideoPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
otherwise a `tuple` is returned where the first element is a list with the generated images and the
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""
callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)
if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
)
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 0. Default height and width to unet
# height = height or self.transformer.config.sample_size * self.vae_scale_factor
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
height,
width,
video_length,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
vae_ver=vae_ver,
)
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = (torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device)
# 3. Encode input prompt
lora_scale = (self.cross_attention_kwargs.get("scale", None)
if self.cross_attention_kwargs is not None else None)
(
prompt_embeds,
negative_prompt_embeds,
prompt_mask,
negative_prompt_mask,
) = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
attention_mask=attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_attention_mask=negative_attention_mask,
lora_scale=lora_scale,
clip_skip=self.clip_skip,
data_type=data_type,
)
if self.text_encoder_2 is not None:
(
prompt_embeds_2,
negative_prompt_embeds_2,
prompt_mask_2,
negative_prompt_mask_2,
) = self.encode_prompt(
prompt,
device,
num_videos_per_prompt,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=None,
attention_mask=None,
negative_prompt_embeds=None,
negative_attention_mask=None,
lora_scale=lora_scale,
clip_skip=self.clip_skip,
text_encoder=self.text_encoder_2,
data_type=data_type,
)
else:
prompt_embeds_2 = None
negative_prompt_embeds_2 = None
prompt_mask_2 = None
negative_prompt_mask_2 = None
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
if prompt_mask is not None:
prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
if prompt_embeds_2 is not None:
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
if prompt_mask_2 is not None:
prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
# 4. Prepare timesteps
extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(self.scheduler.set_timesteps,
{"n_tokens": n_tokens})
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
**extra_set_timesteps_kwargs,
)
if "884" in vae_ver:
video_length = (video_length - 1) // 4 + 1
elif "888" in vae_ver:
video_length = (video_length - 1) // 8 + 1
else:
video_length = video_length
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
video_length,
prompt_embeds.dtype,
device,
generator,
latents,
)
world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group
if get_sequence_parallel_state():
latents = rearrange(latents, "b t (n s) h w -> b t n s h w", n=world_size).contiguous()
latents = latents[:, :, rank, :, :, :]
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_func_kwargs(
self.scheduler.step,
{
"generator": generator,
"eta": eta
},
)
target_dtype = PRECISION_TO_TYPE[self.args.precision]
autocast_enabled = (target_dtype != torch.float32) and not self.args.disable_autocast
vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
vae_autocast_enabled = (vae_dtype != torch.float32) and not self.args.disable_autocast
# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
def dict_to_3d_list(mask_strategy, t_max=50, l_max=60, h_max=24):
result = [[[None for _ in range(h_max)] for _ in range(l_max)] for _ in range(t_max)]
if mask_strategy is None:
return result
for key, value in mask_strategy.items():
t, l, h = map(int, key.split('_'))
result[t][l][h] = value
return result
mask_strategy = dict_to_3d_list(mask_strategy)
# if is_progress_bar:
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# expand the latents if we are doing classifier free guidance
latent_model_input = (torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
t_expand = t.repeat(latent_model_input.shape[0])
guidance_expand = (torch.tensor(
[embedded_guidance_scale] * latent_model_input.shape[0],
dtype=torch.float32,
device=device,
).to(target_dtype) * 1000.0 if embedded_guidance_scale is not None else None)
# predict the noise residual
with torch.autocast(device_type="cuda", dtype=target_dtype, enabled=autocast_enabled):
# concat prompt_embeds_2 and prompt_embeds. Mismatch fill with zeros
if prompt_embeds_2.shape[-1] != prompt_embeds.shape[-1]:
prompt_embeds_2 = F.pad(
prompt_embeds_2,
(0, prompt_embeds.shape[2] - prompt_embeds_2.shape[1]),
value=0,
).unsqueeze(1)
encoder_hidden_states = torch.cat([prompt_embeds_2, prompt_embeds], dim=1)
noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
latent_model_input,
encoder_hidden_states,
t_expand,
prompt_mask,
mask_strategy=mask_strategy[i],
guidance=guidance_expand,
return_dict=False,
)[0]
# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(
noise_pred,
noise_pred_text,
guidance_rescale=self.guidance_rescale,
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
if progress_bar is not None:
progress_bar.update()
if callback is not None and i % callback_steps == 0:
step_idx = i // getattr(self.scheduler, "order", 1)
callback(step_idx, t, latents)
if get_sequence_parallel_state():
latents = all_gather(latents, dim=2)
if not output_type == "latent":
expand_temporal_dim = False
if len(latents.shape) == 4:
if isinstance(self.vae, AutoencoderKLCausal3D):
latents = latents.unsqueeze(2)
expand_temporal_dim = True
elif len(latents.shape) == 5:
pass
else:
raise ValueError(
f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
if (hasattr(self.vae.config, "shift_factor") and self.vae.config.shift_factor):
latents = (latents / self.vae.config.scaling_factor + self.vae.config.shift_factor)
else:
latents = latents / self.vae.config.scaling_factor
with torch.autocast(device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled):
if enable_tiling:
self.vae.enable_tiling()
if enable_vae_sp:
self.vae.enable_parallel()
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
if expand_temporal_dim or image.shape[2] == 1:
image = image.squeeze(2)
else:
image = latents
image = (image / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
image = image.cpu().float()
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return image
return HunyuanVideoPipelineOutput(videos=image)
# ruff: noqa: F401
from .scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modified from diffusers==0.29.2
#
# ==============================================================================
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import torch
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.schedulers.scheduling_utils import SchedulerMixin
from diffusers.utils import BaseOutput, logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class FlowMatchDiscreteSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's `step` function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.FloatTensor
class FlowMatchDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
Euler scheduler.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
timestep_spacing (`str`, defaults to `"linspace"`):
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
shift (`float`, defaults to 1.0):
The shift value for the timestep schedule.
reverse (`bool`, defaults to `True`):
Whether to reverse the timestep schedule.
"""
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0,
reverse: bool = True,
solver: str = "euler",
n_tokens: Optional[int] = None,
):
sigmas = torch.linspace(1, 0, num_train_timesteps + 1)
if not reverse:
sigmas = sigmas.flip(0)
self.sigmas = sigmas
# the value fed to model
self.timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32)
self._step_index = None
self._begin_index = None
self.supported_solver = ["euler"]
if solver not in self.supported_solver:
raise ValueError(f"Solver {solver} not supported. Supported solvers: {self.supported_solver}")
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(
self,
num_inference_steps: int,
device: Union[str, torch.device] = None,
n_tokens: int = None,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
"""
self.num_inference_steps = num_inference_steps
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
sigmas = self.sd3_time_shift(sigmas)
if not self.config.reverse:
sigmas = 1 - sigmas
self.sigmas = sigmas
self.timesteps = (sigmas[:-1] * self.config.num_train_timesteps).to(dtype=torch.float32, device=device)
# Reset step index
self._step_index = None
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
return sample
def sd3_time_shift(self, t: torch.Tensor):
return (self.config.shift * t) / (1 + (self.config.shift - 1) * t)
def step(
self,
model_output: torch.FloatTensor,
timestep: Union[float, torch.FloatTensor],
sample: torch.FloatTensor,
return_dict: bool = True,
) -> Union[FlowMatchDiscreteSchedulerOutput, Tuple]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`):
The direct output from learned diffusion model.
timestep (`float`):
The current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
n_tokens (`int`, *optional*):
Number of tokens in the input sequence.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or
tuple.
Returns:
[`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
"""
if (isinstance(timestep, int) or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)):
raise ValueError(("Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."), )
if self.step_index is None:
self._init_step_index(timestep)
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
if self.config.solver == "euler":
prev_sample = sample + model_output.to(torch.float32) * dt
else:
raise ValueError(f"Solver {self.config.solver} not supported. Supported solvers: {self.supported_solver}")
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample, )
return FlowMatchDiscreteSchedulerOutput(prev_sample=prev_sample)
def __len__(self):
return self.config.num_train_timesteps
# ruff: noqa: F405, F403
import argparse
import re
from .constants import *
from .modules.models import HUNYUAN_VIDEO_CONFIG
def parse_args(namespace=None):
parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
parser = add_network_args(parser)
parser = add_extra_models_args(parser)
parser = add_denoise_schedule_args(parser)
parser = add_inference_args(parser)
parser = add_parallel_args(parser)
args = parser.parse_args(namespace=namespace)
args = sanity_check_args(args)
return args
def add_network_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="HunyuanVideo network args")
# Main model
group.add_argument(
"--model",
type=str,
choices=list(HUNYUAN_VIDEO_CONFIG.keys()),
default="HYVideo-T/2-cfgdistill",
)
group.add_argument(
"--latent-channels",
type=str,
default=16,
help="Number of latent channels of DiT. If None, it will be determined by `vae`. If provided, "
"it still needs to match the latent channels of the VAE model.",
)
group.add_argument(
"--precision",
type=str,
default="bf16",
choices=PRECISIONS,
help="Precision mode. Options: fp32, fp16, bf16. Applied to the backbone model and optimizer.",
)
# RoPE
group.add_argument("--rope-theta", type=int, default=256, help="Theta used in RoPE.")
return parser
def add_extra_models_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Extra models args, including vae, text encoders and tokenizers)")
# - VAE
group.add_argument(
"--vae",
type=str,
default="884-16c-hy",
choices=list(VAE_PATH),
help="Name of the VAE model.",
)
group.add_argument(
"--vae-precision",
type=str,
default="fp16",
choices=PRECISIONS,
help="Precision mode for the VAE model.",
)
group.add_argument(
"--vae-tiling",
action="store_true",
help="Enable tiling for the VAE model to save GPU memory.",
)
group.set_defaults(vae_tiling=True)
group.add_argument(
"--text-encoder",
type=str,
default="llm",
choices=list(TEXT_ENCODER_PATH),
help="Name of the text encoder model.",
)
group.add_argument(
"--text-encoder-precision",
type=str,
default="fp16",
choices=PRECISIONS,
help="Precision mode for the text encoder model.",
)
group.add_argument(
"--text-states-dim",
type=int,
default=4096,
help="Dimension of the text encoder hidden states.",
)
group.add_argument("--text-len", type=int, default=256, help="Maximum length of the text input.")
group.add_argument(
"--tokenizer",
type=str,
default="llm",
choices=list(TOKENIZER_PATH),
help="Name of the tokenizer model.",
)
group.add_argument(
"--prompt-template",
type=str,
default="dit-llm-encode",
choices=PROMPT_TEMPLATE,
help="Image prompt template for the decoder-only text encoder model.",
)
group.add_argument(
"--prompt-template-video",
type=str,
default="dit-llm-encode-video",
choices=PROMPT_TEMPLATE,
help="Video prompt template for the decoder-only text encoder model.",
)
group.add_argument(
"--hidden-state-skip-layer",
type=int,
default=2,
help="Skip layer for hidden states.",
)
group.add_argument(
"--apply-final-norm",
action="store_true",
help="Apply final normalization to the used text encoder hidden states.",
)
# - CLIP
group.add_argument(
"--text-encoder-2",
type=str,
default="clipL",
choices=list(TEXT_ENCODER_PATH),
help="Name of the second text encoder model.",
)
group.add_argument(
"--text-encoder-precision-2",
type=str,
default="fp16",
choices=PRECISIONS,
help="Precision mode for the second text encoder model.",
)
group.add_argument(
"--text-states-dim-2",
type=int,
default=768,
help="Dimension of the second text encoder hidden states.",
)
group.add_argument(
"--tokenizer-2",
type=str,
default="clipL",
choices=list(TOKENIZER_PATH),
help="Name of the second tokenizer model.",
)
group.add_argument(
"--text-len-2",
type=int,
default=77,
help="Maximum length of the second text input.",
)
return parser
def add_denoise_schedule_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Denoise schedule args")
group.add_argument(
"--denoise-type",
type=str,
default="flow",
help="Denoise type for noised inputs.",
)
# Flow Matching
group.add_argument(
"--flow-shift",
type=float,
default=7.0,
help="Shift factor for flow matching schedulers.",
)
group.add_argument(
"--flow-reverse",
action="store_true",
help="If reverse, learning/sampling from t=1 -> t=0.",
)
group.add_argument(
"--flow-solver",
type=str,
default="euler",
help="Solver for flow matching.",
)
group.add_argument(
"--use-linear-quadratic-schedule",
action="store_true",
help="Use linear quadratic schedule for flow matching."
"Following MovieGen (https://ai.meta.com/static-resource/movie-gen-research-paper)",
)
group.add_argument(
"--linear-schedule-end",
type=int,
default=25,
help="End step for linear quadratic schedule for flow matching.",
)
return parser
def add_inference_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Inference args")
# ======================== Model loads ========================
group.add_argument(
"--model-base",
type=str,
default="ckpts",
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--dit-weight",
type=str,
default="ckpts/hunyuan-video-t2v-720p/transformers/mp_rank_00_model_states.pt",
help="Path to the HunyuanVideo model. If None, search the model in the args.model_root."
"1. If it is a file, load the model directly."
"2. If it is a directory, search the model in the directory. Support two types of models: "
"1) named `pytorch_model_*.pt`"
"2) named `*_model_states.pt`, where * can be `mp_rank_00`.",
)
group.add_argument(
"--model-resolution",
type=str,
default="540p",
choices=["540p", "720p"],
help="Root path of all the models, including t2v models and extra models.",
)
group.add_argument(
"--load-key",
type=str,
default="module",
help="Key to load the model states. 'module' for the main model, 'ema' for the EMA model.",
)
group.add_argument(
"--use-cpu-offload",
action="store_true",
help="Use CPU offload for the model load.",
)
# ======================== Inference general setting ========================
group.add_argument(
"--batch-size",
type=int,
default=1,
help="Batch size for inference and evaluation.",
)
group.add_argument(
"--infer-steps",
type=int,
default=50,
help="Number of denoising steps for inference.",
)
group.add_argument(
"--disable-autocast",
action="store_true",
help="Disable autocast for denoising loop and vae decoding in pipeline sampling.",
)
group.add_argument(
"--save-path",
type=str,
default="./results",
help="Path to save the generated samples.",
)
group.add_argument(
"--save-path-suffix",
type=str,
default="",
help="Suffix for the directory of saved samples.",
)
group.add_argument(
"--name-suffix",
type=str,
default="",
help="Suffix for the names of saved samples.",
)
group.add_argument(
"--num-videos",
type=int,
default=1,
help="Number of videos to generate for each prompt.",
)
# ---sample size---
group.add_argument(
"--video-size",
type=int,
nargs="+",
default=(720, 1280),
help="Video size for training. If a single value is provided, it will be used for both height "
"and width. If two values are provided, they will be used for height and width "
"respectively.",
)
group.add_argument(
"--video-length",
type=int,
default=129,
help="How many frames to sample from a video. if using 3d vae, the number should be 4n+1",
)
# --- prompt ---
group.add_argument(
"--prompt",
type=str,
default=None,
help="Prompt for sampling during evaluation.",
)
group.add_argument(
"--seed-type",
type=str,
default="auto",
choices=["file", "random", "fixed", "auto"],
help="Seed type for evaluation. If file, use the seed from the CSV file. If random, generate a "
"random seed. If fixed, use the fixed seed given by `--seed`. If auto, `csv` will use the "
"seed column if available, otherwise use the fixed `seed` value. `prompt` will use the "
"fixed `seed` value.",
)
group.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
# Classifier-Free Guidance
group.add_argument("--neg-prompt", type=str, default=None, help="Negative prompt for sampling.")
group.add_argument("--cfg-scale", type=float, default=1.0, help="Classifier free guidance scale.")
group.add_argument(
"--embedded-cfg-scale",
type=float,
default=6.0,
help="Embedded classifier free guidance scale.",
)
group.add_argument(
"--reproduce",
action="store_true",
help="Enable reproducibility by setting random seeds and deterministic algorithms.",
)
return parser
def add_parallel_args(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Parallel args")
# ======================== Model loads ========================
group.add_argument(
"--ulysses-degree",
type=int,
default=1,
help="Ulysses degree.",
)
group.add_argument(
"--ring-degree",
type=int,
default=1,
help="Ulysses degree.",
)
return parser
def sanity_check_args(args):
# VAE channels
vae_pattern = r"\d{2,3}-\d{1,2}c-\w+"
if not re.match(vae_pattern, args.vae):
raise ValueError(f"Invalid VAE model: {args.vae}. Must be in the format of '{vae_pattern}'.")
vae_channels = int(args.vae.split("-")[1][:-1])
if args.latent_channels is None:
args.latent_channels = vae_channels
if vae_channels != args.latent_channels:
raise ValueError(f"Latent channels ({args.latent_channels}) must match the VAE channels ({vae_channels}).")
return args
import os
import random
import time
from pathlib import Path
import torch
from loguru import logger
from safetensors.torch import load_file as safetensors_load_file
from fastvideo.models.hunyuan.constants import NEGATIVE_PROMPT, PRECISION_TO_TYPE, PROMPT_TEMPLATE
from fastvideo.models.hunyuan.diffusion.pipelines import HunyuanVideoPipeline
from fastvideo.models.hunyuan.diffusion.schedulers import FlowMatchDiscreteScheduler
from fastvideo.models.hunyuan.modules import load_model
from fastvideo.models.hunyuan.text_encoder import TextEncoder
from fastvideo.models.hunyuan.utils.data_utils import align_to
from fastvideo.models.hunyuan.vae import load_vae
from fastvideo.utils.parallel_states import nccl_info
class Inference(object):
def __init__(
self,
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=None,
pipeline=None,
use_cpu_offload=False,
device=None,
logger=None,
parallel_args=None,
):
self.vae = vae
self.vae_kwargs = vae_kwargs
self.text_encoder = text_encoder
self.text_encoder_2 = text_encoder_2
self.model = model
self.pipeline = pipeline
self.use_cpu_offload = use_cpu_offload
self.args = args
self.device = (device if device is not None else "cuda" if torch.cuda.is_available() else "cpu")
self.logger = logger
self.parallel_args = parallel_args
@classmethod
def from_pretrained(cls, pretrained_model_path, args, device=None, **kwargs):
"""
Initialize the Inference pipeline.
Args:
pretrained_model_path (str or pathlib.Path): The model path, including t2v, text encoder and vae checkpoints.
args (argparse.Namespace): The arguments for the pipeline.
device (int): The device for inference. Default is 0.
"""
# ========================================================================
logger.info(f"Got text-to-video model root path: {pretrained_model_path}")
# ==================== Initialize Distributed Environment ================
if nccl_info.sp_size > 1:
device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}")
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
parallel_args = None # {"ulysses_degree": args.ulysses_degree, "ring_degree": args.ring_degree}
# ======================== Get the args path =============================
# Disable gradient
torch.set_grad_enabled(False)
# =========================== Build main model ===========================
logger.info("Building model...")
factor_kwargs = {"device": device, "dtype": PRECISION_TO_TYPE[args.precision]}
in_channels = args.latent_channels
out_channels = args.latent_channels
model = load_model(
args,
in_channels=in_channels,
out_channels=out_channels,
factor_kwargs=factor_kwargs,
)
model = model.to(device)
model = Inference.load_state_dict(args, model, pretrained_model_path)
if args.enable_torch_compile:
model = torch.compile(model)
model.eval()
# ============================= Build extra models ========================
# VAE
vae, _, s_ratio, t_ratio = load_vae(
args.vae,
args.vae_precision,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
# Text encoder
if args.prompt_template_video is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
elif args.prompt_template is not None:
crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
else:
crop_start = 0
max_length = args.text_len + crop_start
# prompt_template
prompt_template = (PROMPT_TEMPLATE[args.prompt_template] if args.prompt_template is not None else None)
# prompt_template_video
prompt_template_video = (PROMPT_TEMPLATE[args.prompt_template_video]
if args.prompt_template_video is not None else None)
text_encoder = TextEncoder(
text_encoder_type=args.text_encoder,
max_length=max_length,
text_encoder_precision=args.text_encoder_precision,
tokenizer_type=args.tokenizer,
prompt_template=prompt_template,
prompt_template_video=prompt_template_video,
hidden_state_skip_layer=args.hidden_state_skip_layer,
apply_final_norm=args.apply_final_norm,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
text_encoder_2 = None
if args.text_encoder_2 is not None:
text_encoder_2 = TextEncoder(
text_encoder_type=args.text_encoder_2,
max_length=args.text_len_2,
text_encoder_precision=args.text_encoder_precision_2,
tokenizer_type=args.tokenizer_2,
reproduce=args.reproduce,
logger=logger,
device=device if not args.use_cpu_offload else "cpu",
)
return cls(
args=args,
vae=vae,
vae_kwargs=vae_kwargs,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
model=model,
use_cpu_offload=args.use_cpu_offload,
device=device,
logger=logger,
parallel_args=parallel_args,
)
@staticmethod
def load_state_dict(args, model, pretrained_model_path):
load_key = args.load_key
dit_weight = Path(args.dit_weight)
if dit_weight is None:
model_dir = pretrained_model_path / f"t2v_{args.model_resolution}"
files = list(model_dir.glob("*.pt"))
if len(files) == 0:
raise ValueError(f"No model weights found in {model_dir}")
if str(files[0]).startswith("pytorch_model_"):
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
bare_model = True
elif any(str(f).endswith("_model_states.pt") for f in files):
files = [f for f in files if str(f).endswith("_model_states.pt")]
model_path = files[0]
if len(files) > 1:
logger.warning(f"Multiple model weights found in {dit_weight}, using {model_path}")
bare_model = False
else:
raise ValueError(f"Invalid model path: {dit_weight} with unrecognized weight format: "
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
f"specific weight file, please provide the full path to the file.")
else:
if dit_weight.is_dir():
files = list(dit_weight.glob("*.pt"))
if len(files) == 0:
raise ValueError(f"No model weights found in {dit_weight}")
if str(files[0]).startswith("pytorch_model_"):
model_path = dit_weight / f"pytorch_model_{load_key}.pt"
bare_model = True
elif any(str(f).endswith("_model_states.pt") for f in files):
files = [f for f in files if str(f).endswith("_model_states.pt")]
model_path = files[0]
if len(files) > 1:
logger.warning(f"Multiple model weights found in {dit_weight}, using {model_path}")
bare_model = False
else:
raise ValueError(f"Invalid model path: {dit_weight} with unrecognized weight format: "
f"{list(map(str, files))}. When given a directory as --dit-weight, only "
f"`pytorch_model_*.pt`(provided by HunyuanDiT official) and "
f"`*_model_states.pt`(saved by deepspeed) can be parsed. If you want to load a "
f"specific weight file, please provide the full path to the file.")
elif dit_weight.is_file():
model_path = dit_weight
bare_model = "unknown"
else:
raise ValueError(f"Invalid model path: {dit_weight}")
if not model_path.exists():
raise ValueError(f"model_path not exists: {model_path}")
logger.info(f"Loading torch model {model_path}...")
if model_path.suffix == ".safetensors":
# Use safetensors library for .safetensors files
state_dict = safetensors_load_file(model_path)
elif model_path.suffix == ".pt":
# Use torch for .pt files
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
else:
raise ValueError(f"Unsupported file format: {model_path}")
if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
bare_model = False
if bare_model is False:
if load_key in state_dict:
state_dict = state_dict[load_key]
else:
raise KeyError(f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
f"are: {list(state_dict.keys())}.")
model.load_state_dict(state_dict, strict=True)
return model
@staticmethod
def parse_size(size):
if isinstance(size, int):
size = [size]
if not isinstance(size, (list, tuple)):
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
if len(size) == 1:
size = [size[0], size[0]]
if len(size) != 2:
raise ValueError(f"Size must be an integer or (height, width), got {size}.")
return size
class HunyuanVideoSampler(Inference):
def __init__(
self,
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=None,
pipeline=None,
use_cpu_offload=False,
device=0,
logger=None,
parallel_args=None,
):
super().__init__(
args,
vae,
vae_kwargs,
text_encoder,
model,
text_encoder_2=text_encoder_2,
pipeline=pipeline,
use_cpu_offload=use_cpu_offload,
device=device,
logger=logger,
parallel_args=parallel_args,
)
self.pipeline = self.load_diffusion_pipeline(
args=args,
vae=self.vae,
text_encoder=self.text_encoder,
text_encoder_2=self.text_encoder_2,
model=self.model,
device=self.device,
)
self.default_negative_prompt = NEGATIVE_PROMPT
def load_diffusion_pipeline(
self,
args,
vae,
text_encoder,
text_encoder_2,
model,
scheduler=None,
device=None,
progress_bar_config=None,
data_type="video",
):
"""Load the denoising scheduler for inference."""
if scheduler is None:
if args.denoise_type == "flow":
scheduler = FlowMatchDiscreteScheduler(
shift=args.flow_shift,
reverse=args.flow_reverse,
solver=args.flow_solver,
)
else:
raise ValueError(f"Invalid denoise type {args.denoise_type}")
pipeline = HunyuanVideoPipeline(
vae=vae,
text_encoder=text_encoder,
text_encoder_2=text_encoder_2,
transformer=model,
scheduler=scheduler,
progress_bar_config=progress_bar_config,
args=args,
)
if self.use_cpu_offload:
pipeline.enable_sequential_cpu_offload()
else:
pipeline = pipeline.to(device)
return pipeline
@torch.no_grad()
def predict(
self,
prompt,
height=192,
width=336,
video_length=129,
seed=None,
negative_prompt=None,
infer_steps=50,
guidance_scale=6,
flow_shift=5.0,
embedded_guidance_scale=None,
batch_size=1,
num_videos_per_prompt=1,
mask_strategy=None,
**kwargs,
):
"""
Predict the image/video from the given text.
Args:
prompt (str or List[str]): The input text.
kwargs:
height (int): The height of the output video. Default is 192.
width (int): The width of the output video. Default is 336.
video_length (int): The frame number of the output video. Default is 129.
seed (int or List[str]): The random seed for the generation. Default is a random integer.
negative_prompt (str or List[str]): The negative text prompt. Default is an empty string.
guidance_scale (float): The guidance scale for the generation. Default is 6.0.
num_images_per_prompt (int): The number of images per prompt. Default is 1.
infer_steps (int): The number of inference steps. Default is 100.
"""
out_dict = dict()
# ========================================================================
# Arguments: seed
# ========================================================================
if isinstance(seed, torch.Tensor):
seed = seed.tolist()
if seed is None:
seeds = [random.randint(0, 1_000_000) for _ in range(batch_size * num_videos_per_prompt)]
elif isinstance(seed, int):
seeds = [seed + i for _ in range(batch_size) for i in range(num_videos_per_prompt)]
elif isinstance(seed, (list, tuple)):
if len(seed) == batch_size:
seeds = [int(seed[i]) + j for i in range(batch_size) for j in range(num_videos_per_prompt)]
elif len(seed) == batch_size * num_videos_per_prompt:
seeds = [int(s) for s in seed]
else:
raise ValueError(
f"Length of seed must be equal to number of prompt(batch_size) or "
f"batch_size * num_videos_per_prompt ({batch_size} * {num_videos_per_prompt}), got {seed}.")
else:
raise ValueError(f"Seed must be an integer, a list of integers, or None, got {seed}.")
# Peiyuan: using GPU seed will cause A100 and H100 to generate different results...
generator = [torch.Generator("cpu").manual_seed(seed) for seed in seeds]
out_dict["seeds"] = seeds
# ========================================================================
# Arguments: target_width, target_height, target_video_length
# ========================================================================
if width <= 0 or height <= 0 or video_length <= 0:
raise ValueError(
f"`height` and `width` and `video_length` must be positive integers, got height={height}, width={width}, video_length={video_length}"
)
if (video_length - 1) % 4 != 0:
raise ValueError(f"`video_length-1` must be a multiple of 4, got {video_length}")
logger.info(f"Input (height, width, video_length) = ({height}, {width}, {video_length})")
target_height = align_to(height, 16)
target_width = align_to(width, 16)
target_video_length = video_length
out_dict["size"] = (target_height, target_width, target_video_length)
# ========================================================================
# Arguments: prompt, new_prompt, negative_prompt
# ========================================================================
if not isinstance(prompt, str):
raise TypeError(f"`prompt` must be a string, but got {type(prompt)}")
prompt = [prompt.strip()]
# negative prompt
if negative_prompt is None or negative_prompt == "":
negative_prompt = self.default_negative_prompt
if not isinstance(negative_prompt, str):
raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}")
negative_prompt = [negative_prompt.strip()]
# ========================================================================
# Scheduler
# ========================================================================
scheduler = FlowMatchDiscreteScheduler(
shift=flow_shift,
reverse=self.args.flow_reverse,
solver=self.args.flow_solver,
)
self.pipeline.scheduler = scheduler
if "884" in self.args.vae:
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
elif "888" in self.args.vae:
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
n_tokens = latents_size[0] * latents_size[1] * latents_size[2]
# ========================================================================
# Print infer args
# ========================================================================
debug_str = f"""
height: {target_height}
width: {target_width}
video_length: {target_video_length}
prompt: {prompt}
neg_prompt: {negative_prompt}
seed: {seed}
infer_steps: {infer_steps}
num_videos_per_prompt: {num_videos_per_prompt}
guidance_scale: {guidance_scale}
n_tokens: {n_tokens}
flow_shift: {flow_shift}
embedded_guidance_scale: {embedded_guidance_scale}"""
logger.debug(debug_str)
# ========================================================================
# Pipeline inference
# ========================================================================
start_time = time.time()
samples = self.pipeline(
prompt=prompt,
height=target_height,
width=target_width,
video_length=target_video_length,
num_inference_steps=infer_steps,
guidance_scale=guidance_scale,
negative_prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
generator=generator,
output_type="pil",
n_tokens=n_tokens,
embedded_guidance_scale=embedded_guidance_scale,
data_type="video" if target_video_length > 1 else "image",
is_progress_bar=True,
vae_ver=self.args.vae,
enable_tiling=self.args.vae_tiling,
enable_vae_sp=self.args.vae_sp,
mask_strategy=mask_strategy,
)[0]
out_dict["samples"] = samples
out_dict["prompts"] = prompt
gen_time = time.time() - start_time
logger.info(f"Success, time: {gen_time}")
return out_dict
from .models import HUNYUAN_VIDEO_CONFIG, HYVideoDiffusionTransformer
def load_model(args, in_channels, out_channels, factor_kwargs):
"""load hunyuan video model
Args:
args (dict): model args
in_channels (int): input channels number
out_channels (int): output channels number
factor_kwargs (dict): factor kwargs
Returns:
model (nn.Module): The hunyuan video model
"""
if args.model in HUNYUAN_VIDEO_CONFIG.keys():
model = HYVideoDiffusionTransformer(
in_channels=in_channels,
out_channels=out_channels,
**HUNYUAN_VIDEO_CONFIG[args.model],
**factor_kwargs,
)
return model
else:
raise NotImplementedError()
import torch.nn as nn
def get_activation_layer(act_type):
"""get activation layer
Args:
act_type (str): the activation type
Returns:
torch.nn.functional: the activation layer
"""
if act_type == "gelu":
return lambda: nn.GELU()
elif act_type == "gelu_tanh":
# Approximate `tanh` requires torch >= 1.13
return lambda: nn.GELU(approximate="tanh")
elif act_type == "relu":
return nn.ReLU
elif act_type == "silu":
return nn.SiLU
else:
raise ValueError(f"Unknown activation type: {act_type}")
import torch
import torch.nn.functional as F
from einops import rearrange
import torch.distributed as dist
try:
from st_attn import sliding_tile_attention
except ImportError:
print("Could not load Sliding Tile Attention.")
sliding_tile_attention = None
from fastvideo.models.flash_attn_no_pad import flash_attn_no_pad
from fastvideo.utils.communications import all_gather, all_to_all_4D
from fastvideo.utils.parallel_states import get_sequence_parallel_state, nccl_info
"""
class VstreamManager:
_vsms: list = [] # 类变量初始化
_initialized: bool = False # 初始化标志
def __init__(self):
raise RuntimeError("This class should not be instantiated. Use class methods directly.")
@classmethod
def _initialize(cls):
if not cls._initialized:
with torch.cuda.stream(torch.cuda.Stream()): # 临时 Stream 保护
rank = dist.get_rank()
cls._vsms = [
torch.cuda.Stream(device=rank),
torch.cuda.Stream(device=rank)
]
cls._initialized = True
@classmethod
def get(cls, index: int = 0) -> torch.cuda.Stream:
if not cls._initialized:
cls._initialize()
if not cls._vsms or index >= len(cls._vsms):
raise ValueError(f"Invalid stream index: {index}")
return cls._vsms[index]
"""
def attention(
q,
k,
v,
drop_rate=0,
attn_mask=None,
causal=False,
):
qkv = torch.stack([q, k, v], dim=2)
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.bool()
x = flash_attn_no_pad(qkv, attn_mask, causal=causal, dropout_p=drop_rate, softmax_scale=None)
b, s, a, d = x.shape
out = x.reshape(b, s, -1)
return out
def tile(x, sp_size):
x = rearrange(x, "b (sp t h w) head d -> b (t sp h w) head d", sp=sp_size, t=30 // sp_size, h=48, w=80)
return rearrange(x,
"b (n_t ts_t n_h ts_h n_w ts_w) h d -> b (n_t n_h n_w ts_t ts_h ts_w) h d",
n_t=5,
n_h=6,
n_w=10,
ts_t=6,
ts_h=8,
ts_w=8)
def untile(x, sp_size):
x = rearrange(x,
"b (n_t n_h n_w ts_t ts_h ts_w) h d -> b (n_t ts_t n_h ts_h n_w ts_w) h d",
n_t=5,
n_h=6,
n_w=10,
ts_t=6,
ts_h=8,
ts_w=8)
return rearrange(x, "b (t sp h w) head d -> b (sp t h w) head d", sp=sp_size, t=30 // sp_size, h=48, w=80)
def parallel_attention(q, k, v, img_q_len, img_kv_len, text_mask, mask_strategy=None):
query, encoder_query = q
key, encoder_key = k
value, encoder_value = v
text_length = text_mask.sum()
if get_sequence_parallel_state():
# batch_size, seq_len, attn_heads, head_dim
query = all_to_all_4D(query, scatter_dim=2, gather_dim=1)
key = all_to_all_4D(key, scatter_dim=2, gather_dim=1)
value = all_to_all_4D(value, scatter_dim=2, gather_dim=1)
"""
with torch.cuda.stream(VstreamManager.get(0)):
key = all_to_all_4D(key, scatter_dim=2, gather_dim=1)
with torch.cuda.stream(VstreamManager.get(1)):
value = all_to_all_4D(value, scatter_dim=2, gather_dim=1)
"""
def shrink_head(encoder_state, dim):
local_heads = encoder_state.shape[dim] // nccl_info.sp_size
return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
encoder_query = shrink_head(encoder_query, dim=2)
encoder_key = shrink_head(encoder_key, dim=2)
encoder_value = shrink_head(encoder_value, dim=2)
#torch.cuda.current_stream().wait_stream(VstreamManager.get(0))
#torch.cuda.current_stream().wait_stream(VstreamManager.get(1))
#torch.cuda.synchronize()
# [b, s, h, d]
sequence_length = query.size(1)
encoder_sequence_length = encoder_query.size(1)
if mask_strategy[0] is not None:
query = torch.cat([tile(query, nccl_info.sp_size), encoder_query], dim=1).transpose(1, 2)
key = torch.cat([tile(key, nccl_info.sp_size), encoder_key], dim=1).transpose(1, 2)
value = torch.cat([tile(value, nccl_info.sp_size), encoder_value], dim=1).transpose(1, 2)
head_num = query.size(1)
current_rank = nccl_info.rank_within_group
start_head = current_rank * head_num
windows = [mask_strategy[head_idx + start_head] for head_idx in range(head_num)]
hidden_states = sliding_tile_attention(query, key, value, windows, text_length).transpose(1, 2)
else:
query = torch.cat([query, encoder_query], dim=1)
key = torch.cat([key, encoder_key], dim=1)
value = torch.cat([value, encoder_value], dim=1)
# B, S, 3, H, D
qkv = torch.stack([query, key, value], dim=2)
attn_mask = F.pad(text_mask, (sequence_length, 0), value=True)
hidden_states = flash_attn_no_pad(qkv, attn_mask, causal=False, dropout_p=0.0, softmax_scale=None)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes((sequence_length, encoder_sequence_length),
dim=1)
if mask_strategy[0] is not None:
hidden_states = untile(hidden_states, nccl_info.sp_size)
if get_sequence_parallel_state():
#hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2)
#encoder_hidden_states = all_gather(encoder_hidden_states, dim=2).contiguous()
hidden_states = all_to_all_4D(hidden_states, scatter_dim=1, gather_dim=2)
encoder_hidden_states = all_gather(encoder_hidden_states, dim=2).contiguous()
hidden_states = hidden_states.to(query.dtype)
encoder_hidden_states = encoder_hidden_states.to(query.dtype)
attn = torch.cat([hidden_states, encoder_hidden_states], dim=1)
b, s, a, d = attn.shape
attn = attn.reshape(b, s, -1)
return attn
import math
import torch
import torch.nn as nn
from ..utils.helpers import to_2tuple
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding
Image to Patch Embedding using Conv2d
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Hacked together by / Copyright 2020 Ross Wightman
Remove the _assert function in forward function to be compatible with multi-resolution images.
"""
def __init__(
self,
patch_size=16,
in_chans=3,
embed_dim=768,
norm_layer=None,
flatten=True,
bias=True,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
patch_size = to_2tuple(patch_size)
self.patch_size = patch_size
self.flatten = flatten
self.proj = nn.Conv3d(
in_chans,
embed_dim,
kernel_size=patch_size,
stride=patch_size,
bias=bias,
**factory_kwargs,
)
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
if bias:
nn.init.zeros_(self.proj.bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
class TextProjection(nn.Module):
"""
Projects text embeddings. Also handles dropout for classifier-free guidance.
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
"""
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.linear_1 = nn.Linear(
in_features=in_channels,
out_features=hidden_size,
bias=True,
**factory_kwargs,
)
self.act_1 = act_layer()
self.linear_2 = nn.Linear(
in_features=hidden_size,
out_features=hidden_size,
bias=True,
**factory_kwargs,
)
def forward(self, caption):
hidden_states = self.linear_1(caption)
hidden_states = self.act_1(hidden_states)
hidden_states = self.linear_2(hidden_states)
return hidden_states
@torch.compile(mode="max-autotune-no-cudagraphs")
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Args:
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
dim (int): the dimension of the output.
max_period (int): controls the minimum frequency of the embeddings.
Returns:
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) /
half).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(
self,
hidden_size,
act_layer,
frequency_embedding_size=256,
max_period=10000,
out_size=None,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.frequency_embedding_size = frequency_embedding_size
self.max_period = max_period
if out_size is None:
out_size = hidden_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
act_layer(),
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
)
nn.init.normal_(self.mlp[0].weight, std=0.02)
nn.init.normal_(self.mlp[2].weight, std=0.02)
def forward(self, t):
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
t_emb = self.mlp(t_freq)
return t_emb
# Modified from timm library:
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
from functools import partial
import torch
import torch.nn as nn
from ..utils.helpers import to_2tuple
from .modulate_layers import modulate
class MLP(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__(
self,
in_channels,
hidden_channels=None,
out_features=None,
act_layer=nn.GELU,
norm_layer=None,
bias=True,
drop=0.0,
use_conv=False,
device=None,
dtype=None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
out_features = out_features or in_channels
hidden_channels = hidden_channels or in_channels
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_channels, hidden_channels, bias=bias[0], **factory_kwargs)
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.norm = (norm_layer(hidden_channels, **factory_kwargs) if norm_layer is not None else nn.Identity())
self.fc2 = linear_layer(hidden_channels, out_features, bias=bias[1], **factory_kwargs)
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.norm(x)
x = self.fc2(x)
x = self.drop2(x)
return x
#
class MLPEmbedder(nn.Module):
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class FinalLayer(nn.Module):
"""The final layer of DiT."""
def __init__(self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
# Just use LayerNorm for the final layer
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
if isinstance(patch_size, int):
self.linear = nn.Linear(
hidden_size,
patch_size * patch_size * out_channels,
bias=True,
**factory_kwargs,
)
else:
self.linear = nn.Linear(
hidden_size,
patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
bias=True,
)
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
# Here we don't distinguish between the modulate types. Just use the simple one.
self.adaLN_modulation = nn.Sequential(
act_layer(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
)
# Zero-initialize the modulation
nn.init.zeros_(self.adaLN_modulation[1].weight)
nn.init.zeros_(self.adaLN_modulation[1].bias)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x = modulate(self.norm_final(x), shift=shift, scale=scale)
x = self.linear(x)
return x
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from einops import rearrange
from fastvideo.models.hunyuan.modules.posemb_layers import get_nd_rotary_pos_embed
from fastvideo.utils.parallel_states import nccl_info
from .activation_layers import get_activation_layer
from .attenion import parallel_attention
from .embed_layers import PatchEmbed, TextProjection, TimestepEmbedder
from .mlp_layers import MLP, FinalLayer, MLPEmbedder
from .modulate_layers import ModulateDiT, apply_gate, modulate
from .norm_layers import get_norm_layer
from .posemb_layers import apply_rotary_emb
from .token_refiner import SingleTokenRefiner
class MMDoubleStreamBlock(nn.Module):
"""
A multimodal dit block with separate modulation for
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qkv_bias: bool = False,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.img_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.img_attn_q_norm = (qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm else nn.Identity())
self.img_attn_k_norm = (qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm else nn.Identity())
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.img_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.txt_mod = ModulateDiT(
hidden_size,
factor=6,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
self.txt_attn_q_norm = (qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm else nn.Identity())
self.txt_attn_k_norm = (qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm else nn.Identity())
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.txt_mlp = MLP(
hidden_size,
mlp_hidden_dim,
act_layer=get_activation_layer(mlp_act_type),
bias=True,
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
img: torch.Tensor,
txt: torch.Tensor,
vec: torch.Tensor,
freqs_cis: tuple = None,
text_mask: torch.Tensor = None,
mask_strategy=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
(
img_mod1_shift,
img_mod1_scale,
img_mod1_gate,
img_mod2_shift,
img_mod2_scale,
img_mod2_gate,
) = self.img_mod(vec).chunk(6, dim=-1)
(
txt_mod1_shift,
txt_mod1_scale,
txt_mod1_gate,
txt_mod2_shift,
txt_mod2_scale,
txt_mod2_gate,
) = self.txt_mod(vec).chunk(6, dim=-1)
# Prepare image for attention.
img_modulated = self.img_norm1(img)
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
img_qkv = self.img_attn_qkv(img_modulated)
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed
img_q = self.img_attn_q_norm(img_q).to(img_v)
img_k = self.img_attn_k_norm(img_k).to(img_v)
# Apply RoPE if needed.
if freqs_cis is not None:
def shrink_head(encoder_state, dim):
local_heads = encoder_state.shape[dim] // nccl_info.sp_size
return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
freqs_cis = (
shrink_head(freqs_cis[0], dim=0),
shrink_head(freqs_cis[1], dim=0),
)
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
# Prepare txt for attention.
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
txt_qkv = self.txt_attn_qkv(txt_modulated)
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed.
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
attn = parallel_attention(
(img_q, txt_q),
(img_k, txt_k),
(img_v, txt_v),
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
text_mask=text_mask,
mask_strategy=mask_strategy,
)
# attention computation end
img_attn, txt_attn = attn[:, :img.shape[1]], attn[:, img.shape[1]:]
# Calculate the img blocks.
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
img = img + apply_gate(
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
gate=img_mod2_gate,
)
# Calculate the txt blocks.
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
txt = txt + apply_gate(
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
gate=txt_mod2_gate,
)
return img, txt
class MMSingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
Also refer to (SD3): https://arxiv.org/abs/2403.03206
(Flux.1): https://github.com/black-forest-labs/flux
"""
def __init__(
self,
hidden_size: int,
heads_num: int,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
qk_norm: bool = True,
qk_norm_type: str = "rms",
qk_scale: float = None,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.deterministic = False
self.hidden_size = hidden_size
self.heads_num = heads_num
head_dim = hidden_size // heads_num
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
self.mlp_hidden_dim = mlp_hidden_dim
self.scale = qk_scale or head_dim**-0.5
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
qk_norm_layer = get_norm_layer(qk_norm_type)
self.q_norm = (qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm else nn.Identity())
self.k_norm = (qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
if qk_norm else nn.Identity())
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
self.mlp_act = get_activation_layer(mlp_act_type)()
self.modulation = ModulateDiT(
hidden_size,
factor=3,
act_layer=get_activation_layer("silu"),
**factory_kwargs,
)
self.hybrid_seq_parallel_attn = None
def enable_deterministic(self):
self.deterministic = True
def disable_deterministic(self):
self.deterministic = False
def forward(
self,
x: torch.Tensor,
vec: torch.Tensor,
txt_len: int,
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
text_mask: torch.Tensor = None,
mask_strategy=None,
) -> torch.Tensor:
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
# Apply QK-Norm if needed.
q = self.q_norm(q).to(v)
k = self.k_norm(k).to(v)
def shrink_head(encoder_state, dim):
local_heads = encoder_state.shape[dim] // nccl_info.sp_size
return encoder_state.narrow(dim, nccl_info.rank_within_group * local_heads, local_heads)
freqs_cis = (
shrink_head(freqs_cis[0], dim=0),
shrink_head(freqs_cis[1], dim=0),
)
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_v, txt_v = v[:, :-txt_len, :, :], v[:, -txt_len:, :, :]
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
attn = parallel_attention(
(img_q, txt_q),
(img_k, txt_k),
(img_v, txt_v),
img_q_len=img_q.shape[1],
img_kv_len=img_k.shape[1],
text_mask=text_mask,
mask_strategy=mask_strategy,
)
# attention computation end
# Compute activation in mlp stream, cat again and run second linear layer.
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + apply_gate(output, gate=mod_gate)
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
"""
HunyuanVideo Transformer backbone
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
Reference:
[1] Flux.1: https://github.com/black-forest-labs/flux
[2] MMDiT: http://arxiv.org/abs/2403.03206
Parameters
----------
args: argparse.Namespace
The arguments parsed by argparse.
patch_size: list
The size of the patch.
in_channels: int
The number of input channels.
out_channels: int
The number of output channels.
hidden_size: int
The hidden size of the transformer backbone.
heads_num: int
The number of attention heads.
mlp_width_ratio: float
The ratio of the hidden size of the MLP in the transformer block.
mlp_act_type: str
The activation function of the MLP in the transformer block.
depth_double_blocks: int
The number of transformer blocks in the double blocks.
depth_single_blocks: int
The number of transformer blocks in the single blocks.
rope_dim_list: list
The dimension of the rotary embedding for t, h, w.
qkv_bias: bool
Whether to use bias in the qkv linear layer.
qk_norm: bool
Whether to use qk norm.
qk_norm_type: str
The type of qk norm.
guidance_embed: bool
Whether to use guidance embedding for distillation.
text_projection: str
The type of the text projection, default is single_refiner.
use_attention_mask: bool
Whether to use attention mask for text encoder.
dtype: torch.dtype
The dtype of the model.
device: torch.device
The device of the model.
"""
@register_to_config
def __init__(
self,
patch_size: list = [1, 2, 2],
in_channels: int = 4, # Should be VAE.config.latent_channels.
out_channels: int = None,
hidden_size: int = 3072,
heads_num: int = 24,
mlp_width_ratio: float = 4.0,
mlp_act_type: str = "gelu_tanh",
mm_double_blocks_depth: int = 20,
mm_single_blocks_depth: int = 40,
rope_dim_list: List[int] = [16, 56, 56],
qkv_bias: bool = True,
qk_norm: bool = True,
qk_norm_type: str = "rms",
guidance_embed: bool = False, # For modulation.
text_projection: str = "single_refiner",
use_attention_mask: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
text_states_dim: int = 4096,
text_states_dim_2: int = 768,
rope_theta: int = 256,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.patch_size = patch_size
self.in_channels = in_channels
self.out_channels = in_channels if out_channels is None else out_channels
self.unpatchify_channels = self.out_channels
self.guidance_embed = guidance_embed
self.rope_dim_list = rope_dim_list
self.rope_theta = rope_theta
# Text projection. Default to linear projection.
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
self.use_attention_mask = use_attention_mask
self.text_projection = text_projection
if hidden_size % heads_num != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
pe_dim = hidden_size // heads_num
if sum(rope_dim_list) != pe_dim:
raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.heads_num = heads_num
# image projection
self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
# text projection
if self.text_projection == "linear":
self.txt_in = TextProjection(
self.config.text_states_dim,
self.hidden_size,
get_activation_layer("silu"),
**factory_kwargs,
)
elif self.text_projection == "single_refiner":
self.txt_in = SingleTokenRefiner(
self.config.text_states_dim,
hidden_size,
heads_num,
depth=2,
**factory_kwargs,
)
else:
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
# time modulation
self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
# text modulation
self.vector_in = MLPEmbedder(self.config.text_states_dim_2, self.hidden_size, **factory_kwargs)
# guidance modulation
self.guidance_in = (TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
if guidance_embed else None)
# double blocks
self.double_blocks = nn.ModuleList([
MMDoubleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
qkv_bias=qkv_bias,
**factory_kwargs,
) for _ in range(mm_double_blocks_depth)
])
# single blocks
self.single_blocks = nn.ModuleList([
MMSingleStreamBlock(
self.hidden_size,
self.heads_num,
mlp_width_ratio=mlp_width_ratio,
mlp_act_type=mlp_act_type,
qk_norm=qk_norm,
qk_norm_type=qk_norm_type,
**factory_kwargs,
) for _ in range(mm_single_blocks_depth)
])
self.final_layer = FinalLayer(
self.hidden_size,
self.patch_size,
self.out_channels,
get_activation_layer("silu"),
**factory_kwargs,
)
def enable_deterministic(self):
for block in self.double_blocks:
block.enable_deterministic()
for block in self.single_blocks:
block.enable_deterministic()
def disable_deterministic(self):
for block in self.double_blocks:
block.disable_deterministic()
for block in self.single_blocks:
block.disable_deterministic()
def get_rotary_pos_embed(self, rope_sizes):
target_ndim = 3
head_dim = self.hidden_size // self.heads_num
rope_dim_list = self.rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert (sum(rope_dim_list) == head_dim), "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=self.rope_theta,
use_real=True,
theta_rescale_factor=1,
)
return freqs_cos, freqs_sin
# x: torch.Tensor,
# t: torch.Tensor, # Should be in range(0, 1000).
# text_states: torch.Tensor = None,
# text_mask: torch.Tensor = None, # Now we don't use it.
# text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
# guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
# return_dict: bool = True,
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
mask_strategy=None,
output_features=False,
output_features_stride=8,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = False,
guidance=None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if guidance is None:
guidance = torch.tensor([6016.0], device=hidden_states.device, dtype=torch.bfloat16)
if mask_strategy is None:
mask_strategy = [[None] * self.heads_num for _ in range(len(self.double_blocks) + len(self.single_blocks))]
img = x = hidden_states
text_mask = encoder_attention_mask
t = timestep
txt = encoder_hidden_states[:, 1:]
text_states_2 = encoder_hidden_states[:, 0, :self.config.text_states_dim_2]
_, _, ot, oh, ow = x.shape # codespell:ignore
tt, th, tw = (
ot // self.patch_size[0], # codespell:ignore
oh // self.patch_size[1], # codespell:ignore
ow // self.patch_size[2], # codespell:ignore
)
original_tt = nccl_info.sp_size * tt
freqs_cos, freqs_sin = self.get_rotary_pos_embed((original_tt, th, tw))
# Prepare modulation vectors.
vec = self.time_in(t)
# text modulation
vec = vec + self.vector_in(text_states_2)
# guidance modulation
if self.guidance_embed:
if guidance is None:
raise ValueError("Didn't get guidance strength for guidance distilled model.")
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
vec = vec + self.guidance_in(guidance)
# Embed image and text.
img = self.img_in(img) # conv3d
if self.text_projection == "linear":
txt = self.txt_in(txt)
elif self.text_projection == "single_refiner":
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
else:
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
txt_seq_len = txt.shape[1]
img_seq_len = img.shape[1]
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
# --------------------- Pass through DiT blocks ------------------------
for index, block in enumerate(self.double_blocks):
double_block_args = [img, txt, vec, freqs_cis, text_mask, mask_strategy[index]]
img, txt = block(*double_block_args)
# Merge txt and img to pass through single stream blocks.
x = torch.cat((img, txt), 1)
if output_features:
features_list = []
if len(self.single_blocks) > 0:
for index, block in enumerate(self.single_blocks):
single_block_args = [
x,
vec,
txt_seq_len,
(freqs_cos, freqs_sin),
text_mask,
mask_strategy[index + len(self.double_blocks)],
]
x = block(*single_block_args)
if output_features and _ % output_features_stride == 0:
features_list.append(x[:, :img_seq_len, ...])
img = x[:, :img_seq_len, ...]
# ---------------------------- Final layer ------------------------------
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
img = self.unpatchify(img, tt, th, tw)
assert not return_dict, "return_dict is not supported."
if output_features:
features_list = torch.stack(features_list, dim=0)
else:
features_list = None
return (img, features_list)
def unpatchify(self, x, t, h, w):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.unpatchify_channels
pt, ph, pw = self.patch_size
assert t * h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
x = torch.einsum("nthwcopq->nctohpwq", x)
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
return imgs
def params_count(self):
counts = {
"double":
sum([
sum(p.numel()
for p in block.img_attn_qkv.parameters()) + sum(p.numel()
for p in block.img_attn_proj.parameters()) +
sum(p.numel() for p in block.img_mlp.parameters()) + sum(p.numel()
for p in block.txt_attn_qkv.parameters()) +
sum(p.numel() for p in block.txt_attn_proj.parameters()) + sum(p.numel()
for p in block.txt_mlp.parameters())
for block in self.double_blocks
]),
"single":
sum([
sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters())
for block in self.single_blocks
]),
"total":
sum(p.numel() for p in self.parameters()),
}
counts["attn+mlp"] = counts["double"] + counts["single"]
return counts
#################################################################################
# HunyuanVideo Configs #
#################################################################################
HUNYUAN_VIDEO_CONFIG = {
"HYVideo-T/2": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
},
"HYVideo-T/2-cfgdistill": {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
"guidance_embed": True,
},
}
from typing import Callable
import torch
import torch.nn as nn
class ModulateDiT(nn.Module):
"""Modulation layer for DiT."""
def __init__(
self,
hidden_size: int,
factor: int,
act_layer: Callable,
dtype=None,
device=None,
):
factory_kwargs = {"dtype": dtype, "device": device}
super().__init__()
self.act = act_layer()
self.linear = nn.Linear(hidden_size, factor * hidden_size, bias=True, **factory_kwargs)
# Zero-initialize the modulation
nn.init.zeros_(self.linear.weight)
nn.init.zeros_(self.linear.bias)
#@torch.compile(mode="reduce-overhead")
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.act(x))
# @torch.compile(mode="max-autotune-no-cudagraphs")
# @torch.compile(options={"triton.cudagraphs": False, "triton.cudagraph_trees": True})
#@torch.compile(mode="reduce-overhead")
def modulate(x, shift=None, scale=None):
"""modulate by shift and scale
Args:
x (torch.Tensor): input tensor.
shift (torch.Tensor, optional): shift tensor. Defaults to None.
scale (torch.Tensor, optional): scale tensor. Defaults to None.
Returns:
torch.Tensor: the output tensor after modulate.
"""
if scale is None and shift is None:
return x
elif shift is None:
return x * (1 + scale.unsqueeze(1))
elif scale is None:
return x + shift.unsqueeze(1)
else:
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
def apply_gate(x, gate=None, tanh=False):
"""AI is creating summary for apply_gate
Args:
x (torch.Tensor): input tensor.
gate (torch.Tensor, optional): gate tensor. Defaults to None.
tanh (bool, optional): whether to use tanh function. Defaults to False.
Returns:
torch.Tensor: the output tensor after apply gate.
"""
if gate is None:
return x
if tanh:
return x * gate.unsqueeze(1).tanh()
else:
return x * gate.unsqueeze(1)
def ckpt_wrapper(module):
def ckpt_forward(*inputs):
outputs = module(*inputs)
return outputs
return ckpt_forward
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
# @torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(
self,
dim: int,
elementwise_affine=True,
eps: float = 1e-6,
device=None,
dtype=None,
):
"""
Initialize the RMSNorm normalization layer.
Args:
dim (int): The dimension of the input tensor.
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
Attributes:
eps (float): A small value added to the denominator for numerical stability.
weight (nn.Parameter): Learnable scaling parameter.
"""
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.eps = eps
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
def _norm(self, x):
"""
Apply the RMSNorm normalization to the input tensor.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The normalized tensor.
"""
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
@torch.compile(mode="max-autotune-no-cudagraphs")
def forward(self, x):
"""
Forward pass through the RMSNorm layer.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor after applying RMSNorm.
"""
output = self._norm(x.float()).type_as(x)
if hasattr(self, "weight"):
output = output * self.weight
return output
def get_norm_layer(norm_layer):
"""
Get the normalization layer.
Args:
norm_layer (str): The type of normalization layer.
Returns:
norm_layer (nn.Module): The normalization layer.
"""
if norm_layer == "layer":
return nn.LayerNorm
elif norm_layer == "rms":
return RMSNorm
else:
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
from typing import List, Tuple, Union
import torch
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x, ) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0, ) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
def reshape_for_broadcast(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
x: torch.Tensor,
head_first=False,
):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Notes:
When using FlashMHAModified, head_first should be False.
When using Attention, head_first should be True.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis[0].shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis.shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = (x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
head_first: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
xk_out = None
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
# real * cos - imag * sin
# imag * cos + real * sin
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
else:
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(
rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(
rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor**(dim / (dim - 2))
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment