Unverified Commit f21da849 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files
parent 3efc43f5
......@@ -11,6 +11,7 @@ class WanPreInfer:
assert (config["dim"] % config["num_heads"]) == 0 and (config["dim"] // config["num_heads"]) % 2 == 0
self.config = config
d = config["dim"] // config["num_heads"]
self.run_device = self.config.get("run_device", "cuda")
self.clean_cuda_cache = config.get("clean_cuda_cache", False)
self.task = config["task"]
self.device = torch.device(self.config.get("run_device", "cuda"))
......@@ -21,7 +22,7 @@ class WanPreInfer:
rope_params(1024, 2 * (d // 6)),
],
dim=1,
).to(self.device)
).to(torch.device(self.run_device))
self.freq_dim = config["freq_dim"]
self.dim = config["dim"]
self.enable_dynamic_cfg = config.get("enable_dynamic_cfg", False)
......
......@@ -48,6 +48,7 @@ class WanModel(CompiledMethodsMixin):
super().__init__()
self.model_path = model_path
self.config = config
self.run_device = self.config.get("run_device", "cuda")
self.cpu_offload = self.config.get("cpu_offload", False)
self.offload_granularity = self.config.get("offload_granularity", "block")
self.model_type = model_type
......
......@@ -59,10 +59,11 @@ class DefaultRunner(BaseRunner):
self.model.compile(self.config.get("compile_shapes", []))
def set_init_device(self):
self.run_device = self.config.get("run_device", "cuda")
if self.config["cpu_offload"]:
self.init_device = torch.device("cpu")
else:
self.init_device = torch.device(self.config.get("run_device", "cuda"))
self.init_device = torch.device(self.run_device)
def load_vfi_model(self):
if self.config["video_frame_interpolation"].get("algo", None) == "rife":
......@@ -162,15 +163,25 @@ class DefaultRunner(BaseRunner):
total_all_steps = self.video_segment_num * infer_steps
self.progress_callback((current_step / total_all_steps) * 100, 100)
if segment_idx is not None and segment_idx == self.video_segment_num - 1:
del self.inputs
torch.cuda.empty_cache()
return self.model.scheduler.latents
def run_step(self):
self.inputs = self.run_input_encoder()
self.run_main()
if hasattr(self, "sr_version") and self.sr_version is not None is not None:
self.config_sr["is_sr_running"] = True
self.inputs_sr = self.run_input_encoder()
self.config_sr["is_sr_running"] = False
self.run_main(total_steps=1)
def end_run(self):
self.model.scheduler.clear()
del self.inputs
if hasattr(self, "inputs"):
del self.inputs
self.input_info = None
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
if hasattr(self.model.transformer_infer, "weights_stream_mgr"):
......@@ -211,7 +222,7 @@ class DefaultRunner(BaseRunner):
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_t2v(self):
self.input_info.latent_shape = self.get_latent_shape_with_target_hw(self.config["target_height"], self.config["target_width"]) # Important: set latent_shape in input_info
self.input_info.latent_shape = self.get_latent_shape_with_target_hw() # Important: set latent_shape in input_info
text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache()
gc.collect()
......@@ -273,6 +284,14 @@ class DefaultRunner(BaseRunner):
if self.config.get("model_cls") == "wan2.2" and self.config["task"] in ["i2v", "s2v"]:
self.inputs["image_encoder_output"]["vae_encoder_out"] = None
if hasattr(self, "sr_version") and self.sr_version is not None is not None:
self.lq_latents_shape = self.model.scheduler.latents.shape
self.model_sr.set_scheduler(self.scheduler_sr)
self.config_sr["is_sr_running"] = True
self.inputs_sr = self.run_input_encoder()
self.config_sr["is_sr_running"] = False
@ProfilingContext4DebugL2("Run DiT")
def run_main(self):
self.init_run()
......
import copy
import gc
import os
import numpy as np
import torch
import torchvision.transforms as transforms
from PIL import Image
from loguru import logger
from lightx2v.models.input_encoders.hf.hunyuan15.byt5.model import ByT5TextEncoder
from lightx2v.models.input_encoders.hf.hunyuan15.qwen25.model import Qwen25VL_TextEncoder
from lightx2v.models.input_encoders.hf.hunyuan15.siglip.model import SiglipVisionEncoder
from lightx2v.models.networks.hunyuan_video.model import HunyuanVideo15Model
from lightx2v.models.runners.default_runner import DefaultRunner
from lightx2v.models.schedulers.hunyuan_video.feature_caching.scheduler import HunyuanVideo15SchedulerCaching
from lightx2v.models.schedulers.hunyuan_video.scheduler import HunyuanVideo15SRScheduler, HunyuanVideo15Scheduler
from lightx2v.models.video_encoders.hf.hunyuanvideo15.hunyuanvideo_15_vae import HunyuanVideo15VAE
from lightx2v.models.video_encoders.hf.hunyuanvideo15.lighttae_hy15 import LightTaeHy15
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import *
@RUNNER_REGISTER("hunyuan_video_1.5")
class HunyuanVideo15Runner(DefaultRunner):
def __init__(self, config):
config["is_sr_running"] = False
if "video_super_resolution" in config and "sr_version" in config["video_super_resolution"]:
self.sr_version = config["video_super_resolution"]["sr_version"]
else:
self.sr_version = None
if self.sr_version is not None:
self.config_sr = copy.deepcopy(config)
self.config_sr["is_sr_running"] = False
self.config_sr["sample_shift"] = config["video_super_resolution"]["flow_shift"] # for SR model
self.config_sr["sample_guide_scale"] = config["video_super_resolution"]["guidance_scale"] # for SR model
self.config_sr["infer_steps"] = config["video_super_resolution"]["num_inference_steps"]
super().__init__(config)
self.target_size_config = {
"360p": {"bucket_hw_base_size": 480, "bucket_hw_bucket_stride": 16},
"480p": {"bucket_hw_base_size": 640, "bucket_hw_bucket_stride": 16},
"720p": {"bucket_hw_base_size": 960, "bucket_hw_bucket_stride": 16},
"1080p": {"bucket_hw_base_size": 1440, "bucket_hw_bucket_stride": 16},
}
self.vision_num_semantic_tokens = 729
self.vision_states_dim = 1152
self.vae_cls = HunyuanVideo15VAE
self.tae_cls = LightTaeHy15
def init_scheduler(self):
if self.config["feature_caching"] == "NoCaching":
scheduler_class = HunyuanVideo15Scheduler
elif self.config.feature_caching in ["Mag", "Tea"]:
scheduler_class = HunyuanVideo15SchedulerCaching
else:
raise NotImplementedError(f"Unsupported feature_caching type: {self.config.feature_caching}")
self.scheduler = scheduler_class(self.config)
if self.sr_version is not None:
self.scheduler_sr = HunyuanVideo15SRScheduler(self.config_sr)
else:
self.scheduler_sr = None
def load_text_encoder(self):
qwen25vl_offload = self.config.get("qwen25vl_cpu_offload", self.config.get("cpu_offload"))
if qwen25vl_offload:
qwen25vl_device = torch.device("cpu")
else:
qwen25vl_device = torch.device("cuda")
qwen25vl_quantized = self.config.get("qwen25vl_quantized", False)
qwen25vl_quant_scheme = self.config.get("qwen25vl_quant_scheme", None)
qwen25vl_quantized_ckpt = self.config.get("qwen25vl_quantized_ckpt", None)
text_encoder_path = os.path.join(self.config["model_path"], "text_encoder/llm")
logger.info(f"Loading text encoder from {text_encoder_path}")
text_encoder = Qwen25VL_TextEncoder(
dtype=torch.float16,
device=qwen25vl_device,
checkpoint_path=text_encoder_path,
cpu_offload=qwen25vl_offload,
qwen25vl_quantized=qwen25vl_quantized,
qwen25vl_quant_scheme=qwen25vl_quant_scheme,
qwen25vl_quant_ckpt=qwen25vl_quantized_ckpt,
)
byt5_offload = self.config.get("byt5_cpu_offload", self.config.get("cpu_offload"))
if byt5_offload:
byt5_device = torch.device("cpu")
else:
byt5_device = torch.device("cuda")
byt5 = ByT5TextEncoder(config=self.config, device=byt5_device, checkpoint_path=self.config["model_path"], cpu_offload=byt5_offload)
text_encoders = [text_encoder, byt5]
return text_encoders
def load_transformer(self):
model = HunyuanVideo15Model(self.config["model_path"], self.config, self.init_device)
if self.sr_version is not None:
self.config_sr["transformer_model_path"] = os.path.join(os.path.dirname(self.config.transformer_model_path), self.sr_version)
self.config_sr["is_sr_running"] = True
model_sr = HunyuanVideo15Model(self.config_sr["model_path"], self.config_sr, self.init_device)
self.config_sr["is_sr_running"] = False
else:
model_sr = None
self.model_sr = model_sr
return model
def get_latent_shape_with_target_hw(self, origin_size=None):
if origin_size is None:
width, height = self.config["aspect_ratio"].split(":")
else:
width, height = origin_size
target_size = self.config["transformer_model_name"].split("_")[0]
target_height, target_width = self.get_closest_resolution_given_original_size((int(width), int(height)), target_size)
latent_shape = [
self.config.get("in_channels", 32),
(self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
target_height // self.config["vae_stride"][1],
target_width // self.config["vae_stride"][2],
]
self.target_height = target_height
self.target_width = target_width
return latent_shape
def get_sr_latent_shape_with_target_hw(self):
SizeMap = {
"480p": 640,
"720p": 960,
"1080p": 1440,
}
sr_stride = 16
base_size = SizeMap[self.config_sr["video_super_resolution"]["base_resolution"]]
sr_size = SizeMap[self.sr_version.split("_")[0]]
lr_video_height, lr_video_width = [x * 16 for x in self.lq_latents_shape[-2:]]
hr_bucket_map = self.build_bucket_map(lr_base_size=base_size, hr_base_size=sr_size, lr_patch_size=16, hr_patch_size=sr_stride)
target_width, target_height = hr_bucket_map((lr_video_width, lr_video_height))
latent_shape = [
self.config_sr.get("in_channels", 32),
(self.config_sr["target_video_length"] - 1) // self.config_sr["vae_stride"][0] + 1,
target_height // self.config_sr["vae_stride"][1],
target_width // self.config_sr["vae_stride"][2],
]
self.target_sr_height = target_height
self.target_sr_width = target_width
return latent_shape
def get_closest_resolution_given_original_size(self, origin_size, target_size):
bucket_hw_base_size = self.target_size_config[target_size]["bucket_hw_base_size"]
bucket_hw_bucket_stride = self.target_size_config[target_size]["bucket_hw_bucket_stride"]
assert bucket_hw_base_size in [128, 256, 480, 512, 640, 720, 960, 1440], f"bucket_hw_base_size must be in [128, 256, 480, 512, 640, 720, 960], but got {bucket_hw_base_size}"
crop_size_list = self.generate_crop_size_list(bucket_hw_base_size, bucket_hw_bucket_stride)
aspect_ratios = np.array([round(float(h) / float(w), 5) for h, w in crop_size_list])
closest_size, closest_ratio = self.get_closest_ratio(origin_size[1], origin_size[0], aspect_ratios, crop_size_list)
height = closest_size[0]
width = closest_size[1]
return height, width
def generate_crop_size_list(self, base_size=256, patch_size=16, max_ratio=4.0):
num_patches = round((base_size / patch_size) ** 2)
assert max_ratio >= 1.0
crop_size_list = []
wp, hp = num_patches, 1
while wp > 0:
if max(wp, hp) / min(wp, hp) <= max_ratio:
crop_size_list.append((wp * patch_size, hp * patch_size))
if (hp + 1) * wp <= num_patches:
hp += 1
else:
wp -= 1
return crop_size_list
def get_closest_ratio(self, height: float, width: float, ratios: list, buckets: list):
aspect_ratio = float(height) / float(width)
diff_ratios = ratios - aspect_ratio
if aspect_ratio >= 1:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x <= 0]
else:
indices = [(index, x) for index, x in enumerate(diff_ratios) if x > 0]
closest_ratio_id = min(indices, key=lambda pair: abs(pair[1]))[0]
closest_size = buckets[closest_ratio_id]
closest_ratio = ratios[closest_ratio_id]
return closest_size, closest_ratio
def run_text_encoder(self, input_info):
prompt = input_info.prompt_enhanced if self.config["use_prompt_enhancer"] else input_info.prompt
neg_prompt = input_info.negative_prompt
# run qwen25vl
if self.config.get("enable_cfg", False) and self.config["cfg_parallel"]:
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0:
context = self.text_encoders[0].infer([prompt])
text_encoder_output = {"context": context}
else:
context_null = self.text_encoders[0].infer([neg_prompt])
text_encoder_output = {"context_null": context_null}
else:
context = self.text_encoders[0].infer([prompt])
context_null = self.text_encoders[0].infer([neg_prompt]) if self.config.get("enable_cfg", False) else None
text_encoder_output = {
"context": context,
"context_null": context_null,
}
# run byt5
byt5_features, byt5_masks = self.text_encoders[1].infer([prompt])
text_encoder_output.update({"byt5_features": byt5_features, "byt5_masks": byt5_masks})
return text_encoder_output
def load_image_encoder(self):
siglip_offload = self.config.get("siglip_cpu_offload", self.config.get("cpu_offload"))
if siglip_offload:
siglip_device = torch.device("cpu")
else:
siglip_device = torch.device("cuda")
image_encoder = SiglipVisionEncoder(
config=self.config,
device=siglip_device,
checkpoint_path=self.config["model_path"],
cpu_offload=siglip_offload,
)
return image_encoder
def load_vae_encoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"checkpoint_path": self.config["model_path"],
"device": vae_device,
"cpu_offload": vae_offload,
"dtype": GET_DTYPE(),
}
if self.config["task"] not in ["i2v", "flf2v", "animate", "vace", "s2v"]:
return None
else:
return self.vae_cls(**vae_config)
def load_vae_decoder(self):
# offload config
vae_offload = self.config.get("vae_cpu_offload", self.config.get("cpu_offload"))
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device("cuda")
vae_config = {
"checkpoint_path": self.config["model_path"],
"device": vae_device,
"cpu_offload": vae_offload,
"dtype": GET_DTYPE(),
}
if self.config.get("use_tae", False):
tae_path = self.config["tae_path"]
vae_decoder = self.tae_cls(vae_path=tae_path, dtype=GET_DTYPE()).to("cuda")
else:
vae_decoder = self.vae_cls(**vae_config)
return vae_decoder
def load_vae(self):
vae_encoder = self.load_vae_encoder()
if vae_encoder is None or self.config.get("use_tae", False):
vae_decoder = self.load_vae_decoder()
else:
vae_decoder = vae_encoder
return vae_encoder, vae_decoder
def load_vsr_model(self):
if self.sr_version:
from lightx2v.models.runners.vsr.vsr_wrapper_hy15 import SRModel3DV2, Upsampler
upsampler_cls = SRModel3DV2 if "720p" in self.sr_version else Upsampler
upsampler_path = os.path.join(self.config["model_path"], "upsampler", self.sr_version)
logger.info("Loading VSR model from {}".format(upsampler_path))
upsampler = upsampler_cls.from_pretrained(upsampler_path).to(self.init_device)
return upsampler
else:
return None
def build_bucket_map(self, lr_base_size, hr_base_size, lr_patch_size, hr_patch_size):
lr_buckets = self.generate_crop_size_list(base_size=lr_base_size, patch_size=lr_patch_size)
hr_buckets = self.generate_crop_size_list(base_size=hr_base_size, patch_size=hr_patch_size)
lr_aspect_ratios = np.array([w / h for w, h in lr_buckets])
hr_aspect_ratios = np.array([w / h for w, h in hr_buckets])
hr_bucket_map = {}
for i, (lr_w, lr_h) in enumerate(lr_buckets):
lr_ratio = lr_aspect_ratios[i]
closest_hr_ratio_id = np.abs(hr_aspect_ratios - lr_ratio).argmin()
hr_bucket_map[(lr_w, lr_h)] = hr_buckets[closest_hr_ratio_id]
def hr_bucket_fn(lr_bucket):
if lr_bucket not in hr_bucket_map:
raise ValueError(f"LR bucket {lr_bucket} not found in bucket map")
return hr_bucket_map[lr_bucket]
hr_bucket_fn.map = hr_bucket_map
return hr_bucket_fn
@ProfilingContext4DebugL1("Run SR")
def run_sr(self, lq_latents):
self.config_sr["is_sr_running"] = True
self.model_sr.scheduler.prepare(
seed=self.input_info.seed, latent_shape=self.latent_sr_shape, lq_latents=lq_latents, upsampler=self.vsr_model, image_encoder_output=self.inputs_sr["image_encoder_output"]
)
total_steps = self.model_sr.scheduler.infer_steps
for step_index in range(total_steps):
with ProfilingContext4DebugL1(
f"Run SR Dit every step",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_per_step_dit_duration,
metrics_labels=[step_index + 1, total_steps],
):
logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4DebugL1("step_pre"):
self.model_sr.scheduler.step_pre(step_index=step_index)
with ProfilingContext4DebugL1("🚀 infer_main"):
self.model_sr.infer(self.inputs_sr)
with ProfilingContext4DebugL1("step_post"):
self.model_sr.scheduler.step_post()
del self.inputs_sr
torch.cuda.empty_cache()
self.config_sr["is_sr_running"] = False
return self.model_sr.scheduler.latents
@ProfilingContext4DebugL1("Run VAE Decoder")
def run_vae_decoder(self, latents):
if self.sr_version:
latents = self.run_sr(latents)
images = super().run_vae_decoder(latents)
return images
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_t2v(self):
self.input_info.latent_shape = self.get_latent_shape_with_target_hw() # Important: set latent_shape in input_info
text_encoder_output = self.run_text_encoder(self.input_info)
# vision_states is all zero, because we don't have any image input
siglip_output = torch.zeros(1, self.vision_num_semantic_tokens, self.config["hidden_size"], dtype=torch.bfloat16).cuda()
siglip_mask = torch.zeros(1, self.vision_num_semantic_tokens, dtype=torch.bfloat16, device=torch.device("cuda"))
torch.cuda.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": {
"siglip_output": siglip_output,
"siglip_mask": siglip_mask,
"cond_latents": None,
},
}
def read_image_input(self, img_path):
if isinstance(img_path, Image.Image):
img_ori = img_path
else:
img_ori = Image.open(img_path).convert("RGB")
return img_ori
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_i2v(self):
img_ori = self.read_image_input(self.input_info.image_path)
if self.sr_version and self.config_sr["is_sr_running"]:
self.latent_sr_shape = self.get_sr_latent_shape_with_target_hw()
self.input_info.latent_shape = self.get_latent_shape_with_target_hw(origin_size=img_ori.size) # Important: set latent_shape in input_info
siglip_output, siglip_mask = self.run_image_encoder(img_ori) if self.config.get("use_image_encoder", True) else None
cond_latents = self.run_vae_encoder(img_ori)
text_encoder_output = self.run_text_encoder(self.input_info)
torch.cuda.empty_cache()
gc.collect()
return {
"text_encoder_output": text_encoder_output,
"image_encoder_output": {
"siglip_output": siglip_output,
"siglip_mask": siglip_mask,
"cond_latents": cond_latents,
},
}
@ProfilingContext4DebugL1(
"Run Image Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_img_encode_duration,
metrics_labels=["WanRunner"],
)
def run_image_encoder(self, first_frame, last_frame=None):
if self.sr_version and self.config_sr["is_sr_running"]:
target_width = self.target_sr_width
target_height = self.target_sr_height
else:
target_width = self.target_width
target_height = self.target_height
input_image_np = self.resize_and_center_crop(first_frame, target_width=target_width, target_height=target_height)
vision_states = self.image_encoder.encode_images(input_image_np).last_hidden_state.to(device=torch.device("cuda"), dtype=torch.bfloat16)
image_encoder_output = self.image_encoder.infer(vision_states)
image_encoder_mask = torch.ones((1, image_encoder_output.shape[1]), dtype=torch.bfloat16, device=torch.device("cuda"))
return image_encoder_output, image_encoder_mask
def resize_and_center_crop(self, image, target_width, target_height):
image = np.array(image)
if target_height == image.shape[0] and target_width == image.shape[1]:
return image
pil_image = Image.fromarray(image)
original_width, original_height = pil_image.size
scale_factor = max(target_width / original_width, target_height / original_height)
resized_width = int(round(original_width * scale_factor))
resized_height = int(round(original_height * scale_factor))
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
left = (resized_width - target_width) / 2
top = (resized_height - target_height) / 2
right = (resized_width + target_width) / 2
bottom = (resized_height + target_height) / 2
cropped_image = resized_image.crop((left, top, right, bottom))
return np.array(cropped_image)
@ProfilingContext4DebugL1(
"Run VAE Encoder",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_vae_encoder_image_duration,
metrics_labels=["WanRunner"],
)
def run_vae_encoder(self, first_frame):
origin_size = first_frame.size
original_width, original_height = origin_size
if self.sr_version and self.config_sr["is_sr_running"]:
target_width = self.target_sr_width
target_height = self.target_sr_height
else:
target_width = self.target_width
target_height = self.target_height
scale_factor = max(target_width / original_width, self.target_height / original_height)
resize_width = int(round(original_width * scale_factor))
resize_height = int(round(original_height * scale_factor))
ref_image_transform = transforms.Compose(
[
transforms.Resize((resize_height, resize_width), interpolation=transforms.InterpolationMode.LANCZOS),
transforms.CenterCrop((target_height, target_width)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
ref_images_pixel_values = ref_image_transform(first_frame).unsqueeze(0).unsqueeze(2).cuda()
cond_latents = self.vae_encoder.encode(ref_images_pixel_values.to(GET_DTYPE()))
return cond_latents
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from einops import rearrange
from torch import Tensor
from lightx2v.models.video_encoders.hf.hunyuanvideo15.hunyuanvideo_15_vae import (
CausalConv3d,
RMS_norm,
ResnetBlock,
forward_with_checkpointing,
swish,
)
class UpsamplerType(Enum):
LEARNED = "learned"
FIXED = "fixed"
NONE = "none"
LEARNED_FIXED = "learned_fixed"
@dataclass
class UpsamplerConfig:
load_from: str
enable: bool = False
hidden_channels: int = 128
num_blocks: int = 16
model_type: UpsamplerType = UpsamplerType.NONE
version: str = "720p"
class SRResidualCausalBlock3D(nn.Module):
def __init__(self, channels: int):
super().__init__()
self.block = nn.Sequential(
CausalConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
CausalConv3d(channels, channels, kernel_size=3),
nn.SiLU(inplace=True),
CausalConv3d(channels, channels, kernel_size=3),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)
class SRModel3DV2(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: int | None = None,
num_blocks: int = 6,
global_residual: bool = False,
):
super().__init__()
if hidden_channels is None:
hidden_channels = 64
self.in_conv = CausalConv3d(in_channels, hidden_channels, kernel_size=3)
self.blocks = nn.ModuleList([SRResidualCausalBlock3D(hidden_channels) for _ in range(num_blocks)])
self.out_conv = CausalConv3d(hidden_channels, out_channels, kernel_size=3)
self.global_residual = bool(global_residual)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
y = self.in_conv(x)
for blk in self.blocks:
y = blk(y)
y = self.out_conv(y)
if self.global_residual and (y.shape == residual.shape):
y = y + residual
return y
class Upsampler(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: tuple[int, ...],
num_res_blocks: int = 2,
is_residual: bool = False,
):
super().__init__()
self.num_res_blocks = num_res_blocks
self.block_out_channels = block_out_channels
self.z_channels = z_channels
# assert block_out_channels[0] % z_channels == 0
block_in = block_out_channels[0]
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3)
self.up = nn.ModuleList()
for i_level, ch in enumerate(block_out_channels):
block = nn.ModuleList()
block_out = ch
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
self.up.append(up)
self.norm_out = RMS_norm(block_in, images=False)
self.conv_out = CausalConv3d(block_in, out_channels, kernel_size=3)
self.gradient_checkpointing = False
self.is_residual = is_residual
def forward(self, z: Tensor, target_shape: Sequence[int] = None) -> Tensor:
"""
Args:
z: (B, C, T, H, W)
target_shape: (H, W)
"""
use_checkpointing = bool(self.training and self.gradient_checkpointing)
if target_shape is not None and z.shape[-2:] != target_shape:
bsz = z.shape[0]
z = rearrange(z, "b c f h w -> (b f) c h w")
z = F.interpolate(z, size=target_shape, mode="bilinear", align_corners=False)
z = rearrange(z, "(b f) c h w -> b c f h w", b=bsz)
# z to block_in
repeats = self.block_out_channels[0] // (self.z_channels)
h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
# upsampling
for i_level in range(len(self.block_out_channels)):
for i_block in range(self.num_res_blocks + 1):
h = forward_with_checkpointing(
self.up[i_level].block[i_block],
h,
use_checkpointing=use_checkpointing,
)
if hasattr(self.up[i_level], "upsample"):
h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
......@@ -450,7 +450,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img = img_path
else:
ref_img = load_image(img_path)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(self.init_device)
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).to(self.run_device)
ref_img, h, w = resize_image(
ref_img,
......@@ -538,7 +538,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def prepare_prev_latents(self, prev_video: Optional[torch.Tensor], prev_frame_length: int) -> Optional[Dict[str, torch.Tensor]]:
"""Prepare previous latents for conditioning"""
device = self.init_device
device = self.run_device
dtype = GET_DTYPE()
tgt_h, tgt_w = self.input_info.target_shape[0], self.input_info.target_shape[1]
......@@ -835,7 +835,7 @@ class WanAudioRunner(WanRunner): # type:ignore
def load_audio_encoder(self):
audio_encoder_path = self.config.get("audio_encoder_path", os.path.join(self.config["model_path"], "TencentGameMate-chinese-hubert-large"))
audio_encoder_offload = self.config.get("audio_encoder_cpu_offload", self.config.get("cpu_offload", False))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, device=self.config.get("run_device", "cuda"))
model = SekoAudioEncoderModel(audio_encoder_path, self.config["audio_sr"], audio_encoder_offload, run_device=self.config.get("run_device", "cuda"))
return model
def load_audio_adapter(self):
......@@ -843,7 +843,8 @@ class WanAudioRunner(WanRunner): # type:ignore
if audio_adapter_offload:
device = torch.device("cpu")
else:
device = torch.device(self.config.get("run_device", "cuda"))
device = torch.device(self.run_device)
audio_adapter = AudioAdapter(
attention_head_dim=self.config["dim"] // self.config["num_heads"],
num_attention_heads=self.config["num_heads"],
......@@ -856,7 +857,7 @@ class WanAudioRunner(WanRunner): # type:ignore
quantized=self.config.get("adapter_quantized", False),
quant_scheme=self.config.get("adapter_quant_scheme", None),
cpu_offload=audio_adapter_offload,
device=self.config.get("run_device", "cuda"),
run_device=self.run_device,
)
audio_adapter.to(device)
......@@ -896,6 +897,7 @@ class Wan22AudioRunner(WanAudioRunner):
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device,
"run_device": self.run_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
......@@ -912,6 +914,7 @@ class Wan22AudioRunner(WanAudioRunner):
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", "Wan2.2_VAE.pth"),
"device": vae_device,
"run_device": self.run_device,
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
......
......@@ -65,7 +65,7 @@ class WanRunner(DefaultRunner):
if clip_offload:
clip_device = torch.device("cpu")
else:
clip_device = torch.device(self.init_device)
clip_device = torch.device(self.run_device)
# quant_config
clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized:
......@@ -84,6 +84,7 @@ class WanRunner(DefaultRunner):
image_encoder = CLIPModel(
dtype=torch.float16,
device=clip_device,
run_device=self.run_device,
checkpoint_path=clip_original_ckpt,
clip_quantized=clip_quantized,
clip_quantized_ckpt=clip_quantized_ckpt,
......@@ -101,7 +102,7 @@ class WanRunner(DefaultRunner):
if t5_offload:
t5_device = torch.device("cpu")
else:
t5_device = torch.device(self.init_device)
t5_device = torch.device(self.run_device)
# quant_config
t5_quantized = self.config.get("t5_quantized", False)
......@@ -123,6 +124,7 @@ class WanRunner(DefaultRunner):
text_encoder = T5EncoderModel(
text_len=self.config["text_len"],
dtype=torch.bfloat16,
run_device=self.run_device,
device=t5_device,
checkpoint_path=t5_original_ckpt,
tokenizer_path=tokenizer_path,
......@@ -142,11 +144,12 @@ class WanRunner(DefaultRunner):
if vae_offload:
vae_device = torch.device("cpu")
else:
vae_device = torch.device(self.init_device)
vae_device = torch.device(self.run_device)
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
"device": vae_device,
"run_device": self.run_device,
"parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
......@@ -170,6 +173,7 @@ class WanRunner(DefaultRunner):
vae_config = {
"vae_path": find_torch_model_path(self.config, "vae_path", self.vae_name),
"device": vae_device,
"run_device": self.run_device,
"parallel": self.config["parallel"],
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
......@@ -222,7 +226,7 @@ class WanRunner(DefaultRunner):
monitor_cli.lightx2v_input_prompt_len.observe(len(prompt))
neg_prompt = input_info.negative_prompt
if self.config["cfg_parallel"]:
if self.config.get("enable_cfg", False) and self.config["cfg_parallel"]:
cfg_p_group = self.config["device_mesh"].get_group(mesh_dim="cfg_p")
cfg_p_rank = dist.get_rank(cfg_p_group)
if cfg_p_rank == 0:
......@@ -236,8 +240,11 @@ class WanRunner(DefaultRunner):
else:
context = self.text_encoders[0].infer([prompt])
context = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context])
context_null = self.text_encoders[0].infer([neg_prompt])
context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
if self.config.get("enable_cfg", False):
context_null = self.text_encoders[0].infer([neg_prompt])
context_null = torch.stack([torch.cat([u, u.new_zeros(self.config["text_len"] - u.size(0), u.size(1))]) for u in context_null])
else:
context_null = None
text_encoder_output = {
"context": context,
"context_null": context_null,
......@@ -376,12 +383,12 @@ class WanRunner(DefaultRunner):
]
return latent_shape
def get_latent_shape_with_target_hw(self, target_h, target_w):
def get_latent_shape_with_target_hw(self):
latent_shape = [
self.config.get("num_channels_latents", 16),
(self.config["target_video_length"] - 1) // self.config["vae_stride"][0] + 1,
int(target_h) // self.config["vae_stride"][1],
int(target_w) // self.config["vae_stride"][2],
int(self.config["target_height"]) // self.config["vae_stride"][1],
int(self.config["target_width"]) // self.config["vae_stride"][2],
]
return latent_shape
......
from lightx2v.models.schedulers.hunyuan_video.scheduler import HunyuanVideo15Scheduler
class HunyuanVideo15SchedulerCaching(HunyuanVideo15Scheduler):
def __init__(self, config):
super().__init__(config)
def clear(self):
self.transformer_infer.clear()
from typing import List, Tuple, Union
import torch
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
def reshape_for_broadcast(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
x: torch.Tensor,
):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Notes:
When using FlashMHAModified, head_first should be False.
When using Attention, head_first should be True.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = reshape_for_broadcast(freqs_cis, xq) # [S, D]
# real * cos - imag * sin
# imag * cos + real * sin
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
return xq_out, xk_out
def rotate_half_force_bf16(x):
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb_force_bf16(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
cos, sin = reshape_for_broadcast(freqs_cis, xq) # [S, D]
# real * cos - imag * sin
# imag * cos + real * sin
xq_out = xq * cos + rotate_half_force_bf16(xq) * sin
xk_out = xk * cos + rotate_half_force_bf16(xk) * sin
return xq_out, xk_out
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs).cuda() # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
import torch
import torch.distributed as dist
from einops import rearrange
from torch.nn import functional as F
from lightx2v.models.schedulers.scheduler import BaseScheduler
from .posemb_layers import get_nd_rotary_pos_embed
class HunyuanVideo15Scheduler(BaseScheduler):
def __init__(self, config):
super().__init__(config)
self.device = torch.device("cuda")
self.reverse = True
self.num_train_timesteps = 1000
self.sample_shift = self.config["sample_shift"]
self.reorg_token = False
self.keep_latents_dtype_in_scheduler = True
self.sample_guide_scale = self.config["sample_guide_scale"]
if self.config["seq_parallel"]:
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
else:
self.seq_p_group = None
def prepare(self, seed, latent_shape, image_encoder_output=None):
self.prepare_latents(seed, latent_shape, dtype=torch.bfloat16)
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
self.multitask_mask = self.get_task_mask(self.config["task"], latent_shape[-3])
self.cond_latents_concat, self.mask_concat = self._prepare_cond_latents_and_mask(self.config["task"], image_encoder_output["cond_latents"], self.latents, self.multitask_mask, self.reorg_token)
self.cos_sin = self.prepare_cos_sin((latent_shape[1], latent_shape[2], latent_shape[3]))
def prepare_latents(self, seed, latent_shape, dtype=torch.bfloat16):
self.generator = torch.Generator(device=self.device).manual_seed(seed)
self.latents = torch.randn(
1,
latent_shape[0],
latent_shape[1],
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=self.device,
generator=self.generator,
)
def set_timesteps(self, num_inference_steps, device, shift):
sigmas = torch.linspace(1, 0, num_inference_steps + 1)
# Apply timestep shift
if shift != 1.0:
sigmas = self.sd3_time_shift(sigmas, shift)
if not self.reverse:
sigmas = 1 - sigmas
self.sigmas = sigmas
self.timesteps = (sigmas[:-1] * self.num_train_timesteps).to(dtype=torch.float32, device=device)
def sd3_time_shift(self, t: torch.Tensor, shift):
return (shift * t) / (1 + (shift - 1) * t)
def get_task_mask(self, task_type, latent_target_length):
if task_type == "t2v":
mask = torch.zeros(latent_target_length)
elif task_type == "i2v":
mask = torch.zeros(latent_target_length)
mask[0] = 1.0
else:
raise ValueError(f"{task_type} is not supported !")
return mask
def _prepare_cond_latents_and_mask(self, task_type, cond_latents, latents, multitask_mask, reorg_token):
"""
Prepare multitask mask training logic.
Args:
task_type: Type of task ("i2v" or "t2v")
cond_latents: Conditional latents tensor
latents: Main latents tensor
multitask_mask: Multitask mask tensor
reorg_token: Whether to reorganize tokens
Returns:
tuple: (latents_concat, mask_concat) - may contain None values
"""
latents_concat = None
mask_concat = None
if cond_latents is not None and task_type == "i2v":
latents_concat = cond_latents.repeat(1, 1, latents.shape[2], 1, 1)
latents_concat[:, :, 1:, :, :] = 0.0
else:
if reorg_token:
latents_concat = torch.zeros(latents.shape[0], latents.shape[1] // 2, latents.shape[2], latents.shape[3], latents.shape[4]).to(latents.device)
else:
latents_concat = torch.zeros(latents.shape[0], latents.shape[1], latents.shape[2], latents.shape[3], latents.shape[4]).to(latents.device)
mask_zeros = torch.zeros(latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4])
mask_ones = torch.ones(latents.shape[0], 1, latents.shape[2], latents.shape[3], latents.shape[4])
mask_concat = self.merge_tensor_by_mask(mask_zeros.cpu(), mask_ones.cpu(), mask=multitask_mask.cpu(), dim=2).to(device=latents.device)
return latents_concat, mask_concat
def merge_tensor_by_mask(self, tensor_1, tensor_2, mask, dim):
assert tensor_1.shape == tensor_2.shape
# Mask is a 0/1 vector. Choose tensor_2 when the value is 1; otherwise, tensor_1
masked_indices = torch.nonzero(mask).squeeze(1)
tmp = tensor_1.clone()
if dim == 0:
tmp[masked_indices] = tensor_2[masked_indices]
elif dim == 1:
tmp[:, masked_indices] = tensor_2[:, masked_indices]
elif dim == 2:
tmp[:, :, masked_indices] = tensor_2[:, :, masked_indices]
return tmp
def step_post(self):
model_output = self.noise_pred.to(torch.float32)
sample = self.latents.to(torch.float32)
dt = self.sigmas[self.step_index + 1] - self.sigmas[self.step_index]
self.latents = sample + model_output * dt
def prepare_cos_sin(self, rope_sizes):
target_ndim = 3
head_dim = self.config["hidden_size"] // self.config["heads_num"]
rope_dim_list = self.config["rope_dim_list"]
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=self.config["rope_theta"],
use_real=True,
theta_rescale_factor=1,
)
cos_half = freqs_cos[:, ::2].contiguous()
sin_half = freqs_sin[:, ::2].contiguous()
cos_sin = torch.cat([cos_half, sin_half], dim=-1)
if self.seq_p_group is not None:
world_size = dist.get_world_size(self.seq_p_group)
cur_rank = dist.get_rank(self.seq_p_group)
seqlen = cos_sin.shape[0]
padding_size = (world_size - (seqlen % world_size)) % world_size
if padding_size > 0:
cos_sin = F.pad(cos_sin, (0, 0, 0, padding_size))
cos_sin = torch.chunk(cos_sin, world_size, dim=0)[cur_rank]
return cos_sin
class HunyuanVideo15SRScheduler(HunyuanVideo15Scheduler):
def __init__(self, config):
super().__init__(config)
self.noise_scale = 0.7
def prepare(self, seed, latent_shape, lq_latents, upsampler, image_encoder_output=None):
dtype = lq_latents.dtype
device = lq_latents.device
self.prepare_latents(seed, latent_shape, lq_latents, dtype=dtype)
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
self.cos_sin = self.prepare_cos_sin((latent_shape[1], latent_shape[2], latent_shape[3]))
tgt_shape = latent_shape[-2:]
bsz = lq_latents.shape[0]
lq_latents = rearrange(lq_latents, "b c f h w -> (b f) c h w")
lq_latents = F.interpolate(lq_latents, size=tgt_shape, mode="bilinear", align_corners=False)
lq_latents = rearrange(lq_latents, "(b f) c h w -> b c f h w", b=bsz)
lq_latents = upsampler(lq_latents.to(dtype=torch.float32, device=device))
lq_latents = lq_latents.to(dtype=dtype)
lq_latents = self.add_noise_to_lq(lq_latents, self.noise_scale)
condition = self.get_condition(lq_latents, image_encoder_output["cond_latents"], self.config["task"])
c = lq_latents.shape[1]
zero_condition = condition.clone()
zero_condition[:, c + 1 : 2 * c + 1] = torch.zeros_like(lq_latents)
zero_condition[:, 2 * c + 1] = 0
self.condition = condition
self.zero_condition = zero_condition
def prepare_latents(self, seed, latent_shape, lq_latents, dtype=torch.bfloat16):
self.generator = torch.Generator(device=lq_latents.device).manual_seed(seed)
self.latents = torch.randn(
1,
latent_shape[0],
latent_shape[1],
latent_shape[2],
latent_shape[3],
dtype=dtype,
device=lq_latents.device,
generator=self.generator,
)
def get_condition(self, lq_latents, img_cond, task):
"""
latents: shape (b c f h w)
"""
b, c, f, h, w = self.latents.shape
cond = torch.zeros([b, c * 2 + 2, f, h, w], device=lq_latents.device, dtype=lq_latents.dtype)
cond[:, c + 1 : 2 * c + 1] = lq_latents
cond[:, 2 * c + 1] = 1
if "t2v" in task:
return cond
elif "i2v" in task:
cond[:, :c, :1] = img_cond
cond[:, c + 1, 0] = 1
return cond
else:
raise ValueError(f"Unsupported task: {task}")
def add_noise_to_lq(self, lq_latents, strength=0.7):
def expand_dims(tensor: torch.Tensor, ndim: int):
shape = tensor.shape + (1,) * (ndim - tensor.ndim)
return tensor.reshape(shape)
noise = torch.randn_like(lq_latents)
timestep = torch.tensor([1000.0], device=lq_latents.device) * strength
t = expand_dims(timestep, lq_latents.ndim)
return (1 - t / 1000.0) * lq_latents + (t / 1000.0) * noise
......@@ -11,10 +11,11 @@ class BaseScheduler:
self.flag_df = False
self.transformer_infer = None
self.infer_condition = True # cfg status
self.keep_latents_dtype_in_scheduler = False
def step_pre(self, step_index):
self.step_index = step_index
if GET_DTYPE() == GET_SENSITIVE_DTYPE():
if GET_DTYPE() == GET_SENSITIVE_DTYPE() and not self.keep_latents_dtype_in_scheduler:
self.latents = self.latents.to(GET_DTYPE())
def clear(self):
......
import math
import os
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models.autoencoders.vae import BaseOutput, DiagonalGaussianDistribution
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.models.modeling_utils import ModelMixin
from einops import rearrange
from torch import Tensor, nn
@dataclass
class DecoderOutput(BaseOutput):
sample: torch.FloatTensor
posterior: Optional[DiagonalGaussianDistribution] = None
def swish(x: Tensor) -> Tensor:
"""Applies the swish activation function."""
return x * torch.sigmoid(x)
def forward_with_checkpointing(module, *inputs, use_checkpointing=False):
"""Forward with optional gradient checkpointing."""
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
if use_checkpointing:
return torch.utils.checkpoint.checkpoint(create_custom_forward(module), *inputs, use_reentrant=False)
else:
return module(*inputs)
# Optimized implementation of CogVideoXSafeConv3d
# https://github.com/huggingface/diffusers/blob/c9ff360966327ace3faad3807dc871a4e5447501/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py#L38
class PatchCausalConv3d(nn.Conv3d):
r"""Causal Conv3d with efficient patch processing for large tensors."""
def find_split_indices(self, seq_len, part_num):
ideal_interval = seq_len / part_num
possible_indices = list(range(0, seq_len, self.stride[0]))
selected_indices = []
for i in range(1, part_num):
closest = min(possible_indices, key=lambda x: abs(x - round(i * ideal_interval)))
if closest not in selected_indices:
selected_indices.append(closest)
merged_indices = []
prev_idx = 0
for idx in selected_indices:
if idx - prev_idx >= self.kernel_size[0]:
merged_indices.append(idx)
prev_idx = idx
return merged_indices
def forward(self, input):
T = input.shape[2] # input: NCTHW
memory_count = torch.prod(torch.tensor(input.shape)).item() * 2 / 1024**3
if T > self.kernel_size[0] and memory_count > 2:
kernel_size = self.kernel_size[0]
part_num = int(memory_count / 2) + 1
split_indices = self.find_split_indices(T, part_num)
input_chunks = torch.tensor_split(input, split_indices, dim=2) if len(split_indices) > 0 else [input]
if kernel_size > 1:
input_chunks = [input_chunks[0]] + [torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2) for i in range(1, len(input_chunks))]
output_chunks = []
for input_chunk in input_chunks:
output_chunks.append(super().forward(input_chunk))
output = torch.cat(output_chunks, dim=2)
return output
else:
return super().forward(input)
class RMS_norm(nn.Module):
"""Root Mean Square Layer Normalization for Channel-First or Last"""
def __init__(self, dim, channel_first=True, images=True, bias=False):
super().__init__()
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
self.channel_first = channel_first
self.scale = dim**0.5
self.gamma = nn.Parameter(torch.ones(shape))
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
class Conv3d(nn.Conv3d):
"""Perform Conv3d on patches with memory-efficient symmetric padding."""
def forward(self, input):
B, C, T, H, W = input.shape
memory_count = (C * T * H * W) * 2 / 1024**3
n_split = math.ceil(memory_count / 2)
if memory_count > 2 and input.shape[-3] % n_split == 0:
chunks = torch.chunk(input, chunks=n_split, dim=-3)
padded_chunks = []
for i in range(len(chunks)):
if self.padding[0] > 0:
padded_chunk = F.pad(
chunks[i],
(0, 0, 0, 0, self.padding[0], self.padding[0]),
mode="constant" if self.padding_mode == "zeros" else self.padding_mode,
value=0,
)
if i > 0:
padded_chunk[:, :, : self.padding[0]] = chunks[i - 1][:, :, -self.padding[0] :]
if i < len(chunks) - 1:
padded_chunk[:, :, -self.padding[0] :] = chunks[i + 1][:, :, : self.padding[0]]
else:
padded_chunk = chunks[i]
padded_chunks.append(padded_chunk)
padding_bak = self.padding
self.padding = (0, self.padding[1], self.padding[2])
outputs = []
for chunk in padded_chunks:
outputs.append(super().forward(chunk))
self.padding = padding_bak
return torch.cat(outputs, dim=-3)
else:
return super().forward(input)
class CausalConv3d(nn.Module):
"""Causal Conv3d with configurable padding for temporal axis."""
def __init__(
self,
chan_in,
chan_out,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
pad_mode="replicate",
disable_causal=False,
enable_patch_conv=False,
**kwargs,
):
super().__init__()
self.pad_mode = pad_mode
if disable_causal:
padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2)
else:
padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
self.time_causal_padding = padding
if enable_patch_conv:
self.conv = PatchCausalConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
else:
self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, x):
x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
return self.conv(x)
def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
"""Prepare a causal attention mask for 3D videos.
Args:
n_frame (int): Number of frames (temporal length).
n_hw (int): Product of height and width.
dtype: Desired mask dtype.
device: Device for the mask.
batch_size (int, optional): If set, expands for batch.
Returns:
torch.Tensor: Causal attention mask.
"""
seq_len = n_frame * n_hw
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device="cuda")
for i in range(seq_len):
i_frame = i // n_hw
mask[i, : (i_frame + 1) * n_hw] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
class AttnBlock(nn.Module):
"""Self-attention block for 3D video tensors."""
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = RMS_norm(in_channels, images=False)
self.q = Conv3d(in_channels, in_channels, kernel_size=1)
self.k = Conv3d(in_channels, in_channels, kernel_size=1)
self.v = Conv3d(in_channels, in_channels, kernel_size=1)
self.proj_out = Conv3d(in_channels, in_channels, kernel_size=1)
def attention(self, h_: Tensor) -> Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, f, h, w = q.shape
q = rearrange(q, "b c f h w -> b 1 (f h w) c").contiguous()
k = rearrange(k, "b c f h w -> b 1 (f h w) c").contiguous()
v = rearrange(v, "b c f h w -> b 1 (f h w) c").contiguous()
attention_mask = prepare_causal_attention_mask(f, h * w, h_.dtype, h_.device, batch_size=b)
h_ = nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attention_mask.unsqueeze(1))
return rearrange(h_, "b 1 (f h w) c -> b c f h w", f=f, h=h, w=w, c=c, b=b)
def forward(self, x: Tensor) -> Tensor:
return x + self.proj_out(self.attention(x))
class ResnetBlock(nn.Module):
"""ResNet-style block for 3D video tensors."""
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = RMS_norm(in_channels, images=False)
self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3)
self.norm2 = RMS_norm(out_channels, images=False)
self.conv2 = CausalConv3d(out_channels, out_channels, kernel_size=3)
if self.in_channels != self.out_channels:
self.nin_shortcut = Conv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_downsample else 1 * 2 * 2
assert out_channels % factor == 0
self.conv = CausalConv3d(in_channels, out_channels // factor, kernel_size=3)
self.add_temporal_downsample = add_temporal_downsample
self.group_size = factor * in_channels // out_channels
def forward(self, x: Tensor):
r1 = 2 if self.add_temporal_downsample else 1
h = self.conv(x)
if self.add_temporal_downsample:
h_first = h[:, :, :1, :, :]
h_first = rearrange(h_first, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
h_first = torch.cat([h_first, h_first], dim=1)
h_next = h[:, :, 1:, :, :]
h_next = rearrange(h_next, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
h = torch.cat([h_first, h_next], dim=2)
# shortcut computation
x_first = x[:, :, :1, :, :]
x_first = rearrange(x_first, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
B, C, T, H, W = x_first.shape
x_first = x_first.view(B, h.shape[1], self.group_size // 2, T, H, W).mean(dim=2)
x_next = x[:, :, 1:, :, :]
x_next = rearrange(x_next, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
B, C, T, H, W = x_next.shape
x_next = x_next.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
shortcut = torch.cat([x_first, x_next], dim=2)
else:
h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
B, C, T, H, W = shortcut.shape
shortcut = shortcut.view(B, h.shape[1], self.group_size, T, H, W).mean(dim=2)
return h + shortcut
class Upsample(nn.Module):
"""Hierarchical upsampling with temporal/ spatial support."""
def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: bool = True):
super().__init__()
factor = 2 * 2 * 2 if add_temporal_upsample else 1 * 2 * 2
self.conv = CausalConv3d(in_channels, out_channels * factor, kernel_size=3)
self.add_temporal_upsample = add_temporal_upsample
self.repeats = factor * out_channels // in_channels
def forward(self, x: Tensor):
r1 = 2 if self.add_temporal_upsample else 1
h = self.conv(x)
if self.add_temporal_upsample:
h_first = h[:, :, :1, :, :]
h_first = rearrange(h_first, "b (r2 r3 c) f h w -> b c f (h r2) (w r3)", r2=2, r3=2)
h_first = h_first[:, : h_first.shape[1] // 2]
h_next = h[:, :, 1:, :, :]
h_next = rearrange(h_next, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
h = torch.cat([h_first, h_next], dim=2)
# shortcut computation
x_first = x[:, :, :1, :, :]
x_first = rearrange(x_first, "b (r2 r3 c) f h w -> b c f (h r2) (w r3)", r2=2, r3=2)
x_first = x_first.repeat_interleave(repeats=self.repeats // 2, dim=1)
x_next = x[:, :, 1:, :, :]
x_next = rearrange(x_next, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
x_next = x_next.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = torch.cat([x_first, x_next], dim=2)
else:
h = rearrange(h, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
shortcut = x.repeat_interleave(repeats=self.repeats, dim=1)
shortcut = rearrange(shortcut, "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)", r1=r1, r2=2, r3=2)
return h + shortcut
class Encoder(nn.Module):
"""Hierarchical video encoder with temporal and spatial factorization."""
def __init__(
self,
in_channels: int,
z_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
ffactor_spatial: int,
ffactor_temporal: int,
downsample_match_channel: bool = True,
):
super().__init__()
assert block_out_channels[-1] % (2 * z_channels) == 0
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
# downsampling
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3)
self.down = nn.ModuleList()
block_in = block_out_channels[0]
for i_level, ch in enumerate(block_out_channels):
block = nn.ModuleList()
block_out = ch
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = nn.Module()
down.block = block
add_spatial_downsample = bool(i_level < np.log2(ffactor_spatial))
add_temporal_downsample = add_spatial_downsample and bool(i_level >= np.log2(ffactor_spatial // ffactor_temporal))
if add_spatial_downsample or add_temporal_downsample:
assert i_level < len(block_out_channels) - 1
block_out = block_out_channels[i_level + 1] if downsample_match_channel else block_in
down.downsample = Downsample(block_in, block_out, add_temporal_downsample)
block_in = block_out
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = RMS_norm(block_in, images=False)
self.conv_out = CausalConv3d(block_in, 2 * z_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, x: Tensor) -> Tensor:
"""Forward pass through the encoder."""
use_checkpointing = bool(self.training and self.gradient_checkpointing)
# downsampling
h = self.conv_in(x)
for i_level in range(len(self.block_out_channels)):
for i_block in range(self.num_res_blocks):
h = forward_with_checkpointing(self.down[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
if hasattr(self.down[i_level], "downsample"):
h = forward_with_checkpointing(self.down[i_level].downsample, h, use_checkpointing=use_checkpointing)
# middle
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
# end
group_size = self.block_out_channels[-1] // (2 * self.z_channels)
shortcut = rearrange(h, "b (c r) f h w -> b c r f h w", r=group_size).mean(dim=2)
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
h += shortcut
return h
class Decoder(nn.Module):
"""Hierarchical video decoder with upsampling factories."""
def __init__(
self,
z_channels: int,
out_channels: int,
block_out_channels: Tuple[int, ...],
num_res_blocks: int,
ffactor_spatial: int,
ffactor_temporal: int,
upsample_match_channel: bool = True,
):
super().__init__()
assert block_out_channels[0] % z_channels == 0
self.z_channels = z_channels
self.block_out_channels = block_out_channels
self.num_res_blocks = num_res_blocks
block_in = block_out_channels[0]
self.conv_in = CausalConv3d(z_channels, block_in, kernel_size=3)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid.attn_1 = AttnBlock(block_in)
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = nn.ModuleList()
for i_level, ch in enumerate(block_out_channels):
block = nn.ModuleList()
block_out = ch
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = nn.Module()
up.block = block
add_spatial_upsample = bool(i_level < np.log2(ffactor_spatial))
add_temporal_upsample = bool(i_level < np.log2(ffactor_temporal))
if add_spatial_upsample or add_temporal_upsample:
assert i_level < len(block_out_channels) - 1
block_out = block_out_channels[i_level + 1] if upsample_match_channel else block_in
up.upsample = Upsample(block_in, block_out, add_temporal_upsample)
block_in = block_out
self.up.append(up)
# end
self.norm_out = RMS_norm(block_in, images=False)
self.conv_out = CausalConv3d(block_in, out_channels, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, z: Tensor) -> Tensor:
"""Forward pass through the decoder."""
use_checkpointing = bool(self.training and self.gradient_checkpointing)
# z to block_in
repeats = self.block_out_channels[0] // (self.z_channels)
h = self.conv_in(z) + z.repeat_interleave(repeats=repeats, dim=1)
# middle
h = forward_with_checkpointing(self.mid.block_1, h, use_checkpointing=use_checkpointing)
h = forward_with_checkpointing(self.mid.attn_1, h, use_checkpointing=use_checkpointing)
h = forward_with_checkpointing(self.mid.block_2, h, use_checkpointing=use_checkpointing)
# upsampling
for i_level in range(len(self.block_out_channels)):
for i_block in range(self.num_res_blocks + 1):
h = forward_with_checkpointing(self.up[i_level].block[i_block], h, use_checkpointing=use_checkpointing)
if hasattr(self.up[i_level], "upsample"):
h = forward_with_checkpointing(self.up[i_level].upsample, h, use_checkpointing=use_checkpointing)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
class AutoencoderKLConv3D(ModelMixin, ConfigMixin):
"""KL regularized 3D Conv VAE with advanced tiling and slicing strategies."""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int,
out_channels: int,
latent_channels: int,
block_out_channels: Tuple[int, ...],
layers_per_block: int,
ffactor_spatial: int,
ffactor_temporal: int,
sample_size: int,
sample_tsize: int,
scaling_factor: float = None,
shift_factor: Optional[float] = None,
downsample_match_channel: bool = True,
upsample_match_channel: bool = True,
spatial_compression_ratio: int = 16,
time_compression_ratio: int = 4,
):
super().__init__()
self.ffactor_spatial = ffactor_spatial
self.ffactor_temporal = ffactor_temporal
self.scaling_factor = scaling_factor
self.shift_factor = shift_factor
self.encoder = Encoder(
in_channels=in_channels,
z_channels=latent_channels,
block_out_channels=block_out_channels,
num_res_blocks=layers_per_block,
ffactor_spatial=ffactor_spatial,
ffactor_temporal=ffactor_temporal,
downsample_match_channel=downsample_match_channel,
)
self.decoder = Decoder(
z_channels=latent_channels,
out_channels=out_channels,
block_out_channels=list(reversed(block_out_channels)),
num_res_blocks=layers_per_block,
ffactor_spatial=ffactor_spatial,
ffactor_temporal=ffactor_temporal,
upsample_match_channel=upsample_match_channel,
)
self.use_slicing = False
self.use_spatial_tiling = False
self.use_temporal_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = sample_size
self.tile_latent_min_size = sample_size // ffactor_spatial
self.tile_sample_min_tsize = sample_tsize
self.tile_latent_min_tsize = sample_tsize // ffactor_temporal
self.tile_overlap_factor = 0.25
def _set_gradient_checkpointing(self, module, value=False):
"""Enable or disable gradient checkpointing on encoder and decoder."""
if isinstance(module, (Encoder, Decoder)):
module.gradient_checkpointing = value
def enable_temporal_tiling(self, use_tiling: bool = True):
self.use_temporal_tiling = use_tiling
def disable_temporal_tiling(self):
self.enable_temporal_tiling(False)
def enable_spatial_tiling(self, use_tiling: bool = True):
self.use_spatial_tiling = use_tiling
def disable_spatial_tiling(self):
self.enable_spatial_tiling(False)
def enable_tiling(self, use_tiling: bool = True):
self.enable_spatial_tiling(use_tiling)
def disable_tiling(self):
self.disable_spatial_tiling()
def enable_slicing(self):
self.use_slicing = True
def disable_slicing(self):
self.use_slicing = False
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
"""Blend tensor b horizontally into a at blend_extent region."""
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
return b
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
"""Blend tensor b vertically into a at blend_extent region."""
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
return b
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int):
"""Blend tensor b temporally into a at blend_extent region."""
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
for x in range(blend_extent):
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
return b
def spatial_tiled_encode(self, x: torch.Tensor):
"""Tiled spatial encoding for large inputs via overlapping."""
B, C, T, H, W = x.shape
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
rows = []
for i in range(0, H, overlap_size):
row = []
for j in range(0, W, overlap_size):
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
moments = torch.cat(result_rows, dim=-2)
return moments
def temporal_tiled_encode(self, x: torch.Tensor):
"""Tiled temporal encoding for large video sequences."""
B, C, T, H, W = x.shape
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
t_limit = self.tile_latent_min_tsize - blend_extent
row = []
for i in range(0, T, overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
tile = self.spatial_tiled_encode(tile)
else:
tile = self.encoder(tile)
if i > 0:
tile = tile[:, :, 1:, :, :]
row.append(tile)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, : t_limit + 1, :, :])
moments = torch.cat(result_row, dim=-3)
return moments
def spatial_tiled_decode(self, z: torch.Tensor):
"""Tiled spatial decoding for large latent maps."""
B, C, T, H, W = z.shape
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
rows = []
for i in range(0, H, overlap_size):
row = []
for j in range(0, W, overlap_size):
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=-1))
dec = torch.cat(result_rows, dim=-2)
return dec
def temporal_tiled_decode(self, z: torch.Tensor):
"""Tiled temporal decoding for long sequence latents."""
B, C, T, H, W = z.shape
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
t_limit = self.tile_sample_min_tsize - blend_extent
assert 0 < overlap_size < self.tile_latent_min_tsize
row = []
for i in range(0, T, overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
decoded = self.spatial_tiled_decode(tile)
else:
decoded = self.decoder(tile)
if i > 0:
decoded = decoded[:, :, 1:, :, :]
row.append(decoded)
result_row = []
for i, tile in enumerate(row):
if i > 0:
tile = self.blend_t(row[i - 1], tile, blend_extent)
result_row.append(tile[:, :, :t_limit, :, :])
else:
result_row.append(tile[:, :, : t_limit + 1, :, :])
dec = torch.cat(result_row, dim=-3)
return dec
@torch.no_grad()
def encode(self, x: Tensor, return_dict: bool = True):
if self.cpu_offload:
self.encoder = self.encoder.to("cuda")
def _encode(x):
if self.use_temporal_tiling and x.shape[-3] > self.tile_sample_min_tsize:
return self.temporal_tiled_encode(x)
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
return self.spatial_tiled_encode(x)
return self.encoder(x)
assert len(x.shape) == 5 # (B, C, T, H, W)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [_encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = _encode(x)
posterior = DiagonalGaussianDistribution(h)
if self.cpu_offload:
self.encoder = self.encoder.to("cpu")
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
@torch.no_grad()
def decode(self, z: Tensor, return_dict: bool = True, generator=None):
if self.cpu_offload:
self.decoder = self.decoder.to("cuda")
def _decode(z):
if self.use_temporal_tiling and z.shape[-3] > self.tile_latent_min_tsize:
return self.temporal_tiled_decode(z)
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.spatial_tiled_decode(z)
return self.decoder(z)
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [_decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = _decode(z)
if self.cpu_offload:
self.decoder = self.decoder.to("cpu")
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
@torch.no_grad()
def forward(self, sample: torch.Tensor, sample_posterior: bool = False, return_posterior: bool = True, return_dict: bool = True):
"""Forward autoencoder pass. Returns both reconstruction and optionally the posterior."""
posterior = self.encode(sample).latent_dist
z = posterior.sample() if sample_posterior else posterior.mode()
dec = self.decode(z).sample
return DecoderOutput(sample=dec, posterior=posterior) if return_dict else (dec, posterior)
class HunyuanVideo15VAE:
def __init__(self, checkpoint_path=None, dtype=torch.float16, device="cuda", cpu_offload=False):
self.vae = AutoencoderKLConv3D.from_pretrained(os.path.join(checkpoint_path, "vae")).to(dtype).to(device)
self.vae.cpu_offload = cpu_offload
@torch.no_grad()
def encode(self, x):
return self.vae.encode(x).latent_dist.mode() * self.vae.config.scaling_factor
@torch.no_grad()
def decode(self, z):
z = z / self.vae.config.scaling_factor
self.vae.enable_tiling()
video_frames = self.vae.decode(z, return_dict=False)[0]
self.vae.disable_tiling()
return video_frames
if __name__ == "__main__":
vae = HunyuanVideo15VAE(checkpoint_path="/data/nvme1/yongyang/models/HunyuanVideo-1.5/ckpts/hunyuanvideo-1.5", dtype=torch.float16, device="cuda")
z = torch.randn(1, 32, 31, 30, 53, dtype=torch.float16, device="cuda")
video_frames = vae.decode(z)
print(video_frames.shape)
import torch
import torch.nn as nn
from lightx2v.models.video_encoders.hf.tae import TAEHV
class LightTaeHy15(nn.Module):
def __init__(self, vae_path="lighttae_hy1_5.pth", dtype=torch.bfloat16):
super().__init__()
self.dtype = dtype
self.taehv = TAEHV(vae_path, model_type="hy15", latent_channels=32, patch_size=2).to(self.dtype)
self.scaling_factor = 1.03682
@torch.no_grad()
def decode(self, latents, parallel=True, show_progress_bar=True, skip_trim=False):
latents = latents / self.scaling_factor
return self.taehv.decode_video(latents.transpose(1, 2).to(self.dtype), parallel, show_progress_bar).transpose(1, 2)
......@@ -27,11 +27,11 @@ class Clamp(nn.Module):
class MemBlock(nn.Module):
def __init__(self, n_in, n_out):
def __init__(self, n_in, n_out, act_func):
super().__init__()
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True), conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
self.conv = nn.Sequential(conv(n_in * 2, n_out), act_func, conv(n_out, n_out), act_func, conv(n_out, n_out))
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
self.act = nn.ReLU(inplace=True)
self.act = act_func
def forward(self, x, past):
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
......@@ -177,27 +177,32 @@ class TAEHV(nn.Module):
self.is_cogvideox = checkpoint_path is not None and "taecvx" in checkpoint_path
# if checkpoint_path is not None and "taew2_2" in checkpoint_path:
# self.patch_size, self.latent_channels = 2, 48
self.model_type = model_type
if model_type == "wan22":
self.patch_size, self.latent_channels = 2, 48
if model_type == "hy15":
act_func = nn.LeakyReLU(0.2, inplace=True)
else:
act_func = nn.ReLU(inplace=True)
self.encoder = nn.Sequential(
conv(self.image_channels * self.patch_size**2, 64),
nn.ReLU(inplace=True),
act_func,
TPool(64, 2),
conv(64, 64, stride=2, bias=False),
MemBlock(64, 64),
MemBlock(64, 64),
MemBlock(64, 64),
MemBlock(64, 64, act_func),
MemBlock(64, 64, act_func),
MemBlock(64, 64, act_func),
TPool(64, 2),
conv(64, 64, stride=2, bias=False),
MemBlock(64, 64),
MemBlock(64, 64),
MemBlock(64, 64),
MemBlock(64, 64, act_func),
MemBlock(64, 64, act_func),
MemBlock(64, 64, act_func),
TPool(64, 1),
conv(64, 64, stride=2, bias=False),
MemBlock(64, 64),
MemBlock(64, 64),
MemBlock(64, 64),
MemBlock(64, 64, act_func),
MemBlock(64, 64, act_func),
MemBlock(64, 64, act_func),
conv(64, self.latent_channels),
)
n_f = [256, 128, 64, 64]
......@@ -205,26 +210,26 @@ class TAEHV(nn.Module):
self.decoder = nn.Sequential(
Clamp(),
conv(self.latent_channels, n_f[0]),
nn.ReLU(inplace=True),
MemBlock(n_f[0], n_f[0]),
MemBlock(n_f[0], n_f[0]),
MemBlock(n_f[0], n_f[0]),
act_func,
MemBlock(n_f[0], n_f[0], act_func),
MemBlock(n_f[0], n_f[0], act_func),
MemBlock(n_f[0], n_f[0], act_func),
nn.Upsample(scale_factor=2 if decoder_space_upscale[0] else 1),
TGrow(n_f[0], 1),
conv(n_f[0], n_f[1], bias=False),
MemBlock(n_f[1], n_f[1]),
MemBlock(n_f[1], n_f[1]),
MemBlock(n_f[1], n_f[1]),
MemBlock(n_f[1], n_f[1], act_func),
MemBlock(n_f[1], n_f[1], act_func),
MemBlock(n_f[1], n_f[1], act_func),
nn.Upsample(scale_factor=2 if decoder_space_upscale[1] else 1),
TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1),
conv(n_f[1], n_f[2], bias=False),
MemBlock(n_f[2], n_f[2]),
MemBlock(n_f[2], n_f[2]),
MemBlock(n_f[2], n_f[2]),
MemBlock(n_f[2], n_f[2], act_func),
MemBlock(n_f[2], n_f[2], act_func),
MemBlock(n_f[2], n_f[2], act_func),
nn.Upsample(scale_factor=2 if decoder_space_upscale[2] else 1),
TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1),
conv(n_f[2], n_f[3], bias=False),
nn.ReLU(inplace=True),
act_func,
conv(n_f[3], self.image_channels * self.patch_size**2),
)
if checkpoint_path is not None:
......@@ -285,7 +290,10 @@ class TAEHV(nn.Module):
"""
skip_trim = self.is_cogvideox and x.shape[1] % 2 == 0
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
x = x.clamp_(0, 1)
if self.model_type == "hy15":
x = x.clamp_(-1, 1)
else:
x = x.clamp_(0, 1)
if self.patch_size > 1:
x = F.pixel_shuffle(x, self.patch_size)
if skip_trim:
......
......@@ -821,9 +821,11 @@ class WanVAE:
use_2d_split=True,
load_from_rank0=False,
use_lightvae=False,
run_device=torch.device("cuda"),
):
self.dtype = dtype
self.device = device
self.run_device = run_device
self.parallel = parallel
self.use_tiling = use_tiling
self.cpu_offload = cpu_offload
......@@ -953,9 +955,9 @@ class WanVAE:
self.scale = [self.mean, self.inv_std]
def to_cuda(self):
self.model.encoder = self.model.encoder.to("cuda")
self.model.decoder = self.model.decoder.to("cuda")
self.model = self.model.to("cuda")
self.model.encoder = self.model.encoder.to(self.run_device)
self.model.decoder = self.model.decoder.to(self.run_device)
self.model = self.model.to(self.run_device)
self.mean = self.mean.cuda()
self.inv_std = self.inv_std.cuda()
self.scale = [self.mean, self.inv_std]
......@@ -1328,7 +1330,7 @@ class WanVAE:
def device_synchronize(
self,
):
if "cuda" in str(self.device):
if "cuda" in str(self.run_device):
torch.cuda.synchronize()
elif "mlu" in str(self.device):
elif "mlu" in str(self.run_device):
torch.mlu.synchronize()
......@@ -43,36 +43,43 @@ def set_config(args):
config_json = json.load(f)
config.update(config_json)
if os.path.exists(os.path.join(config["model_path"], "config.json")):
with open(os.path.join(config["model_path"], "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
elif os.path.exists(os.path.join(config["model_path"], "low_noise_model", "config.json")): # 需要一个更优雅的update方法
with open(os.path.join(config["model_path"], "low_noise_model", "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
elif os.path.exists(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json")): # 需要一个更优雅的update方法
with open(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
elif os.path.exists(os.path.join(config["model_path"], "original", "config.json")):
with open(os.path.join(config["model_path"], "original", "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
# load quantized config
if config.get("dit_quantized_ckpt", None) is not None:
config_path = os.path.join(config["dit_quantized_ckpt"], "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
if config["model_cls"] == "hunyuan_video_1.5": # Special config for hunyuan video 1.5 model folder structure
config["transformer_model_path"] = os.path.join(config["model_path"], "transformer", config["transformer_model_name"]) # transformer_model_name: [480p_t2v, 480p_i2v, 720p_t2v, 720p_i2v]
if os.path.exists(os.path.join(config["transformer_model_path"], "config.json")):
with open(os.path.join(config["transformer_model_path"], "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
else:
if os.path.exists(os.path.join(config["model_path"], "config.json")):
with open(os.path.join(config["model_path"], "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
elif os.path.exists(os.path.join(config["model_path"], "low_noise_model", "config.json")): # 需要一个更优雅的update方法
with open(os.path.join(config["model_path"], "low_noise_model", "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
elif os.path.exists(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json")): # 需要一个更优雅的update方法
with open(os.path.join(config["model_path"], "distill_models", "low_noise_model", "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
elif os.path.exists(os.path.join(config["model_path"], "original", "config.json")):
with open(os.path.join(config["model_path"], "original", "config.json"), "r") as f:
model_config = json.load(f)
config.update(model_config)
# load quantized config
if config.get("dit_quantized_ckpt", None) is not None:
config_path = os.path.join(config["dit_quantized_ckpt"], "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
model_config = json.load(f)
config.update(model_config)
if config["task"] in ["i2v", "s2v"]:
if config["target_video_length"] % config["vae_stride"][0] != 1:
logger.warning(f"`num_frames - 1` has to be divisible by {config['vae_stride'][0]}. Rounding to the nearest number.")
config["target_video_length"] = config["target_video_length"] // config["vae_stride"][0] * config["vae_stride"][0] + 1
if config["task"] not in ["t2i", "i2i"]:
if config["task"] not in ["t2i", "i2i"] and config["model_cls"] != "hunyuan_video_1.5":
config["attnmap_frame_num"] = ((config["target_video_length"] - 1) // config["vae_stride"][0] + 1) // config["patch_size"][0]
if config["model_cls"] == "seko_talk":
config["attnmap_frame_num"] += 1
......
......@@ -333,16 +333,13 @@ def load_safetensors(in_path, remove_key=None, include_keys=None):
def load_safetensors_from_path(in_path, remove_key=None, include_keys=None):
"""从单个safetensors文件加载权重,支持按key筛选"""
include_keys = include_keys or []
tensors = {}
with safetensors.safe_open(in_path, framework="pt", device="cpu") as f:
for key in f.keys():
# 优先处理include_keys:如果非空,只保留包含任意指定key的条目
if include_keys:
if any(inc_key in key for inc_key in include_keys):
tensors[key] = f.get_tensor(key)
# 否则使用remove_key排除
else:
if not (remove_key and remove_key in key):
tensors[key] = f.get_tensor(key)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment