"llama/vscode:/vscode.git/clone" did not exist on "c68f367ef6688972de6798e631a7aa50c48af763"
Unverified Commit 69c2f650 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Remove outdated models (#348)

parent 08d2f46a
from lightx2v.models.input_encoders.hf.t5_v1_1_xxl.model import T5EncoderModel_v1_1_xxl
from lightx2v.models.networks.cogvideox.model import CogvideoxModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.cogvideox.scheduler import CogvideoxXDPMScheduler
from lightx2v.models.video_encoders.hf.cogvideox.model import CogvideoxVAE
from lightx2v.utils.registry_factory import RUNNER_REGISTER
@RUNNER_REGISTER("cogvideox")
class CogvideoxRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self):
model = CogvideoxModel(self.config)
return model
def load_image_encoder(self):
return None
def load_text_encoder(self):
text_encoder = T5EncoderModel_v1_1_xxl(self.config)
text_encoders = [text_encoder]
return text_encoders
def load_vae(self):
vae_model = CogvideoxVAE(self.config)
return vae_model, vae_model
def init_scheduler(self):
self.scheduler = CogvideoxXDPMScheduler(self.config)
def run_text_encoder(self, text, img):
text_encoder_output = {}
n_prompt = self.config.get("negative_prompt", "")
context = self.text_encoders[0].infer([text], self.config)
context_null = self.text_encoders[0].infer([n_prompt if n_prompt else ""], self.config)
text_encoder_output["context"] = context
text_encoder_output["context_null"] = context_null
return text_encoder_output
def run_vae_encoder(self, img):
# TODO: implement vae encoder for Cogvideox
raise NotImplementedError("I2V inference is not implemented for Cogvideox.")
def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
# TODO: Implement image encoder for Cogvideox-I2V
raise ValueError(f"Unsupported model class: {self.config['model_cls']}")
def set_target_shape(self):
ret = {}
if self.config.task == "i2v":
# TODO: implement set_target_shape for Cogvideox-I2V
raise NotImplementedError("I2V inference is not implemented for Cogvideox.")
else:
num_frames = self.config.target_video_length
latent_frames = (num_frames - 1) // self.config.vae_scale_factor_temporal + 1
additional_frames = 0
patch_size_t = self.config.patch_size_t
if patch_size_t is not None and latent_frames % patch_size_t != 0:
additional_frames = patch_size_t - latent_frames % patch_size_t
num_frames += additional_frames * self.config.vae_scale_factor_temporal
self.config.target_shape = (
self.config.batch_size,
(num_frames - 1) // self.config.vae_scale_factor_temporal + 1,
self.config.latent_channels,
self.config.height // self.config.vae_scale_factor_spatial,
self.config.width // self.config.vae_scale_factor_spatial,
)
ret["target_shape"] = self.config.target_shape
return ret
import os
import numpy as np
import torch
import torchvision
from lightx2v.models.input_encoders.hf.clip.model import TextEncoderHFClipModel
from lightx2v.models.input_encoders.hf.llama.model import TextEncoderHFLlamaModel
from lightx2v.models.input_encoders.hf.llava.model import TextEncoderHFLlavaModel
from lightx2v.models.networks.hunyuan.model import HunyuanModel
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan.feature_caching.scheduler import HunyuanSchedulerAdaCaching, HunyuanSchedulerCustomCaching, HunyuanSchedulerTaylorCaching, HunyuanSchedulerTeaCaching
from lightx2v.models.schedulers.hunyuan.scheduler import HunyuanScheduler
from lightx2v.models.video_encoders.hf.hunyuan.hunyuan_vae import HunyuanVAE
from lightx2v.utils.envs import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
@RUNNER_REGISTER("hunyuan")
class HunyuanRunner(DefaultRunner):
def __init__(self, config):
super().__init__(config)
def load_transformer(self):
return HunyuanModel(self.config.model_path, self.config, self.init_device, self.config)
def load_image_encoder(self):
return None
def load_text_encoder(self):
if self.config.task == "t2v":
text_encoder_1 = TextEncoderHFLlamaModel(os.path.join(self.config.model_path, "text_encoder"), self.init_device)
else:
text_encoder_1 = TextEncoderHFLlavaModel(os.path.join(self.config.model_path, "text_encoder_i2v"), self.init_device)
text_encoder_2 = TextEncoderHFClipModel(os.path.join(self.config.model_path, "text_encoder_2"), self.init_device)
text_encoders = [text_encoder_1, text_encoder_2]
return text_encoders
def load_vae(self):
vae_model = HunyuanVAE(self.config.model_path, dtype=torch.float16, device=self.init_device, config=self.config)
return vae_model, vae_model
def init_scheduler(self):
if self.config.feature_caching == "NoCaching":
scheduler = HunyuanScheduler(self.config)
elif self.config.feature_caching == "Tea":
scheduler = HunyuanSchedulerTeaCaching(self.config)
elif self.config.feature_caching == "TaylorSeer":
scheduler = HunyuanSchedulerTaylorCaching(self.config)
elif self.config.feature_caching == "Ada":
scheduler = HunyuanSchedulerAdaCaching(self.config)
elif self.config.feature_caching == "Custom":
scheduler = HunyuanSchedulerCustomCaching(self.config)
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, img):
text_encoder_output = {}
for i, encoder in enumerate(self.text_encoders):
if self.config.task == "i2v" and i == 0:
text_state, attention_mask = encoder.infer(text, img, self.config)
else:
text_state, attention_mask = encoder.infer(text, self.config)
text_encoder_output[f"text_encoder_{i + 1}_text_states"] = text_state.to(dtype=GET_DTYPE())
text_encoder_output[f"text_encoder_{i + 1}_attention_mask"] = attention_mask
return text_encoder_output
@staticmethod
def get_closest_ratio(height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
diff_ratios = ratios - aspect_ratio
if aspect_ratio >= 1:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
else:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]
closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
closest_size = buckets[closest_ratio_id]
closest_ratio = ratios[closest_ratio_id]
return closest_size, closest_ratio
@staticmethod
def generate_crop_size_list(base_size=256, patch_size=32, max_ratio=4.0):
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def run_image_encoder(self, img):
return None
def run_vae_encoder(self, img):
kwargs = {}
if self.config.i2v_resolution == "720p":
bucket_hw_base_size = 960
elif self.config.i2v_resolution == "540p":
bucket_hw_base_size = 720
elif self.config.i2v_resolution == "360p":
bucket_hw_base_size = 480
else:
raise ValueError(f"self.config.i2v_resolution: {self.config.i2v_resolution} must be in [360p, 540p, 720p]")
origin_size = img.size
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, 32)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
self.config.target_height, self.config.target_width = closest_size
kwargs["target_height"], kwargs["target_width"] = closest_size
resize_param = min(closest_size)
center_crop_param = closest_size
ref_image_transform = torchvision.transforms.Compose(
[torchvision.transforms.Resize(resize_param), torchvision.transforms.CenterCrop(center_crop_param), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize([0.5], [0.5])]
)
semantic_image_pixel_values = [ref_image_transform(img)]
semantic_image_pixel_values = torch.cat(semantic_image_pixel_values).unsqueeze(0).unsqueeze(2).to(torch.float16).to(torch.device("cuda"))
img_latents = self.vae_encoder.encode(semantic_image_pixel_values, self.config).mode()
scaling_factor = 0.476986
img_latents.mul_(scaling_factor)
return img_latents, kwargs
def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
image_encoder_output = {"img": img, "img_latents": vae_encoder_out}
return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
def set_target_shape(self):
vae_scale_factor = 2 ** (4 - 1)
self.config.target_shape = (
1,
16,
(self.config.target_video_length - 1) // 4 + 1,
int(self.config.target_height) // vae_scale_factor,
int(self.config.target_width) // vae_scale_factor,
)
return {"target_height": self.config.target_height, "target_width": self.config.target_width, "target_shape": self.config.target_shape}
import gc
import torch
from loguru import logger
from lightx2v.models.networks.wan.causvid_model import WanCausVidModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.step_distill.scheduler import WanStepDistillScheduler
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
@RUNNER_REGISTER("wan2.1_causvid")
class WanCausVidRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
self.num_frame_per_block = self.config.num_frame_per_block
self.num_frames = self.config.num_frames
self.frame_seq_length = self.config.frame_seq_length
self.infer_blocks = self.config.num_blocks
self.num_fragments = self.config.num_fragments
def load_transformer(self):
if self.config.get("lora_configs") and self.config.lora_configs:
model = WanModel(
self.config.model_path,
self.config,
self.init_device,
)
lora_wrapper = WanLoraWrapper(model)
for lora_config in self.config.lora_configs:
lora_path = lora_config["path"]
strength = lora_config.get("strength", 1.0)
lora_name = lora_wrapper.load_lora(lora_path)
lora_wrapper.apply_lora(lora_name, strength)
logger.info(f"Loaded LoRA: {lora_name} with strength: {strength}")
else:
model = WanCausVidModel(self.config.model_path, self.config, self.init_device)
return model
def set_inputs(self, inputs):
super().set_inputs(inputs)
self.config["num_fragments"] = inputs.get("num_fragments", 1)
self.num_fragments = self.config["num_fragments"]
def init_scheduler(self):
self.scheduler = WanStepDistillScheduler(self.config)
def set_target_shape(self):
if self.config.task in ["i2v", "s2v"]:
self.config.target_shape = (16, self.config.num_frame_per_block, self.config.lat_h, self.config.lat_w)
# i2v需根据input shape重置frame_seq_length
frame_seq_length = (self.config.lat_h // 2) * (self.config.lat_w // 2)
self.model.transformer_infer.frame_seq_length = frame_seq_length
self.frame_seq_length = frame_seq_length
elif self.config.task == "t2v":
self.config.target_shape = (
16,
self.config.num_frame_per_block,
int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2],
)
def run(self):
self.model.transformer_infer._init_kv_cache(dtype=GET_DTYPE(), device="cuda")
self.model.transformer_infer._init_crossattn_cache(dtype=GET_DTYPE(), device="cuda")
output_latents = torch.zeros(
(self.model.config.target_shape[0], self.num_frames + (self.num_fragments - 1) * (self.num_frames - self.num_frame_per_block), *self.model.config.target_shape[2:]),
device="cuda",
dtype=GET_DTYPE(),
)
start_block_idx = 0
for fragment_idx in range(self.num_fragments):
logger.info(f"========> fragment_idx: {fragment_idx + 1} / {self.num_fragments}")
kv_start = 0
kv_end = kv_start + self.num_frame_per_block * self.frame_seq_length
if fragment_idx > 0:
logger.info("recompute the kv_cache ...")
with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.latents = self.model.scheduler.last_sample
self.model.scheduler.step_pre(step_index=self.model.scheduler.infer_steps - 1)
with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs, kv_start, kv_end)
kv_start += self.num_frame_per_block * self.frame_seq_length
kv_end += self.num_frame_per_block * self.frame_seq_length
infer_blocks = self.infer_blocks - (fragment_idx > 0)
for block_idx in range(infer_blocks):
logger.info(f"=====> block_idx: {block_idx + 1} / {infer_blocks}")
logger.info(f"=====> kv_start: {kv_start}, kv_end: {kv_end}")
self.model.scheduler.reset()
for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs, kv_start, kv_end)
with ProfilingContext4DebugL1("step_post"):
self.model.scheduler.step_post()
kv_start += self.num_frame_per_block * self.frame_seq_length
kv_end += self.num_frame_per_block * self.frame_seq_length
output_latents[:, start_block_idx * self.num_frame_per_block : (start_block_idx + 1) * self.num_frame_per_block] = self.model.scheduler.latents
start_block_idx += 1
return output_latents, self.model.scheduler.generator
def end_run(self):
self.model.scheduler.clear()
del self.inputs, self.model.scheduler, self.model.transformer_infer.kv_cache, self.model.transformer_infer.crossattn_cache
gc.collect()
torch.cuda.empty_cache()
......@@ -8,8 +8,8 @@ import torchvision.transforms.functional as TF
from PIL import Image
from loguru import logger
from lightx2v.models.input_encoders.hf.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.xlm_roberta.model import CLIPModel
from lightx2v.models.input_encoders.hf.wan.t5.model import T5EncoderModel
from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import CLIPModel
from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.runners.default_runner import DefaultRunner
......
import gc
import os
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from loguru import logger
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler import WanSkyreelsV2DFScheduler
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
@RUNNER_REGISTER("wan2.1_skyreels_v2_df")
class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I2V/T2V
def __init__(self, config):
super().__init__(config)
def init_scheduler(self):
self.scheduler = WanSkyreelsV2DFScheduler(self.config)
def run_image_encoder(self, config, image_encoder, vae_model):
img = Image.open(config.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
h, w = img.shape[1:]
aspect_ratio = h / w
max_area = config.target_height * config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
h = lat_h * config.vae_stride[1]
w = lat_w * config.vae_stride[2]
config.lat_h = lat_h
config.lat_w = lat_w
vae_encoder_out = vae_model.encode([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1).cuda()])[0]
vae_encoder_out = vae_encoder_out.to(GET_DTYPE())
return vae_encoder_out
def set_target_shape(self):
if os.path.isfile(self.config.image_path):
self.config.target_shape = (16, (self.config.target_video_length - 1) // 4 + 1, self.config.lat_h, self.config.lat_w)
else:
self.config.target_shape = (
16,
(self.config.target_video_length - 1) // 4 + 1,
int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2],
)
def run_input_encoder(self):
image_encoder_output = None
if os.path.isfile(self.config.image_path):
with ProfilingContext4DebugL2("Run Img Encoder"):
image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model)
with ProfilingContext4DebugL2("Run Text Encoder"):
text_encoder_output = self.run_text_encoder(self.config["prompt"], self.text_encoders, self.config, image_encoder_output)
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
gc.collect()
torch.cuda.empty_cache()
def run(self):
num_frames = self.config.num_frames
overlap_history = self.config.overlap_history
base_num_frames = self.config.base_num_frames
addnoise_condition = self.config.addnoise_condition
causal_block_size = self.config.causal_block_size
latent_length = (num_frames - 1) // 4 + 1
base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
overlap_history_frames = (overlap_history - 1) // 4 + 1
n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
prefix_video = self.inputs["image_encoder_output"]
predix_video_latent_length = 0
if prefix_video is not None:
predix_video_latent_length = prefix_video.size(1)
output_video = None
logger.info(f"Diffusion-Forcing n_iter:{n_iter}")
for i in range(n_iter):
if output_video is not None: # i !=0
prefix_video = output_video[:, :, -overlap_history:].to(self.model.scheduler.device)
prefix_video = self.vae_model.encode(prefix_video)[0] # [(b, c, f, h, w)]
if prefix_video.shape[1] % causal_block_size != 0:
truncate_len = prefix_video.shape[1] % causal_block_size
# the length of prefix video is truncated for the casual block size alignment.
prefix_video = prefix_video[:, : prefix_video.shape[1] - truncate_len]
predix_video_latent_length = prefix_video.shape[1]
finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
left_frame_num = latent_length - finished_frame_num
base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
else: # i == 0
base_num_frames_iter = base_num_frames
if prefix_video is not None:
input_dtype = self.model.scheduler.latents.dtype
self.model.scheduler.latents[:, :predix_video_latent_length] = prefix_video.to(input_dtype)
self.model.scheduler.generate_timestep_matrix(base_num_frames_iter, base_num_frames_iter, addnoise_condition, predix_video_latent_length, causal_block_size)
for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4DebugL1("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4DebugL1("🚀 infer_main"):
self.model.infer(self.inputs)
with ProfilingContext4DebugL1("step_post"):
self.model.scheduler.step_post()
videos = self.run_vae(self.model.scheduler.latents, self.model.scheduler.generator)
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) # reset
if output_video is None:
output_video = videos.clamp(-1, 1).cpu() # b, c, f, h, w
else:
output_video = torch.cat([output_video, videos[:, :, overlap_history:].clamp(-1, 1).cpu()], 2)
return output_video
def run_pipeline(self):
self.init_scheduler()
self.model.set_scheduler(self.scheduler)
self.run_input_encoder()
self.model.scheduler.prepare()
output_video = self.run()
self.end_run()
self.save_video(output_video)
import numpy as np
import torch
from diffusers.models.embeddings import get_3d_rotary_pos_embed
from diffusers.utils.torch_utils import randn_tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler
# Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
tw = tgt_width
th = tgt_height
h, w = src
r = h / w
if r > (th / tw):
resize_height = th
resize_width = int(round(th / h * w))
else:
resize_width = tw
resize_height = int(round(tw / w * h))
crop_top = int(round((th - resize_height) / 2.0))
crop_left = int(round((tw - resize_width) / 2.0))
return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
def rescale_zero_terminal_snr(alphas_cumprod):
"""
Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
Args:
betas (`torch.Tensor`):
the betas that the scheduler is being initialized with.
Returns:
`torch.Tensor`: rescaled betas with zero terminal SNR
"""
alphas_bar_sqrt = alphas_cumprod.sqrt()
# Store old values.
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
# Shift so the last timestep is zero.
alphas_bar_sqrt -= alphas_bar_sqrt_T
# Scale so the first timestep is back to the old value.
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
# Convert alphas_bar_sqrt to betas
alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
return alphas_bar
class CogvideoxXDPMScheduler(BaseScheduler):
def __init__(self, config):
self.config = config
self.set_timesteps()
self.generator = torch.Generator().manual_seed(config.seed)
self.noise_pred = None
if self.config.beta_schedule == "scaled_linear":
# this schedule is very specific to the latent diffusion model.
self.betas = torch.linspace(self.config.scheduler_beta_start**0.5, self.config.scheduler_beta_end**0.5, self.config.num_train_timesteps, dtype=torch.float64) ** 2
else:
raise NotImplementedError(f"{self.config.beta_schedule} is not implemented for {self.__class__}")
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0).to(torch.device("cuda"))
# Modify: SNR shift following SD3
self.alphas_cumprod = self.alphas_cumprod / (self.config.scheduler_snr_shift_scale + (1 - self.config.scheduler_snr_shift_scale) * self.alphas_cumprod)
# Rescale for zero SNR
if self.config.scheduler_rescale_betas_zero_snr:
self.alphas_cumprod = rescale_zero_terminal_snr(self.alphas_cumprod)
# At every step in ddim, we are looking into the previous alphas_cumprod
# For the final step, there is no previous alphas_cumprod because we are already at 0
# `set_alpha_to_one` decides whether we set this parameter simply to one or
# whether we use the final alpha of the "non-previous" one.
self.final_alpha_cumprod = torch.tensor(1.0) if self.config.scheduler_set_alpha_to_one else self.alphas_cumprod[0]
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
def scale_model_input(self, sample, timestep=None):
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The current timestep in the diffusion chain.
Returns:
`torch.Tensor`:
A scaled input sample.
"""
return sample
def set_timesteps(self):
if self.config.num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
f"`num_inference_steps`: {self.config.num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
f" maximal {self.config.num_train_timesteps} timesteps."
)
self.infer_steps = self.config.num_inference_steps
# "linspace", "leading", "trailing" corresponds to annotation of Table 2. of https://arxiv.org/abs/2305.08891
if self.config.timestep_spacing == "linspace":
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, self.config.num_inference_steps).round()[::-1].copy().astype(np.int64)
elif self.config.timestep_spacing == "leading":
step_ratio = self.config.num_train_timesteps // self.config.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = (np.arange(0, self.infer_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
timesteps += self.config.steps_offset
elif self.config.timestep_spacing == "trailing":
step_ratio = self.config.num_train_timesteps / self.config.num_inference_steps
# creates integer timesteps by multiplying by ratio
# casting to int to avoid issues when num_inference_step is power of 3
timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
timesteps -= 1
else:
raise ValueError(f"{self.config.timestep_spacing} is not supported. Please make sure to choose one of 'leading' or 'trailing'.")
self.timesteps = torch.Tensor(timesteps).to(torch.device("cuda")).int()
def prepare(self, image_encoder_output):
self.image_encoder_output = image_encoder_output
self.prepare_latents(shape=self.config.target_shape, dtype=torch.bfloat16)
self.prepare_guidance()
self.prepare_rotary_pos_embedding()
def prepare_latents(self, shape, dtype):
latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.init_noise_sigma
self.latents = latents
self.old_pred_original_sample = None
def prepare_guidance(self):
self.guidance_scale = self.config.guidance_scale
def prepare_rotary_pos_embedding(self):
grid_height = self.config.height // (self.config.vae_scale_factor_spatial * self.config.patch_size)
grid_width = self.config.width // (self.config.vae_scale_factor_spatial * self.config.patch_size)
p = self.config.patch_size
p_t = self.config.patch_size_t
base_size_width = self.config.transformer_sample_width // p
base_size_height = self.config.transformer_sample_height // p
num_frames = self.latents.size(1)
if p_t is None:
# CogVideoX 1.0
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.transformer.config.attention_head_dim,
crops_coords=grid_crops_coords,
grid_size=(grid_height, grid_width),
temporal_size=num_frames,
device=torch.device("cuda"),
)
else:
# CogVideoX 1.5
base_num_frames = (num_frames + p_t - 1) // p_t
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
embed_dim=self.config.transformer_attention_head_dim,
crops_coords=None,
grid_size=(grid_height, grid_width),
temporal_size=base_num_frames,
grid_type="slice",
max_size=(base_size_height, base_size_width),
device=torch.device("cuda"),
)
self.freqs_cos = freqs_cos
self.freqs_sin = freqs_sin
self.image_rotary_emb = (freqs_cos, freqs_sin) if self.config.use_rotary_positional_embeddings else None
def get_variables(self, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back=None):
lamb = ((alpha_prod_t / (1 - alpha_prod_t)) ** 0.5).log()
lamb_next = ((alpha_prod_t_prev / (1 - alpha_prod_t_prev)) ** 0.5).log()
h = lamb_next - lamb
if alpha_prod_t_back is not None:
lamb_previous = ((alpha_prod_t_back / (1 - alpha_prod_t_back)) ** 0.5).log()
h_last = lamb - lamb_previous
r = h_last / h
return h, r, lamb, lamb_next
else:
return h, None, lamb, lamb_next
def get_mult(self, h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back):
mult1 = ((1 - alpha_prod_t_prev) / (1 - alpha_prod_t)) ** 0.5 * (-h).exp()
mult2 = (-2 * h).expm1() * alpha_prod_t_prev**0.5
if alpha_prod_t_back is not None:
mult3 = 1 + 1 / (2 * r)
mult4 = 1 / (2 * r)
return mult1, mult2, mult3, mult4
else:
return mult1, mult2
def step_post(self):
if self.infer_steps is None:
raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler")
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
# Ideally, read DDIM paper in-detail understanding
# Notation (<variable name> -> <name in paper>
# - pred_noise_t -> e_theta(x_t, t)
# - pred_original_sample -> f_theta(x_t, t) or x_0
# - std_dev_t -> sigma_t
# - eta -> η
# - pred_sample_direction -> "direction pointing to x_t"
# - pred_prev_sample -> "x_t-1"
timestep = self.timesteps[self.step_index]
timestep_back = self.timesteps[self.step_index - 1] if self.step_index > 0 else None
# 1. get previous step value (=t-1)
prev_timestep = timestep - self.config.num_train_timesteps // self.infer_steps
# 2. compute alphas, betas
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t_back = self.alphas_cumprod[timestep_back] if timestep_back is not None else None
beta_prod_t = 1 - alpha_prod_t
# 3. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
# To make style tests pass, commented out `pred_epsilon` as it is an unused variable
if self.config.scheduler_prediction_type == "epsilon":
pred_original_sample = (self.latents - beta_prod_t ** (0.5) * self.noise_pred) / alpha_prod_t ** (0.5)
# pred_epsilon = model_output
elif self.config.scheduler_prediction_type == "sample":
pred_original_sample = self.noise_pred
# pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
elif self.config.scheduler_prediction_type == "v_prediction":
pred_original_sample = (alpha_prod_t**0.5) * self.latents - (beta_prod_t**0.5) * self.noise_pred
# pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
else:
raise ValueError(f"prediction_type given as {self.config.scheduler_prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`")
h, r, lamb, lamb_next = self.get_variables(alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back)
mult = list(self.get_mult(h, r, alpha_prod_t, alpha_prod_t_prev, alpha_prod_t_back))
mult_noise = (1 - alpha_prod_t_prev) ** 0.5 * (1 - (-2 * h).exp()) ** 0.5
noise = randn_tensor(self.latents.shape, generator=self.generator, device=self.latents.device, dtype=self.latents.dtype)
prev_sample = mult[0] * self.latents - mult[1] * pred_original_sample + mult_noise * noise
if self.old_pred_original_sample is None or prev_timestep < 0:
# Save a network evaluation if all noise levels are 0 or on the first step
self.latents = prev_sample
self.old_pred_original_sample = pred_original_sample
else:
denoised_d = mult[2] * pred_original_sample - mult[3] * self.old_pred_original_sample
noise = randn_tensor(self.latents.shape, generator=self.generator, device=self.latents.device, dtype=self.latents.dtype)
x_advanced = mult[0] * self.latents - mult[1] * denoised_d + mult_noise * noise
self.latents = x_advanced
self.old_pred_original_sample = pred_original_sample
from ..scheduler import HunyuanScheduler
class HunyuanSchedulerTeaCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class HunyuanSchedulerTaylorCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
pattern = [True, False, False, False]
self.caching_records = (pattern * ((config.infer_steps + 3) // 4))[: config.infer_steps]
def clear(self):
self.transformer_infer.clear()
class HunyuanSchedulerAdaCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
class HunyuanSchedulerCustomCaching(HunyuanScheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
import torch
def cache_init(num_steps, model_kwargs=None):
"""
Initialization for cache.
"""
cache_dic = {}
cache = {}
cache_index = {}
cache[-1] = {}
cache_index[-1] = {}
cache_index["layer_index"] = {}
cache_dic["attn_map"] = {}
cache_dic["attn_map"][-1] = {}
cache_dic["attn_map"][-1]["double_stream"] = {}
cache_dic["attn_map"][-1]["single_stream"] = {}
cache_dic["k-norm"] = {}
cache_dic["k-norm"][-1] = {}
cache_dic["k-norm"][-1]["double_stream"] = {}
cache_dic["k-norm"][-1]["single_stream"] = {}
cache_dic["v-norm"] = {}
cache_dic["v-norm"][-1] = {}
cache_dic["v-norm"][-1]["double_stream"] = {}
cache_dic["v-norm"][-1]["single_stream"] = {}
cache_dic["cross_attn_map"] = {}
cache_dic["cross_attn_map"][-1] = {}
cache[-1]["double_stream"] = {}
cache[-1]["single_stream"] = {}
cache_dic["cache_counter"] = 0
for j in range(20):
cache[-1]["double_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic["attn_map"][-1]["double_stream"][j] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["total"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["attn_map"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["k-norm"][-1]["double_stream"][j]["img_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["txt_mlp"] = {}
cache_dic["v-norm"][-1]["double_stream"][j]["img_mlp"] = {}
for j in range(40):
cache[-1]["single_stream"][j] = {}
cache_index[-1][j] = {}
cache_dic["attn_map"][-1]["single_stream"][j] = {}
cache_dic["attn_map"][-1]["single_stream"][j]["total"] = {}
cache_dic["k-norm"][-1]["single_stream"][j] = {}
cache_dic["k-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["v-norm"][-1]["single_stream"][j] = {}
cache_dic["v-norm"][-1]["single_stream"][j]["total"] = {}
cache_dic["taylor_cache"] = False
cache_dic["duca"] = False
cache_dic["test_FLOPs"] = False
mode = "Taylor"
if mode == "original":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
elif mode == "ToCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = False
elif mode == "DuCa":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.10
cache_dic["fresh_threshold"] = 5
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["max_order"] = 0
cache_dic["first_enhance"] = 1
cache_dic["duca"] = True
elif mode == "Taylor":
cache_dic["cache_type"] = "random"
cache_dic["cache_index"] = cache_index
cache_dic["cache"] = cache
cache_dic["fresh_ratio_schedule"] = "ToCa"
cache_dic["fresh_ratio"] = 0.0
cache_dic["fresh_threshold"] = 2
cache_dic["max_order"] = 1
cache_dic["force_fresh"] = "global"
cache_dic["soft_fresh_weight"] = 0.0
cache_dic["taylor_cache"] = True
cache_dic["first_enhance"] = 1
current = {}
current["num_steps"] = num_steps
current["activated_steps"] = [0]
return cache_dic, current
def force_scheduler(cache_dic, current):
if cache_dic["fresh_ratio"] == 0:
# FORA
linear_step_weight = 0.0
else:
# TokenCache
linear_step_weight = 0.0
step_factor = torch.tensor(1 - linear_step_weight + 2 * linear_step_weight * current["step"] / current["num_steps"])
threshold = torch.round(cache_dic["fresh_threshold"] / step_factor)
# no force constrain for sensitive steps, cause the performance is good enough.
# you may have a try.
cache_dic["cal_threshold"] = threshold
# return threshold
def cal_type(cache_dic, current):
"""
Determine calculation type for this step
"""
if (cache_dic["fresh_ratio"] == 0.0) and (not cache_dic["taylor_cache"]):
# FORA:Uniform
first_step = current["step"] == 0
else:
# ToCa: First enhanced
first_step = current["step"] < cache_dic["first_enhance"]
# first_step = (current['step'] <= 3)
force_fresh = cache_dic["force_fresh"]
if not first_step:
fresh_interval = cache_dic["cal_threshold"]
else:
fresh_interval = cache_dic["fresh_threshold"]
if (first_step) or (cache_dic["cache_counter"] == fresh_interval - 1):
current["type"] = "full"
cache_dic["cache_counter"] = 0
current["activated_steps"].append(current["step"])
# current['activated_times'].append(current['t'])
force_scheduler(cache_dic, current)
elif cache_dic["taylor_cache"]:
cache_dic["cache_counter"] += 1
current["type"] = "taylor_cache"
else:
cache_dic["cache_counter"] += 1
if cache_dic["duca"]:
if cache_dic["cache_counter"] % 2 == 1: # 0: ToCa-Aggresive-ToCa, 1: Aggresive-ToCa-Aggresive
current["type"] = "ToCa"
# 'cache_noise' 'ToCa' 'FORA'
else:
current["type"] = "aggressive"
else:
current["type"] = "ToCa"
# if current['step'] < 25:
# current['type'] = 'FORA'
# else:
# current['type'] = 'aggressive'
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
from diffusers.utils.torch_utils import randn_tensor
from lightx2v.models.schedulers.scheduler import BaseScheduler
from lightx2v.utils.envs import *
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
def set_timesteps_sigmas(num_inference_steps, shift, device, num_train_timesteps=1000):
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
sigmas = (shift * sigmas) / (1 + (shift - 1) * sigmas)
timesteps = (sigmas[:-1] * num_train_timesteps).to(dtype=torch.float32, device=device)
return timesteps, sigmas
def get_1d_rotary_pos_embed_riflex(
dim: int,
pos: Union[np.ndarray, int],
theta: float = 10000.0,
use_real=False,
k: Optional[int] = None,
L_test: Optional[int] = None,
):
"""
RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
data type.
Args:
dim (`int`): Dimension of the frequency tensor.
pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
theta (`float`, *optional*, defaults to 10000.0):
Scaling factor for frequency computation. Defaults to 10000.0.
use_real (`bool`, *optional*):
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
L_test (`int`, *optional*, defaults to None): the number of frames for inference
Returns:
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
"""
assert dim % 2 == 0
if isinstance(pos, int):
pos = torch.arange(pos)
if isinstance(pos, np.ndarray):
pos = torch.from_numpy(pos) # [S]
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=pos.device)[: (dim // 2)].float() / dim)) # [D/2]
# === Riflex modification start ===
# Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
# Empirical observations show that a few videos may exhibit repetition in the tail frames.
# To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
if k is not None:
freqs[k - 1] = 0.9 * 2 * torch.pi / L_test
# === Riflex modification end ===
freqs = torch.outer(pos, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
return freqs_cos, freqs_sin
else:
# lumina
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
class HunyuanScheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.shift = 7.0
self.timesteps, self.sigmas = set_timesteps_sigmas(self.infer_steps, self.shift, device=torch.device("cuda"))
assert len(self.timesteps) == self.infer_steps
self.embedded_guidance_scale = 6.0
self.generator = [torch.Generator("cuda").manual_seed(seed) for seed in [self.config.seed]]
self.noise_pred = None
def prepare(self, image_encoder_output):
self.image_encoder_output = image_encoder_output
self.prepare_latents(shape=self.config.target_shape, dtype=torch.float32, image_encoder_output=image_encoder_output)
self.prepare_guidance()
self.prepare_rotary_pos_embedding(video_length=self.config.target_video_length, height=self.config.target_height, width=self.config.target_width)
def prepare_guidance(self):
self.guidance = torch.tensor([self.embedded_guidance_scale], dtype=GET_DTYPE(), device=torch.device("cuda")) * 1000.0
def step_post(self):
if self.config.task == "t2v":
sample = self.latents.to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
self.latents = sample + self.noise_pred.to(torch.float32) * dt
else:
sample = self.latents[:, :, 1:, :, :].to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
latents = sample + self.noise_pred[:, :, 1:, :, :].to(torch.float32) * dt
self.latents = torch.concat([self.image_encoder_output["img_latents"], latents], dim=2)
def prepare_latents(self, shape, dtype, image_encoder_output):
if self.config.task == "t2v":
self.latents = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
else:
x1 = image_encoder_output["img_latents"].repeat(1, 1, (self.config.target_video_length - 1) // 4 + 1, 1, 1)
x0 = randn_tensor(shape, generator=self.generator, device=torch.device("cuda"), dtype=dtype)
t = torch.tensor([0.999]).to(device=torch.device("cuda"))
self.latents = x0 * t + x1 * (1 - t)
self.latents = self.latents.to(dtype=dtype)
self.latents = torch.concat([image_encoder_output["img_latents"], self.latents[:, :, 1:, :, :]], dim=2)
def prepare_rotary_pos_embedding(self, video_length, height, width):
target_ndim = 3
ndim = 5 - 2
# 884
vae = "884-16c-hy"
patch_size = [1, 2, 2]
hidden_size = 3072
heads_num = 24
rope_theta = 256
rope_dim_list = [16, 56, 56]
if "884" in vae:
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
elif "888" in vae:
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
else:
latents_size = [video_length, height // 8, width // 8]
if isinstance(patch_size, int):
assert all(s % patch_size == 0 for s in latents_size), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
rope_sizes = [s // patch_size for s in latents_size]
elif isinstance(patch_size, list):
assert all(s % patch_size[idx] == 0 for idx, s in enumerate(latents_size)), f"Latent size(last {ndim} dimensions) should be divisible by patch size({patch_size}), but got {latents_size}."
rope_sizes = [s // patch_size[idx] for idx, s in enumerate(latents_size)]
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
if self.config.task == "t2v":
head_dim = hidden_size // heads_num
rope_dim_list = rope_dim_list
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
self.freqs_cos, self.freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=rope_theta,
use_real=True,
theta_rescale_factor=1,
)
self.freqs_cos = self.freqs_cos.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
self.freqs_sin = self.freqs_sin.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
else:
L_test = rope_sizes[0] # Latent frames
L_train = 25 # Training length from HunyuanVideo
actual_num_frames = video_length # Use input video_length directly
head_dim = hidden_size // heads_num
rope_dim_list = rope_dim_list or [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) must equal head_dim"
if actual_num_frames > 192:
k = 2 + ((actual_num_frames + 3) // (4 * L_train))
k = max(4, min(8, k))
# Compute positional grids for RIFLEx
axes_grids = [torch.arange(size, device=torch.device("cuda"), dtype=torch.float32) for size in rope_sizes]
grid = torch.meshgrid(*axes_grids, indexing="ij")
grid = torch.stack(grid, dim=0) # [3, t, h, w]
pos = grid.reshape(3, -1).t() # [t * h * w, 3]
# Apply RIFLEx to temporal dimension
freqs = []
for i in range(3):
if i == 0: # Temporal with RIFLEx
freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], pos[:, i], theta=rope_theta, use_real=True, k=k, L_test=L_test)
else: # Spatial with default RoPE
freqs_cos, freqs_sin = get_1d_rotary_pos_embed_riflex(rope_dim_list[i], pos[:, i], theta=rope_theta, use_real=True, k=None, L_test=None)
freqs.append((freqs_cos, freqs_sin))
freqs_cos = torch.cat([f[0] for f in freqs], dim=1)
freqs_sin = torch.cat([f[1] for f in freqs], dim=1)
else:
# 20250316 pftq: Original code for <= 192 frames
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=rope_theta,
use_real=True,
theta_rescale_factor=1,
)
self.freqs_cos = freqs_cos.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
self.freqs_sin = freqs_sin.to(dtype=GET_DTYPE(), device=torch.device("cuda"))
import math
import os
import numpy as np
import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from lightx2v.utils.envs import *
class WanSkyreelsV2DFScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.df_schedulers = []
self.flag_df = True
def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if os.path.isfile(self.config.image_path):
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
else:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * self.num_train_timesteps
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.sigmas = self.sigmas.to("cpu")
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def generate_timestep_matrix(
self,
num_frames,
base_num_frames,
addnoise_condition,
num_pre_ready,
casual_block_size=1,
ar_step=0,
shrink_interval_with_mask=False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
self.addnoise_condition = addnoise_condition
self.predix_video_latent_length = num_pre_ready
step_template = self.timesteps
step_matrix, step_index = [], []
update_mask, valid_interval = [], []
num_iterations = len(step_template) + 1
num_frames_block = num_frames // casual_block_size
base_num_frames_block = base_num_frames // casual_block_size
if base_num_frames_block < num_frames_block:
infer_step_num = len(step_template)
gen_block = base_num_frames_block
min_ar_step = infer_step_num / gen_block
assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
# print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
step_template = torch.cat(
[
torch.tensor([999], dtype=torch.int64, device=step_template.device),
step_template.long(),
torch.tensor([0], dtype=torch.int64, device=step_template.device),
]
) # to handle the counter in row works starting from 1
pre_row = torch.zeros(num_frames_block, dtype=torch.long)
if num_pre_ready > 0:
pre_row[: num_pre_ready // casual_block_size] = num_iterations
while not torch.all(pre_row >= (num_iterations - 1)):
new_row = torch.zeros(num_frames_block, dtype=torch.long)
for i in range(num_frames_block):
if i == 0 or pre_row[i - 1] >= (num_iterations - 1): # the first frame or the last frame is completely denoised
new_row[i] = pre_row[i] + 1
else:
new_row[i] = new_row[i - 1] - ar_step
new_row = new_row.clamp(0, num_iterations)
update_mask.append((new_row != pre_row) & (new_row != num_iterations)) # False: no need to update, True: need to update
step_index.append(new_row)
step_matrix.append(step_template[new_row])
pre_row = new_row
# for long video we split into several sequences, base_num_frames is set to the model max length (for training)
terminal_flag = base_num_frames_block
if shrink_interval_with_mask:
idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
update_mask = update_mask[0]
update_mask_idx = idx_sequence[update_mask]
last_update_idx = update_mask_idx[-1].item()
terminal_flag = last_update_idx + 1
# for i in range(0, len(update_mask)):
for curr_mask in update_mask:
if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
terminal_flag += 1
valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
step_update_mask = torch.stack(update_mask, dim=0)
step_index = torch.stack(step_index, dim=0)
step_matrix = torch.stack(step_matrix, dim=0)
if casual_block_size > 1:
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
self.step_matrix = step_matrix
self.step_update_mask = step_update_mask
self.valid_interval = valid_interval
self.df_timesteps = torch.zeros_like(self.step_matrix)
self.df_schedulers = []
for _ in range(base_num_frames):
sample_scheduler = WanScheduler(self.config)
sample_scheduler.prepare()
self.df_schedulers.append(sample_scheduler)
def step_pre(self, step_index):
self.step_index = step_index
if GET_DTYPE() == GET_SENSITIVE_DTYPE():
self.latents = self.latents.to(GET_DTYPE())
valid_interval_start, valid_interval_end = self.valid_interval[step_index]
timestep = self.step_matrix[step_index][None, valid_interval_start:valid_interval_end].clone()
if self.addnoise_condition > 0 and valid_interval_start < self.predix_video_latent_length:
latent_model_input = self.latents[:, valid_interval_start:valid_interval_end, :, :].clone()
noise_factor = 0.001 * self.addnoise_condition
self.latents[:, valid_interval_start : self.predix_video_latent_length] = (
latent_model_input[:, valid_interval_start : self.predix_video_latent_length] * (1.0 - noise_factor)
+ torch.randn_like(latent_model_input[:, valid_interval_start : self.predix_video_latent_length]) * noise_factor
)
timestep[:, valid_interval_start : self.predix_video_latent_length] = self.addnoise_condition
self.df_timesteps[step_index] = timestep
def step_post(self):
update_mask_i = self.step_update_mask[self.step_index]
valid_interval_start, valid_interval_end = self.valid_interval[self.step_index]
timestep = self.df_timesteps[self.step_index]
for idx in range(valid_interval_start, valid_interval_end): # 每一帧单独step
if update_mask_i[idx].item():
self.df_schedulers[idx].step_pre(step_index=self.step_index)
self.df_schedulers[idx].noise_pred = self.noise_pred[:, idx - valid_interval_start]
self.df_schedulers[idx].timesteps[self.step_index] = timestep[idx]
self.df_schedulers[idx].latents = self.latents[:, idx]
self.df_schedulers[idx].step_post()
self.latents[:, idx] = self.df_schedulers[idx].latents
import glob
import os
import torch # type: ignore
from diffusers.video_processor import VideoProcessor # type: ignore
from safetensors import safe_open # type: ignore
from lightx2v.models.video_encoders.hf.cogvideox.autoencoder_ks_cogvidex import AutoencoderKLCogVideoX
from lightx2v.utils.envs import *
class CogvideoxVAE:
def __init__(self, config):
self.config = config
self.load()
def _load_safetensor_to_dict(self, file_path):
with safe_open(file_path, framework="pt") as f:
tensor_dict = {key: f.get_tensor(key).to(GET_DTYPE()).cuda() for key in f.keys()}
return tensor_dict
def _load_ckpt(self, model_path):
safetensors_pattern = os.path.join(model_path, "*.safetensors")
safetensors_files = glob.glob(safetensors_pattern)
if not safetensors_files:
raise FileNotFoundError(f"No .safetensors files found in directory: {model_path}")
weight_dict = {}
for file_path in safetensors_files:
file_weights = self._load_safetensor_to_dict(file_path)
weight_dict.update(file_weights)
return weight_dict
def load(self):
vae_path = os.path.join(self.config.model_path, "vae")
self.vae_config = AutoencoderKLCogVideoX.load_config(vae_path)
self.model = AutoencoderKLCogVideoX.from_config(self.vae_config)
vae_ckpt = self._load_ckpt(vae_path)
self.vae_scale_factor_spatial = 2 ** (len(self.vae_config["block_out_channels"]) - 1) # 8
self.vae_scale_factor_temporal = self.vae_config["temporal_compression_ratio"] # 4
self.vae_scaling_factor_image = self.vae_config["scaling_factor"] # 0.7
self.model.load_state_dict(vae_ckpt)
self.model.to(GET_DTYPE()).to(torch.device("cuda"))
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
@torch.no_grad()
def decode(self, latents, generator, config):
latents = latents.permute(0, 2, 1, 3, 4)
latents = 1 / self.config.vae_scaling_factor_image * latents
frames = self.model.decode(latents).sample
images = self.video_processor.postprocess_video(video=frames, output_type="pil")[0]
return images
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment