Commit c07946d8 authored by hepj's avatar hepj
Browse files

dit & video

parents
import os
from pathlib import Path
import torch
import torch.nn.functional as F
from diffusers import AutoencoderKLHunyuanVideo, AutoencoderKLMochi
from torch import nn
from transformers import AutoTokenizer, T5EncoderModel
from fastvideo.models.hunyuan.modules.models import (HYVideoDiffusionTransformer, MMDoubleStreamBlock,
MMSingleStreamBlock)
from fastvideo.models.hunyuan.text_encoder import TextEncoder
from fastvideo.models.hunyuan.vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
from fastvideo.models.hunyuan_hf.modeling_hunyuan import (HunyuanVideoSingleTransformerBlock,
HunyuanVideoTransformer3DModel, HunyuanVideoTransformerBlock)
from fastvideo.models.mochi_hf.modeling_mochi import MochiTransformer3DModel, MochiTransformerBlock
from fastvideo.utils.logging_ import main_print
hunyuan_config = {
"mm_double_blocks_depth": 20,
"mm_single_blocks_depth": 40,
"rope_dim_list": [16, 56, 56],
"hidden_size": 3072,
"heads_num": 24,
"mlp_width_ratio": 4,
"guidance_embed": True,
}
PROMPT_TEMPLATE_ENCODE = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
PROMPT_TEMPLATE_ENCODE_VIDEO = (
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
"1. The main content and theme of the video."
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
"4. background environment, light, style and atmosphere."
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>")
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
PROMPT_TEMPLATE = {
"dit-llm-encode": {
"template": PROMPT_TEMPLATE_ENCODE,
"crop_start": 36,
},
"dit-llm-encode-video": {
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
"crop_start": 95,
},
}
class HunyuanTextEncoderWrapper(nn.Module):
def __init__(self, pretrained_model_name_or_path, device):
super().__init__()
text_len = 256
crop_start = PROMPT_TEMPLATE["dit-llm-encode-video"].get("crop_start", 0)
max_length = text_len + crop_start
# prompt_template
prompt_template = PROMPT_TEMPLATE["dit-llm-encode"]
# prompt_template_video
prompt_template_video = PROMPT_TEMPLATE["dit-llm-encode-video"]
text_encoder_path = os.path.join(pretrained_model_name_or_path, "text_encoder")
self.text_encoder = TextEncoder(
text_encoder_type="llm",
text_encoder_path=text_encoder_path,
max_length=max_length,
text_encoder_precision="fp16",
tokenizer_type="llm",
prompt_template=prompt_template,
prompt_template_video=prompt_template_video,
hidden_state_skip_layer=2,
apply_final_norm=False,
reproduce=False,
logger=None,
device=device,
)
text_encoder_path_2 = os.path.join(pretrained_model_name_or_path, "text_encoder_2")
self.text_encoder_2 = TextEncoder(
text_encoder_type="clipL",
text_encoder_path=text_encoder_path_2,
max_length=77,
text_encoder_precision="fp16",
tokenizer_type="clipL",
reproduce=False,
logger=None,
device=device,
)
def encode_(self, prompt, text_encoder, clip_skip=None):
# TODO
device = self.text_encoder.device
data_type = "video"
num_videos_per_prompt = 1
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
if clip_skip is None:
prompt_outputs = text_encoder.encode(text_inputs, data_type="video", device=device)
prompt_embeds = prompt_outputs.hidden_state
else:
prompt_outputs = text_encoder.encode(
text_inputs,
output_hidden_states=True,
data_type=data_type,
device=device,
)
prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
prompt_embeds = text_encoder.model.text_model.final_layer_norm(prompt_embeds)
attention_mask = prompt_outputs.attention_mask
if attention_mask is not None:
attention_mask = attention_mask.to(device)
bs_embed, seq_len = attention_mask.shape
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
if text_encoder is not None:
prompt_embeds_dtype = text_encoder.dtype
elif self.transformer is not None:
prompt_embeds_dtype = self.transformer.dtype
else:
prompt_embeds_dtype = prompt_embeds.dtype
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
if prompt_embeds.ndim == 2:
bs_embed, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
else:
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
return (prompt_embeds, attention_mask)
def encode_prompt(self, prompt):
prompt_embeds, attention_mask = self.encode_(prompt, self.text_encoder)
prompt_embeds_2, attention_mask_2 = self.encode_(prompt, self.text_encoder_2)
prompt_embeds_2 = F.pad(
prompt_embeds_2,
(0, prompt_embeds.shape[2] - prompt_embeds_2.shape[1]),
value=0,
).unsqueeze(1)
prompt_embeds = torch.cat([prompt_embeds_2, prompt_embeds], dim=1)
return prompt_embeds, attention_mask
class MochiTextEncoderWrapper(nn.Module):
def __init__(self, pretrained_model_name_or_path, device):
super().__init__()
self.text_encoder = T5EncoderModel.from_pretrained(os.path.join(pretrained_model_name_or_path,
"text_encoder")).to(device)
self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(pretrained_model_name_or_path, "tokenizer"))
self.max_sequence_length = 256
def encode_prompt(self, prompt):
device = self.text_encoder.device
dtype = self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.max_sequence_length - 1:-1])
main_print(f"Truncated text input: {prompt} to: {removed_text} for model input.")
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.view(batch_size, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
return prompt_embeds, prompt_attention_mask
def load_hunyuan_state_dict(model, dit_model_name_or_path):
load_key = "module"
model_path = dit_model_name_or_path
bare_model = "unknown"
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
if bare_model == "unknown" and ("ema" in state_dict or "module" in state_dict):
bare_model = False
if bare_model is False:
if load_key in state_dict:
state_dict = state_dict[load_key]
else:
raise KeyError(f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
f"are: {list(state_dict.keys())}.")
model.load_state_dict(state_dict, strict=True)
return model
def load_transformer(
model_type,
dit_model_name_or_path,
pretrained_model_name_or_path,
master_weight_type,
):
if model_type == "mochi":
if dit_model_name_or_path:
transformer = MochiTransformer3DModel.from_pretrained(
dit_model_name_or_path,
torch_dtype=master_weight_type,
# torch_dtype=torch.bfloat16 if args.use_lora else torch.float32,
)
else:
transformer = MochiTransformer3DModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=master_weight_type,
# torch_dtype=torch.bfloat16 if args.use_lora else torch.float32,
)
elif model_type == "hunyuan_hf":
if dit_model_name_or_path:
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
dit_model_name_or_path,
torch_dtype=master_weight_type,
# torch_dtype=torch.bfloat16 if args.use_lora else torch.float32,
)
else:
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
pretrained_model_name_or_path,
subfolder="transformer",
torch_dtype=master_weight_type,
# torch_dtype=torch.bfloat16 if args.use_lora else torch.float32,
)
elif model_type == "hunyuan":
transformer = HYVideoDiffusionTransformer(
in_channels=16,
out_channels=16,
**hunyuan_config,
dtype=master_weight_type,
)
transformer = load_hunyuan_state_dict(transformer, dit_model_name_or_path)
if master_weight_type == torch.bfloat16:
transformer = transformer.bfloat16()
else:
raise ValueError(f"Unsupported model type: {model_type}")
return transformer
def load_vae(model_type, pretrained_model_name_or_path):
weight_dtype = torch.float32
if model_type == "mochi":
vae = AutoencoderKLMochi.from_pretrained(pretrained_model_name_or_path,
subfolder="vae",
torch_dtype=weight_dtype).to("cuda")
autocast_type = torch.bfloat16
fps = 30
elif model_type == "hunyuan_hf":
vae = AutoencoderKLHunyuanVideo.from_pretrained(pretrained_model_name_or_path,
subfolder="vae",
torch_dtype=weight_dtype).to("cuda")
autocast_type = torch.bfloat16
fps = 24
elif model_type == "hunyuan":
vae_precision = torch.float32
vae_path = os.path.join(pretrained_model_name_or_path, "hunyuan-video-t2v-720p/vae")
config = AutoencoderKLCausal3D.load_config(vae_path)
vae = AutoencoderKLCausal3D.from_config(config)
vae_ckpt = Path(vae_path) / "pytorch_model.pt"
assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
ckpt = torch.load(vae_ckpt, map_location=vae.device, weights_only=True)
if "state_dict" in ckpt:
ckpt = ckpt["state_dict"]
if any(k.startswith("vae.") for k in ckpt.keys()):
ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
vae.load_state_dict(ckpt)
vae = vae.to(dtype=vae_precision)
vae.requires_grad_(False)
vae = vae.to("cuda")
vae.eval()
autocast_type = torch.float32
fps = 24
return vae, autocast_type, fps
def load_text_encoder(model_type, pretrained_model_name_or_path, device):
if model_type == "mochi":
text_encoder = MochiTextEncoderWrapper(pretrained_model_name_or_path, device)
elif model_type == "hunyuan" or "hunyuan_hf":
text_encoder = HunyuanTextEncoderWrapper(pretrained_model_name_or_path, device)
else:
raise ValueError(f"Unsupported model type: {model_type}")
return text_encoder
def get_no_split_modules(transformer):
# if of type MochiTransformer3DModel
if isinstance(transformer, MochiTransformer3DModel):
return (MochiTransformerBlock, )
elif isinstance(transformer, HunyuanVideoTransformer3DModel):
return (HunyuanVideoSingleTransformerBlock, HunyuanVideoTransformerBlock)
elif isinstance(transformer, HYVideoDiffusionTransformer):
return (MMDoubleStreamBlock, MMSingleStreamBlock)
else:
raise ValueError(f"Unsupported transformer type: {type(transformer)}")
if __name__ == "__main__":
# test encode prompt
device = torch.cuda.current_device()
pretrained_model_name_or_path = "data/hunyuan"
text_encoder = load_text_encoder("hunyuan", pretrained_model_name_or_path, device)
prompt = "A man on stage claps his hands together while facing the audience. The audience, visible in the foreground, holds up mobile devices to record the event, capturing the moment from various angles. The background features a large banner with text identifying the man on stage. Throughout the sequence, the man's expression remains engaged and directed towards the audience. The camera angle remains constant, focusing on capturing the interaction between the man on stage and the audience."
prompt_embeds, attention_mask = text_encoder.encode_prompt(prompt)
import os
import pdb
import sys
def main_print(content):
if int(os.environ["LOCAL_RANK"]) <= 0:
print(content)
# ForkedPdb().set_trace()
class ForkedPdb(pdb.Pdb):
"""A Pdb subclass that may be used
from a forked multiprocessing child
"""
def interaction(self, *args, **kwargs):
_stdin = sys.stdin
try:
sys.stdin = open("/dev/stdin")
pdb.Pdb.interaction(self, *args, **kwargs)
finally:
sys.stdin = _stdin
import torch
from accelerate.logging import get_logger
logger = get_logger(__name__)
def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
# Optimizer creation
supported_optimizers = ["adam", "adamw", "prodigy"]
if args.optimizer not in supported_optimizers:
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include {supported_optimizers}. Defaulting to AdamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]):
logger.warning(f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}")
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
if args.optimizer.lower() == "adamw":
optimizer_class = (bnb.optim.AdamW8bit if args.use_8bit_adam else torch.optim.AdamW)
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "adam":
optimizer_class = bnb.optim.Adam8bit if args.use_8bit_adam else torch.optim.Adam
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
elif args.optimizer.lower() == "prodigy":
try:
import prodigyopt
except ImportError:
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
optimizer_class = prodigyopt.Prodigy
if args.learning_rate <= 0.1:
logger.warning(
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0")
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
beta3=args.prodigy_beta3,
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
decouple=args.prodigy_decouple,
use_bias_correction=args.prodigy_use_bias_correction,
safeguard_warmup=args.prodigy_safeguard_warmup,
)
return optimizer
import os
import torch.distributed as dist
class COMM_INFO:
def __init__(self):
self.group = None
self.sp_size = 1
self.global_rank = 0
self.rank_within_group = 0
self.group_id = 0
nccl_info = COMM_INFO()
_SEQUENCE_PARALLEL_STATE = False
def initialize_sequence_parallel_state(sequence_parallel_size):
global _SEQUENCE_PARALLEL_STATE
if sequence_parallel_size > 1:
_SEQUENCE_PARALLEL_STATE = True
initialize_sequence_parallel_group(sequence_parallel_size)
else:
nccl_info.sp_size = 1
nccl_info.global_rank = int(os.getenv("RANK", "0"))
nccl_info.rank_within_group = 0
nccl_info.group_id = int(os.getenv("RANK", "0"))
def set_sequence_parallel_state(state):
global _SEQUENCE_PARALLEL_STATE
_SEQUENCE_PARALLEL_STATE = state
def get_sequence_parallel_state():
return _SEQUENCE_PARALLEL_STATE
def initialize_sequence_parallel_group(sequence_parallel_size):
"""Initialize the sequence parallel group."""
rank = int(os.getenv("RANK", "0"))
world_size = int(os.getenv("WORLD_SIZE", "1"))
assert (
world_size % sequence_parallel_size == 0
), "world_size must be divisible by sequence_parallel_size, but got world_size: {}, sequence_parallel_size: {}".format(
world_size, sequence_parallel_size)
nccl_info.sp_size = sequence_parallel_size
nccl_info.global_rank = rank
num_sequence_parallel_groups: int = world_size // sequence_parallel_size
for i in range(num_sequence_parallel_groups):
ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
group = dist.new_group(ranks)
if rank in ranks:
nccl_info.group = group
nccl_info.rank_within_group = rank - i * sequence_parallel_size
nccl_info.group_id = i
def destroy_sequence_parallel_group():
"""Destroy the sequence parallel group."""
dist.destroy_process_group()
# isort: skip_file
import gc
import os
from typing import List, Optional, Union
import numpy as np
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import export_to_video
from diffusers.utils.torch_utils import randn_tensor
from diffusers.video_processor import VideoProcessor
from einops import rearrange
from tqdm import tqdm
import wandb
from fastvideo.distill.solver import PCMFMScheduler
from fastvideo.models.mochi_hf.pipeline_mochi import (linear_quadratic_schedule, retrieve_timesteps)
from fastvideo.utils.communications import all_gather
from fastvideo.utils.load import load_vae
from fastvideo.utils.parallel_states import (get_sequence_parallel_state, nccl_info)
def prepare_latents(
batch_size,
num_channels_latents,
height,
width,
num_frames,
dtype,
device,
generator,
vae_spatial_scale_factor,
vae_temporal_scale_factor,
):
height = height // vae_spatial_scale_factor
width = width // vae_spatial_scale_factor
num_frames = (num_frames - 1) // vae_temporal_scale_factor + 1
shape = (batch_size, num_channels_latents, num_frames, height, width)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def sample_validation_video(
model_type,
transformer,
vae,
scheduler,
scheduler_type="euler",
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: int = 16,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 4.5,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
vae_spatial_scale_factor=8,
vae_temporal_scale_factor=6,
num_channels_latents=12,
):
device = vae.device
batch_size = prompt_embeds.shape[0]
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare latent variables
# TODO: Remove hardcore
latents = prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
prompt_embeds.dtype,
device,
generator,
vae_spatial_scale_factor,
vae_temporal_scale_factor,
)
world_size, rank = nccl_info.sp_size, nccl_info.rank_within_group
if get_sequence_parallel_state():
latents = rearrange(latents, "b t (n s) h w -> b t n s h w", n=world_size).contiguous()
latents = latents[:, :, rank, :, :, :]
# 5. Prepare timestep
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
threshold_noise = 0.025
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
sigmas = np.array(sigmas)
if scheduler_type == "euler" and model_type == "mochi": #todo
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
)
else:
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps,
device,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
# 6. Denoising loop
# with self.progress_bar(total=num_inference_steps) as progress_bar:
# write with tqdm instead
# only enable if nccl_info.global_rank == 0
with tqdm(
total=num_inference_steps,
disable=nccl_info.rank_within_group != 0,
desc="Validation sampling...",
) as progress_bar:
for i, t in enumerate(timesteps):
latent_model_input = (torch.cat([latents] * 2) if do_classifier_free_guidance else latents)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])
with torch.autocast("cuda", dtype=torch.bfloat16):
noise_pred = transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
return_dict=False,
)[0]
# Mochi CFG + Sampling runs in FP32
noise_pred = noise_pred.to(torch.float32)
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
latents = latents.to(latents_dtype)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
progress_bar.update()
if get_sequence_parallel_state():
latents = all_gather(latents, dim=2)
if output_type == "latent":
video = latents
else:
# unscale/denormalize the latents
# denormalize with the mean and std if available and not None
has_latents_mean = (hasattr(vae.config, "latents_mean") and vae.config.latents_mean is not None)
has_latents_std = (hasattr(vae.config, "latents_std") and vae.config.latents_std is not None)
if has_latents_mean and has_latents_std:
latents_mean = (torch.tensor(vae.config.latents_mean).view(1, 12, 1, 1,
1).to(latents.device, latents.dtype))
latents_std = (torch.tensor(vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype))
latents = latents * latents_std / vae.config.scaling_factor + latents_mean
else:
latents = latents / vae.config.scaling_factor
with torch.autocast("cuda", dtype=vae.dtype):
video = vae.decode(latents, return_dict=False)[0]
video_processor = VideoProcessor(vae_scale_factor=vae_spatial_scale_factor)
video = video_processor.postprocess_video(video, output_type=output_type)
return (video, )
@torch.no_grad()
@torch.autocast("cuda", dtype=torch.bfloat16)
def log_validation(
args,
transformer,
device,
weight_dtype, # TODO
global_step,
scheduler_type="euler",
shift=1.0,
num_euler_timesteps=100,
linear_quadratic_threshold=0.025,
linear_range=0.5,
ema=False,
):
# TODO
print("Running validation....\n")
if args.model_type == "mochi":
vae_spatial_scale_factor = 8
vae_temporal_scale_factor = 6
num_channels_latents = 12
elif args.model_type == "hunyuan" or "hunyuan_hf":
vae_spatial_scale_factor = 8
vae_temporal_scale_factor = 4
num_channels_latents = 16
else:
raise ValueError(f"Model type {args.model_type} not supported")
vae, autocast_type, fps = load_vae(args.model_type, args.pretrained_model_name_or_path)
vae.enable_tiling()
if scheduler_type == "euler":
scheduler = FlowMatchEulerDiscreteScheduler(shift=shift)
else:
linear_quadraic = True if scheduler_type == "pcm_linear_quadratic" else False
scheduler = PCMFMScheduler(
1000,
shift,
num_euler_timesteps,
linear_quadraic,
linear_quadratic_threshold,
linear_range,
)
# args.validation_prompt_dir
validation_guidance_scale_ls = args.validation_guidance_scale.split(",")
validation_guidance_scale_ls = [float(scale) for scale in validation_guidance_scale_ls]
for validation_sampling_step in args.validation_sampling_steps.split(","):
validation_sampling_step = int(validation_sampling_step)
for validation_guidance_scale in validation_guidance_scale_ls:
videos = []
# prompt_embed are named embed0 to embedN
# check how many embeds are there
embe_dir = os.path.join(args.validation_prompt_dir, "prompt_embed")
mask_dir = os.path.join(args.validation_prompt_dir, "prompt_attention_mask")
embeds = sorted([f for f in os.listdir(embe_dir)])
masks = sorted([f for f in os.listdir(mask_dir)])
num_embeds = len(embeds)
validation_prompt_ids = list(range(num_embeds))
num_sp_groups = int(os.getenv("WORLD_SIZE", "1")) // nccl_info.sp_size
# pad to multiple of groups
if num_embeds % num_sp_groups != 0:
validation_prompt_ids += [0] * (num_sp_groups - num_embeds % num_sp_groups)
num_embeds_per_group = len(validation_prompt_ids) // num_sp_groups
local_prompt_ids = validation_prompt_ids[nccl_info.group_id *
num_embeds_per_group:(nccl_info.group_id + 1) *
num_embeds_per_group]
for i in local_prompt_ids:
prompt_embed_path = os.path.join(embe_dir, f"{embeds[i]}")
prompt_mask_path = os.path.join(mask_dir, f"{masks[i]}")
prompt_embeds = (torch.load(prompt_embed_path, map_location="cpu",
weights_only=True).to(device).unsqueeze(0))
prompt_attention_mask = (torch.load(prompt_mask_path, map_location="cpu",
weights_only=True).to(device).unsqueeze(0))
negative_prompt_embeds = torch.zeros(256, 4096).to(device).unsqueeze(0)
negative_prompt_attention_mask = (torch.zeros(256).bool().to(device).unsqueeze(0))
generator = torch.Generator(device="cpu").manual_seed(12345)
video = sample_validation_video(
args.model_type,
transformer,
vae,
scheduler,
scheduler_type=scheduler_type,
num_frames=args.num_frames,
height=args.num_height,
width=args.num_width,
num_inference_steps=validation_sampling_step,
guidance_scale=validation_guidance_scale,
generator=generator,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_embeds=negative_prompt_embeds,
negative_prompt_attention_mask=negative_prompt_attention_mask,
vae_spatial_scale_factor=vae_spatial_scale_factor,
vae_temporal_scale_factor=vae_temporal_scale_factor,
num_channels_latents=num_channels_latents,
)[0]
if nccl_info.rank_within_group == 0:
videos.append(video[0])
# collect videos from all process to process zero
gc.collect()
torch.cuda.empty_cache()
# log if main process
torch.distributed.barrier()
all_videos = [None for i in range(int(os.getenv("WORLD_SIZE", "1")))] # remove padded videos
torch.distributed.all_gather_object(all_videos, videos)
if nccl_info.global_rank == 0:
# remove padding
videos = [video for videos in all_videos for video in videos]
videos = videos[:num_embeds]
# linearize all videos
video_filenames = []
for i, video in enumerate(videos):
filename = os.path.join(
args.output_dir,
f"validation_step_{global_step}_sample_{validation_sampling_step}_guidance_{validation_guidance_scale}_video_{i}.mp4",
)
export_to_video(video, filename, fps=fps)
video_filenames.append(filename)
logs = {
f"{'ema_' if ema else ''}validation_sample_{validation_sampling_step}_guidance_{validation_guidance_scale}":
[wandb.Video(filename) for i, filename in enumerate(video_filenames)]
}
wandb.log(logs, step=global_step)
# SPDX-License-Identifier: Apache-2.0
from fastvideo.v1.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
AttentionMetadataBuilder)
from fastvideo.v1.attention.layer import DistributedAttention, LocalAttention
from fastvideo.v1.attention.selector import get_attn_backend
__all__ = [
"DistributedAttention",
"LocalAttention",
"AttentionBackend",
"AttentionMetadata",
"AttentionMetadataBuilder",
# "AttentionState",
"get_attn_backend",
]
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/backends/abstract.py
from abc import ABC, abstractmethod
from dataclasses import dataclass, fields
from typing import (TYPE_CHECKING, Any, Dict, Generic, Optional, Protocol, Set,
Type, TypeVar)
if TYPE_CHECKING:
from fastvideo.v1.fastvideo_args import FastVideoArgs
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
import torch
class AttentionBackend(ABC):
"""Abstract class for attention backends."""
# For some attention backends, we allocate an output tensor before
# calling the custom op. When piecewise cudagraph is enabled, this
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False
@staticmethod
@abstractmethod
def get_name() -> str:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_impl_cls() -> Type["AttentionImpl"]:
raise NotImplementedError
@staticmethod
@abstractmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError
# @staticmethod
# @abstractmethod
# def get_state_cls() -> Type["AttentionState"]:
# raise NotImplementedError
# @classmethod
# def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata":
# return cls.get_metadata_cls()(*args, **kwargs)
@staticmethod
@abstractmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
@dataclass
class AttentionMetadata:
"""Attention metadata for prefill and decode batched together."""
# Current step of diffusion process
current_timestep: int
# @property
# @abstractmethod
# def inference_metadata(self) -> Optional["AttentionMetadata"]:
# """Return the attention metadata that's required to run prefill
# attention."""
# pass
# @property
# @abstractmethod
# def training_metadata(self) -> Optional["AttentionMetadata"]:
# """Return the attention metadata that's required to run decode
# attention."""
# pass
def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
) -> Dict[str, Any]:
"""Similar to dataclasses.asdict, but avoids deepcopying."""
if skip_fields is None:
skip_fields = set()
# Note that if we add dataclasses as fields, they will need
# similar handling.
return {
field.name: getattr(self, field.name)
for field in fields(self) if field.name not in skip_fields
}
T = TypeVar("T", bound=AttentionMetadata)
# class AttentionState(ABC, Generic[T]):
# """Holds attention backend-specific objects reused during the
# lifetime of the model runner."""
# @abstractmethod
# def __init__(self, runner: "ModelRunnerBase"):
# ...
# @abstractmethod
# @contextmanager
# def graph_capture(self, max_batch_size: int):
# """Context manager used when capturing CUDA graphs."""
# yield
# @abstractmethod
# def graph_clone(self, batch_size: int) -> "AttentionState[T]":
# """Clone attention state to save in CUDA graph metadata."""
# ...
# @abstractmethod
# def graph_capture_get_metadata_for_batch(
# self,
# batch_size: int,
# is_encoder_decoder_model: bool = False) -> T:
# """Get attention metadata for CUDA graph capture of batch_size."""
# ...
# @abstractmethod
# def get_graph_input_buffers(
# self,
# attn_metadata: T,
# is_encoder_decoder_model: bool = False) -> Dict[str, Any]:
# """Get attention-specific input buffers for CUDA graph capture."""
# ...
# @abstractmethod
# def prepare_graph_input_buffers(
# self,
# input_buffers: Dict[str, Any],
# attn_metadata: T,
# is_encoder_decoder_model: bool = False) -> None:
# """In-place modify input buffers dict for CUDA graph replay."""
# ...
# @abstractmethod
# def begin_forward(self, model_input: "ModelRunnerInputBase") -> None:
# """Prepare state for forward pass."""
# ...
class AttentionMetadataBuilder(ABC, Generic[T]):
"""Abstract class for attention metadata builders."""
@abstractmethod
def __init__(self) -> None:
"""Create the builder, remember some configuration and parameters."""
raise NotImplementedError
@abstractmethod
def prepare(self) -> None:
"""Prepare for one batch."""
raise NotImplementedError
@abstractmethod
def build(
self,
current_timestep: int,
forward_batch: "ForwardBatch",
fastvideo_args: "FastVideoArgs",
) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError
class AttentionLayer(Protocol):
_k_scale: torch.Tensor
_v_scale: torch.Tensor
_k_scale_float: float
_v_scale_float: float
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
...
class AttentionImpl(ABC, Generic[T]):
@abstractmethod
def __init__(
self,
num_heads: int,
head_size: int,
softmax_scale: float,
causal: bool = False,
num_kv_heads: Optional[int] = None,
prefix: str = "",
**extra_impl_args,
) -> None:
raise NotImplementedError
def preprocess_qkv(self, qkv: torch.Tensor,
attn_metadata: T) -> torch.Tensor:
"""Preprocess QKV tensor before performing attention operation.
Default implementation returns the tensor unchanged.
Subclasses can override this to implement custom preprocessing
like reshaping, tiling, scaling, or other transformations.
Called AFTER all_to_all for distributed attention
Args:
qkv: The query-key-value tensor
attn_metadata: Metadata for the attention operation
Returns:
Processed QKV tensor
"""
return qkv
def postprocess_output(
self,
output: torch.Tensor,
attn_metadata: T,
) -> torch.Tensor:
"""Postprocess the output tensor after the attention operation.
Default implementation returns the tensor unchanged.
Subclasses can override this to implement custom postprocessing
like untiling, scaling, or other transformations.
Called BEFORE all_to_all for distributed attention
Args:
output: The output tensor from the attention operation
attn_metadata: Metadata for the attention operation
Returns:
Postprocessed output tensor
"""
return output
@abstractmethod
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: T,
) -> torch.Tensor:
raise NotImplementedError
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Type
import torch
from flash_attn import flash_attn_func as flash_attn_2_func
try:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
# flash_attn 3 has slightly different API: it returns lse by default
flash_attn_func = lambda q, k, v, softmax_scale, causal: flash_attn_3_func(
q, k, v, softmax_scale, causal)[0]
except ImportError:
flash_attn_func = flash_attn_2_func
from fastvideo.v1.attention.backends.abstract import (AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder)
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "FLASH_ATTN"
@staticmethod
def get_impl_cls() -> Type["FlashAttentionImpl"]:
return FlashAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
raise NotImplementedError
@staticmethod
def get_builder_cls() -> Type["AttentionMetadataBuilder"]:
raise NotImplementedError
class FlashAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: Optional[int] = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
):
output = flash_attn_func(
query, # type: ignore[no-untyped-call]
key,
value,
softmax_scale=self.softmax_scale,
causal=self.causal)
return output
from typing import List, Optional, Type
import torch
from sageattention import sageattn
from fastvideo.v1.attention.backends.abstract import (
AttentionBackend) # FlashAttentionMetadata,
from fastvideo.v1.attention.backends.abstract import (AttentionImpl,
AttentionMetadata)
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
class SageAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "SAGE_ATTN"
@staticmethod
def get_impl_cls() -> Type["SageAttentionImpl"]:
return SageAttentionImpl
# @staticmethod
# def get_metadata_cls() -> Type["AttentionMetadata"]:
# return FlashAttentionMetadata
class SageAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: Optional[int] = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout = extra_impl_args.get("dropout_p", 0.0)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
output = sageattn(
query,
key,
value,
# since input is (batch_size, seq_len, head_num, head_dim)
tensor_layout="NHD",
is_causal=self.causal)
return output
from typing import List, Optional, Type
import torch
from fastvideo.v1.attention.backends.abstract import (
AttentionBackend) # FlashAttentionMetadata,
from fastvideo.v1.attention.backends.abstract import (AttentionImpl,
AttentionMetadata)
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
class SDPABackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "SDPA"
@staticmethod
def get_impl_cls() -> Type["SDPAImpl"]:
return SDPAImpl
# @staticmethod
# def get_metadata_cls() -> Type["AttentionMetadata"]:
# return FlashAttentionMetadata
class SDPAImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: Optional[int] = None,
prefix: str = "",
**extra_impl_args,
) -> None:
self.causal = causal
self.softmax_scale = softmax_scale
self.dropout = extra_impl_args.get("dropout_p", 0.0)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
# transpose to bs, heads, seq_len, head_dim
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
attn_kwargs = {
"attn_mask": None,
"dropout_p": self.dropout,
"is_causal": self.causal,
"scale": self.softmax_scale
}
if query.shape[1] != key.shape[1]:
attn_kwargs["enable_gqa"] = True
output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, **attn_kwargs)
output = output.transpose(1, 2)
return output
import json
from dataclasses import dataclass
from typing import List, Optional, Type
import torch
from einops import rearrange
from st_attn import sliding_tile_attention
import fastvideo.v1.envs as envs
from fastvideo.v1.attention.backends.abstract import (AttentionBackend,
AttentionImpl,
AttentionMetadata,
AttentionMetadataBuilder)
from fastvideo.v1.distributed import get_sp_group
from fastvideo.v1.fastvideo_args import FastVideoArgs
from fastvideo.v1.logger import init_logger
from fastvideo.v1.pipelines.pipeline_batch_info import ForwardBatch
logger = init_logger(__name__)
# TODO(will-refactor): move this to a utils file
def dict_to_3d_list(mask_strategy) -> List[List[List[Optional[torch.Tensor]]]]:
indices = [tuple(map(int, key.split('_'))) for key in mask_strategy]
max_timesteps_idx = max(
timesteps_idx for timesteps_idx, layer_idx, head_idx in indices) + 1
max_layer_idx = max(layer_idx
for timesteps_idx, layer_idx, head_idx in indices) + 1
max_head_idx = max(head_idx
for timesteps_idx, layer_idx, head_idx in indices) + 1
result = [[[None for _ in range(max_head_idx)]
for _ in range(max_layer_idx)] for _ in range(max_timesteps_idx)]
for key, value in mask_strategy.items():
timesteps_idx, layer_idx, head_idx = map(int, key.split('_'))
result[timesteps_idx][layer_idx][head_idx] = value
return result
class RangeDict(dict):
def __getitem__(self, item):
for key in self.keys():
if isinstance(key, tuple):
low, high = key
if low <= item <= high:
return super().__getitem__(key)
elif key == item:
return super().__getitem__(key)
raise KeyError(f"seq_len {item} not supported for STA")
class SlidingTileAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_supported_head_sizes() -> List[int]:
# TODO(will-refactor): check this
return [32, 64, 96, 128, 160, 192, 224, 256]
@staticmethod
def get_name() -> str:
return "SLIDING_TILE_ATTN"
@staticmethod
def get_impl_cls() -> Type["SlidingTileAttentionImpl"]:
return SlidingTileAttentionImpl
@staticmethod
def get_metadata_cls() -> Type["SlidingTileAttentionMetadata"]:
return SlidingTileAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["SlidingTileAttentionMetadataBuilder"]:
return SlidingTileAttentionMetadataBuilder
@dataclass
class SlidingTileAttentionMetadata(AttentionMetadata):
current_timestep: int
class SlidingTileAttentionMetadataBuilder(AttentionMetadataBuilder):
def __init__(self):
pass
def prepare(self):
pass
def build(
self,
current_timestep: int,
forward_batch: ForwardBatch,
fastvideo_args: FastVideoArgs,
) -> SlidingTileAttentionMetadata:
return SlidingTileAttentionMetadata(current_timestep=current_timestep, )
class SlidingTileAttentionImpl(AttentionImpl):
def __init__(
self,
num_heads: int,
head_size: int,
causal: bool,
softmax_scale: float,
num_kv_heads: Optional[int] = None,
prefix: str = "",
**extra_impl_args,
) -> None:
# TODO(will-refactor): for now this is the mask strategy, but maybe we should
# have a more general config for STA?
config_file = envs.FASTVIDEO_ATTENTION_CONFIG
if config_file is None:
raise ValueError("FASTVIDEO_ATTENTION_CONFIG is not set")
with open(config_file) as f:
mask_strategy = json.load(f)
mask_strategy = dict_to_3d_list(mask_strategy)
self.prefix = prefix
self.mask_strategy = mask_strategy
sp_group = get_sp_group()
self.sp_size = sp_group.world_size
# STA config
self.STA_base_tile_size = [6, 8, 8]
self.img_latent_shape_mapping = RangeDict({
(115200, 115456): '30x48x80',
82944: '36x48x48',
69120: '18x48x80',
})
self.full_window_mapping = {
'30x48x80': [5, 6, 10],
'36x48x48': [6, 6, 6],
'18x48x80': [3, 6, 10]
}
def tile(self, x: torch.Tensor) -> torch.Tensor:
x = rearrange(x,
"b (sp t h w) head d -> b (t sp h w) head d",
sp=self.sp_size,
t=self.img_latent_shape_int[0] // self.sp_size,
h=self.img_latent_shape_int[1],
w=self.img_latent_shape_int[2])
return rearrange(
x,
"b (n_t ts_t n_h ts_h n_w ts_w) h d -> b (n_t n_h n_w ts_t ts_h ts_w) h d",
n_t=self.full_window_size[0],
n_h=self.full_window_size[1],
n_w=self.full_window_size[2],
ts_t=self.STA_base_tile_size[0],
ts_h=self.STA_base_tile_size[1],
ts_w=self.STA_base_tile_size[2])
def untile(self, x: torch.Tensor) -> torch.Tensor:
x = rearrange(
x,
"b (n_t n_h n_w ts_t ts_h ts_w) h d -> b (n_t ts_t n_h ts_h n_w ts_w) h d",
n_t=self.full_window_size[0],
n_h=self.full_window_size[1],
n_w=self.full_window_size[2],
ts_t=self.STA_base_tile_size[0],
ts_h=self.STA_base_tile_size[1],
ts_w=self.STA_base_tile_size[2])
return rearrange(x,
"b (t sp h w) head d -> b (sp t h w) head d",
sp=self.sp_size,
t=self.img_latent_shape_int[0] // self.sp_size,
h=self.img_latent_shape_int[1],
w=self.img_latent_shape_int[2])
def preprocess_qkv(
self,
qkv: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
img_sequence_length = qkv.shape[1]
self.img_latent_shape_str = self.img_latent_shape_mapping[
img_sequence_length]
self.full_window_size = self.full_window_mapping[
self.img_latent_shape_str]
self.img_latent_shape_int = list(
map(int, self.img_latent_shape_str.split('x')))
self.img_seq_length = self.img_latent_shape_int[
0] * self.img_latent_shape_int[1] * self.img_latent_shape_int[2]
return self.tile(qkv)
def postprocess_output(
self,
output: torch.Tensor,
attn_metadata: SlidingTileAttentionMetadata,
) -> torch.Tensor:
return self.untile(output)
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
attn_metadata: SlidingTileAttentionMetadata,
) -> torch.Tensor:
assert self.mask_strategy is not None, "mask_strategy cannot be None for SlidingTileAttention"
assert self.mask_strategy[
0] is not None, "mask_strategy[0] cannot be None for SlidingTileAttention"
timestep = attn_metadata.current_timestep
# pattern:'.double_blocks.0.attn.impl' or '.single_blocks.0.attn.impl'
layer_idx = int(self.prefix.split('.')[-3])
# TODO: remove hardcode
text_length = q.shape[1] - self.img_seq_length
has_text = text_length > 0
query = q.transpose(1, 2).contiguous()
key = k.transpose(1, 2).contiguous()
value = v.transpose(1, 2).contiguous()
head_num = query.size(1)
sp_group = get_sp_group()
current_rank = sp_group.rank_in_group
start_head = current_rank * head_num
windows = [
self.mask_strategy[timestep][layer_idx][head_idx + start_head]
for head_idx in range(head_num)
]
# if has_text is False:
# from IPython import embed
# embed()
hidden_states = sliding_tile_attention(
query, key, value, windows, text_length, has_text,
self.img_latent_shape_str).transpose(1, 2)
return hidden_states
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
import torch.nn as nn
from fastvideo.v1.attention.selector import (backend_name_to_enum,
get_attn_backend)
from fastvideo.v1.distributed.communication_op import (
sequence_model_parallel_all_gather, sequence_model_parallel_all_to_all_4D)
from fastvideo.v1.distributed.parallel_state import (
get_sequence_model_parallel_rank, get_sequence_model_parallel_world_size)
from fastvideo.v1.forward_context import ForwardContext, get_forward_context
from fastvideo.v1.platforms import _Backend
class DistributedAttention(nn.Module):
"""Distributed attention layer.
"""
def __init__(self,
num_heads: int,
head_size: int,
num_kv_heads: Optional[int] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
supported_attention_backends: Optional[Tuple[_Backend,
...]] = None,
prefix: str = "",
**extra_impl_args) -> None:
super().__init__()
if softmax_scale is None:
self.softmax_scale = head_size**-0.5
else:
self.softmax_scale = softmax_scale
if num_kv_heads is None:
num_kv_heads = num_heads
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(
head_size,
dtype,
supported_attention_backends=supported_attention_backends)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads=num_heads,
head_size=head_size,
causal=causal,
softmax_scale=self.softmax_scale,
num_kv_heads=num_kv_heads,
prefix=f"{prefix}.impl",
**extra_impl_args)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
replicated_q: Optional[torch.Tensor] = None,
replicated_k: Optional[torch.Tensor] = None,
replicated_v: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward pass for distributed attention.
Args:
q (torch.Tensor): Query tensor [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): Key tensor [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor): Value tensor [batch_size, seq_len, num_heads, head_dim]
replicated_q (Optional[torch.Tensor]): Replicated query tensor, typically for text tokens
replicated_k (Optional[torch.Tensor]): Replicated key tensor
replicated_v (Optional[torch.Tensor]): Replicated value tensor
Returns:
Tuple[torch.Tensor, Optional[torch.Tensor]]: A tuple containing:
- o (torch.Tensor): Output tensor after attention for the main sequence
- replicated_o (Optional[torch.Tensor]): Output tensor for replicated tokens, if provided
"""
# Check input shapes
assert q.dim() == 4 and k.dim() == 4 and v.dim(
) == 4, "Expected 4D tensors"
# assert bs = 1
assert q.shape[
0] == 1, "Batch size must be 1, and there should be no padding tokens"
batch_size, seq_len, num_heads, head_dim = q.shape
local_rank = get_sequence_model_parallel_rank()
world_size = get_sequence_model_parallel_world_size()
forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
# Stack QKV
qkv = torch.cat([q, k, v], dim=0) # [3, seq_len, num_heads, head_dim]
# Redistribute heads across sequence dimension
qkv = sequence_model_parallel_all_to_all_4D(qkv,
scatter_dim=2,
gather_dim=1)
# Apply backend-specific preprocess_qkv
qkv = self.impl.preprocess_qkv(qkv, ctx_attn_metadata)
# Concatenate with replicated QKV if provided
if replicated_q is not None:
assert replicated_k is not None and replicated_v is not None
replicated_qkv = torch.cat(
[replicated_q, replicated_k, replicated_v],
dim=0) # [3, seq_len, num_heads, head_dim]
heads_per_rank = num_heads // world_size
replicated_qkv = replicated_qkv[:, :, local_rank *
heads_per_rank:(local_rank + 1) *
heads_per_rank]
qkv = torch.cat([qkv, replicated_qkv], dim=1)
q, k, v = qkv.chunk(3, dim=0)
output = self.impl.forward(q, k, v, ctx_attn_metadata)
# Redistribute back if using sequence parallelism
replicated_output = None
if replicated_q is not None:
replicated_output = output[:, seq_len * world_size:]
output = output[:, :seq_len * world_size]
# TODO: make this asynchronous
replicated_output = sequence_model_parallel_all_gather(
replicated_output.contiguous(), dim=2)
# Apply backend-specific postprocess_output
output = self.impl.postprocess_output(output, ctx_attn_metadata)
output = sequence_model_parallel_all_to_all_4D(output,
scatter_dim=1,
gather_dim=2)
return output, replicated_output
class LocalAttention(nn.Module):
"""Attention layer.
"""
def __init__(self,
num_heads: int,
head_size: int,
num_kv_heads: Optional[int] = None,
softmax_scale: Optional[float] = None,
causal: bool = False,
supported_attention_backends: Optional[Tuple[_Backend,
...]] = None,
**extra_impl_args) -> None:
super().__init__()
if softmax_scale is None:
self.softmax_scale = head_size**-0.5
else:
self.softmax_scale = softmax_scale
if num_kv_heads is None:
num_kv_heads = num_heads
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(
head_size,
dtype,
supported_attention_backends=supported_attention_backends)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads=num_heads,
head_size=head_size,
softmax_scale=self.softmax_scale,
num_kv_heads=num_kv_heads,
causal=causal,
**extra_impl_args)
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.backend = backend_name_to_enum(attn_backend.get_name())
self.dtype = dtype
def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
"""
Apply local attention between query, key and value tensors.
Args:
q (torch.Tensor): Query tensor of shape [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): Key tensor of shape [batch_size, seq_len, num_heads, head_dim]
v (torch.Tensor): Value tensor of shape [batch_size, seq_len, num_heads, head_dim]
Returns:
torch.Tensor: Output tensor after local attention
"""
# Check input shapes
assert q.dim() == 4 and k.dim() == 4 and v.dim(
) == 4, "Expected 4D tensors"
forward_context: ForwardContext = get_forward_context()
ctx_attn_metadata = forward_context.attn_metadata
output = self.impl.forward(q, k, v, ctx_attn_metadata)
return output
# SPDX-License-Identifier: Apache-2.0
# Adapted from vllm: https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/attention/selector.py
import os
from contextlib import contextmanager
from functools import cache
from typing import Generator, Optional, Tuple, Type, cast
import torch
import fastvideo.v1.envs as envs
from fastvideo.v1.attention.backends.abstract import AttentionBackend
from fastvideo.v1.logger import init_logger
from fastvideo.v1.platforms import _Backend, current_platform
from fastvideo.v1.utils import STR_BACKEND_ENV_VAR, resolve_obj_by_qualname
logger = init_logger(__name__)
def backend_name_to_enum(backend_name: str) -> Optional[_Backend]:
"""
Convert a string backend name to a _Backend enum value.
Returns:
* _Backend: enum value if backend_name is a valid in-tree type
* None: otherwise it's an invalid in-tree type or an out-of-tree platform is
loaded.
"""
assert backend_name is not None
return _Backend[backend_name] if backend_name in _Backend.__members__ else \
None
def get_env_variable_attn_backend() -> Optional[_Backend]:
'''
Get the backend override specified by the FastVideo attention
backend environment variable, if one is specified.
Returns:
* _Backend enum value if an override is specified
* None otherwise
'''
backend_name = os.environ.get(STR_BACKEND_ENV_VAR)
return (None
if backend_name is None else backend_name_to_enum(backend_name))
# Global state allows a particular choice of backend
# to be forced, overriding the logic which auto-selects
# a backend based on system & workload configuration
# (default behavior if this variable is None)
#
# THIS SELECTION TAKES PRECEDENCE OVER THE
# FASTVIDEO ATTENTION BACKEND ENVIRONMENT VARIABLE
forced_attn_backend: Optional[_Backend] = None
def global_force_attn_backend(attn_backend: Optional[_Backend]) -> None:
'''
Force all attention operations to use a specified backend.
Passing `None` for the argument re-enables automatic
backend selection.,
Arguments:
* attn_backend: backend selection (None to revert to auto)
'''
global forced_attn_backend
forced_attn_backend = attn_backend
def get_global_forced_attn_backend() -> Optional[_Backend]:
'''
Get the currently-forced choice of attention backend,
or None if auto-selection is currently enabled.
'''
return forced_attn_backend
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
supported_attention_backends: Optional[Tuple[_Backend, ...]] = None,
) -> Type[AttentionBackend]:
return _cached_get_attn_backend(head_size, dtype,
supported_attention_backends)
@cache
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
supported_attention_backends: Optional[Tuple[_Backend, ...]] = None,
) -> Type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE FASTVIDEO_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
if not supported_attention_backends:
raise ValueError("supported_attention_backends is empty")
selected_backend = None
backend_by_global_setting: Optional[_Backend] = (
get_global_forced_attn_backend())
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.FASTVIDEO_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
# get device-specific attn_backend
if selected_backend not in supported_attention_backends:
selected_backend = None
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}")
return cast(Type[AttentionBackend], resolve_obj_by_qualname(attention_cls))
@contextmanager
def global_force_attn_backend_context_manager(
attn_backend: _Backend) -> Generator[None, None, None]:
'''
Globally force a FastVideo attention backend override within a
context manager, reverting the global attention backend
override to its prior state upon exiting the context
manager.
Arguments:
* attn_backend: attention backend to force
Returns:
* Generator
'''
# Save the current state of the global backend override (if any)
original_value = get_global_forced_attn_backend()
# Globally force the new backend override
global_force_attn_backend(attn_backend)
# Yield control back to the enclosed code block
try:
yield
finally:
# Revert the original global backend override, if any
global_force_attn_backend(original_value)
from fastvideo.v1.configs.models.base import ModelConfig
from fastvideo.v1.configs.models.dits.base import DiTConfig
from fastvideo.v1.configs.models.encoders.base import EncoderConfig
from fastvideo.v1.configs.models.vaes.base import VAEConfig
__all__ = ["ModelConfig", "VAEConfig", "DiTConfig", "EncoderConfig"]
from dataclasses import dataclass, field, fields
from typing import Any, Dict
from fastvideo.v1.logger import init_logger
logger = init_logger(__name__)
# 1. ArchConfig contains all fields from diffuser's/transformer's config.json (i.e. all fields related to the architecture of the model)
# 2. ArchConfig should be inherited & overridden by each model arch_config
# 3. Any field in ArchConfig is fixed upon initialization, and should be hidden away from users
@dataclass
class ArchConfig:
pass
@dataclass
class ModelConfig:
# Every model config parameter can be categorized into either ArchConfig or everything else
# Diffuser/Transformer parameters
arch_config: ArchConfig = field(default_factory=ArchConfig)
# FastVideo-specific parameters here
# i.e. STA, quantization, teacache
def __getattr__(self, name):
# Only called if 'name' is not found in ModelConfig directly
if hasattr(self.arch_config, name):
return getattr(self.arch_config, name)
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'")
def __getstate__(self):
# Return a dictionary of attributes to pickle
# Convert to dict and exclude any problematic attributes
state = self.__dict__.copy()
return state
def __setstate__(self, state):
# Restore instance attributes from the unpickled state
self.__dict__.update(state)
# This should be used only when loading from transformers/diffusers
def update_model_arch(self, source_model_dict: Dict[str, Any]) -> None:
arch_config = self.arch_config
valid_fields = {f.name for f in fields(arch_config)}
for key, value in source_model_dict.items():
if key in valid_fields:
setattr(arch_config, key, value)
else:
raise AttributeError(
f"{type(arch_config).__name__} has no field '{key}'")
if hasattr(arch_config, "__post_init__"):
arch_config.__post_init__()
def update_model_config(self, source_model_dict: Dict[str, Any]) -> None:
assert "arch_config" not in source_model_dict, "Source model config shouldn't contain arch_config."
valid_fields = {f.name for f in fields(self)}
for key, value in source_model_dict.items():
if key in valid_fields:
setattr(self, key, value)
else:
logger.warning("%s does not contain field '%s'!",
type(self).__name__, key)
raise AttributeError(f"Invalid field: {key}")
if hasattr(self, "__post_init__"):
self.__post_init__()
from fastvideo.v1.configs.models.dits.hunyuanvideo import HunyuanVideoConfig
from fastvideo.v1.configs.models.dits.wanvideo import WanVideoConfig
__all__ = ["HunyuanVideoConfig", "WanVideoConfig"]
from dataclasses import dataclass, field
from typing import Optional, Tuple
from fastvideo.v1.configs.models.base import ArchConfig, ModelConfig
from fastvideo.v1.configs.quantization import QuantizationConfig
from fastvideo.v1.platforms import _Backend
@dataclass
class DiTArchConfig(ArchConfig):
_fsdp_shard_conditions: list = field(default_factory=list)
_param_names_mapping: dict = field(default_factory=dict)
_supported_attention_backends: Tuple[_Backend,
...] = (_Backend.SLIDING_TILE_ATTN,
_Backend.SAGE_ATTN,
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA)
hidden_size: int = 0
num_attention_heads: int = 0
num_channels_latents: int = 0
@dataclass
class DiTConfig(ModelConfig):
arch_config: DiTArchConfig = field(default_factory=DiTArchConfig)
# FastVideoDiT-specific parameters
prefix: str = ""
quant_config: Optional[QuantizationConfig] = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment