Commit 075ec4c0 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

[Feat] Add vace model (#236)

parent 0dd7ca09
{
"infer_steps": 50,
"target_video_length": 81,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 16,
"enable_cfg": true,
"cpu_offload": false
}
......@@ -52,7 +52,7 @@ class SageAttn2Weight(AttnWeightTemplate):
)
x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "wan2.1_audio", "wan2.2"]:
elif model_cls in ["wan2.1", "wan2.1_distill", "wan2.1_causvid", "wan2.1_df", "wan2.1_audio", "wan2.2", "wan2.1_vace"]:
x = sageattn(
q.unsqueeze(0),
k.unsqueeze(0),
......
......@@ -13,6 +13,7 @@ from lightx2v.models.runners.wan.wan_causvid_runner import WanCausVidRunner # n
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
......@@ -45,6 +46,7 @@ def main():
"wan2.1_distill",
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"wan2.1_vace",
"cogvideox",
"wan2.1_audio",
"wan2.2_moe",
......@@ -57,7 +59,7 @@ def main():
default="wan2.1",
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v"], default="t2v")
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--use_prompt_enhancer", action="store_true")
......@@ -69,6 +71,25 @@ def main():
parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (flf2v) task")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file for audio-to-video (a2v) task")
parser.add_argument(
"--src_ref_images",
type=str,
default=None,
help="The file list of the source reference images. Separated by ','. Default None.",
)
parser.add_argument(
"--src_video",
type=str,
default=None,
help="The file of the source video. Default None.",
)
parser.add_argument(
"--src_mask",
type=str,
default=None,
help="The file of the source mask. Default None.",
)
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args()
......
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import numpy as np
import torch
import torch.nn.functional as F
class VaceVideoProcessor(object):
def __init__(self, downsample, min_area, max_area, min_fps, max_fps, zero_start, seq_len, keep_last, **kwargs):
self.downsample = downsample
self.min_area = min_area
self.max_area = max_area
self.min_fps = min_fps
self.max_fps = max_fps
self.zero_start = zero_start
self.keep_last = keep_last
self.seq_len = seq_len
assert seq_len >= min_area / (self.downsample[1] * self.downsample[2])
def set_area(self, area):
self.min_area = area
self.max_area = area
def set_seq_len(self, seq_len):
self.seq_len = seq_len
@staticmethod
def resize_crop(video: torch.Tensor, oh: int, ow: int):
"""
Resize, center crop and normalize for decord loaded video (torch.Tensor type)
Parameters:
video - video to process (torch.Tensor): Tensor from `reader.get_batch(frame_ids)`, in shape of (T, H, W, C)
oh - target height (int)
ow - target width (int)
Returns:
The processed video (torch.Tensor): Normalized tensor range [-1, 1], in shape of (C, T, H, W)
Raises:
"""
# permute ([t, h, w, c] -> [t, c, h, w])
video = video.permute(0, 3, 1, 2)
# resize and crop
ih, iw = video.shape[2:]
if ih != oh or iw != ow:
# resize
scale = max(ow / iw, oh / ih)
video = F.interpolate(video, size=(round(scale * ih), round(scale * iw)), mode="bicubic", antialias=True)
assert video.size(3) >= ow and video.size(2) >= oh
# center crop
x1 = (video.size(3) - ow) // 2
y1 = (video.size(2) - oh) // 2
video = video[:, :, y1 : y1 + oh, x1 : x1 + ow]
# permute ([t, c, h, w] -> [c, t, h, w]) and normalize
video = video.transpose(0, 1).float().div_(127.5).sub_(1.0)
return video
def _video_preprocess(self, video, oh, ow):
return self.resize_crop(video, oh, ow)
def _get_frameid_bbox_default(self, fps, frame_timestamps, h, w, crop_box, rng):
target_fps = min(fps, self.max_fps)
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
of = min((int(duration * target_fps) - 1) // df + 1, int(self.seq_len / area_z))
# deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of))
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
# sample frame ids
target_duration = of / target_fps
begin = 0.0 if self.zero_start else rng.uniform(0, duration - target_duration)
timestamps = np.linspace(begin, begin + target_duration, of)
frame_ids = np.argmax(np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], timestamps[:, None] < frame_timestamps[None, :, 1]), axis=1).tolist()
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox_adjust_last(self, fps, frame_timestamps, h, w, crop_box, rng):
duration = frame_timestamps[-1].mean()
x1, x2, y1, y2 = [0, w, 0, h] if crop_box is None else crop_box
h, w = y2 - y1, x2 - x1
ratio = h / w
df, dh, dw = self.downsample
area_z = min(self.seq_len, self.max_area / (dh * dw), (h // dh) * (w // dw))
of = min((len(frame_timestamps) - 1) // df + 1, int(self.seq_len / area_z))
# deduce target shape of the [latent video]
target_area_z = min(area_z, int(self.seq_len / of))
oh = round(np.sqrt(target_area_z * ratio))
ow = int(target_area_z / oh)
of = (of - 1) * df + 1
oh *= dh
ow *= dw
# sample frame ids
target_duration = duration
target_fps = of / target_duration
timestamps = np.linspace(0.0, target_duration, of)
frame_ids = np.argmax(np.logical_and(timestamps[:, None] >= frame_timestamps[None, :, 0], timestamps[:, None] <= frame_timestamps[None, :, 1]), axis=1).tolist()
# print(oh, ow, of, target_duration, target_fps, len(frame_timestamps), len(frame_ids))
return frame_ids, (x1, x2, y1, y2), (oh, ow), target_fps
def _get_frameid_bbox(self, fps, frame_timestamps, h, w, crop_box, rng):
if self.keep_last:
return self._get_frameid_bbox_adjust_last(fps, frame_timestamps, h, w, crop_box, rng)
else:
return self._get_frameid_bbox_default(fps, frame_timestamps, h, w, crop_box, rng)
def load_video(self, data_key, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, crop_box=crop_box, seed=seed, **kwargs)
def load_video_pair(self, data_key, data_key2, crop_box=None, seed=2024, **kwargs):
return self.load_video_batch(data_key, data_key2, crop_box=crop_box, seed=seed, **kwargs)
def load_video_batch(self, *data_key_batch, crop_box=None, seed=2024, **kwargs):
rng = np.random.default_rng(seed + hash(data_key_batch[0]) % 10000)
# read video
import decord
decord.bridge.set_bridge("torch")
readers = []
for data_k in data_key_batch:
reader = decord.VideoReader(data_k)
readers.append(reader)
fps = readers[0].get_avg_fps()
length = min([len(r) for r in readers])
frame_timestamps = [readers[0].get_frame_timestamp(i) for i in range(length)]
frame_timestamps = np.array(frame_timestamps, dtype=np.float32)
h, w = readers[0].next().shape[:2]
frame_ids, (x1, x2, y1, y2), (oh, ow), fps = self._get_frameid_bbox(fps, frame_timestamps, h, w, crop_box, rng)
# preprocess video
videos = [reader.get_batch(frame_ids)[:, y1:y2, x1:x2, :] for reader in readers]
videos = [self._video_preprocess(video, oh, ow) for video in videos]
return *videos, frame_ids, (oh, ow), fps
# return videos if len(videos) > 1 else videos[0]
def prepare_source(src_video, src_mask, src_ref_images, num_frames, image_size, device):
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_video is None and sub_src_mask is None:
src_video[i] = torch.zeros((3, num_frames, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones((1, num_frames, image_size[0], image_size[1]), device=device)
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
for j, ref_img in enumerate(ref_images):
if ref_img is not None and ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode="bilinear", align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top : top + new_height, left : left + new_width] = resized_image
src_ref_images[i][j] = white_canvas
return src_video, src_mask, src_ref_images
from dataclasses import dataclass
from typing import List
from typing import Any, List, Optional
import torch
......@@ -13,5 +13,7 @@ class WanPreInferModuleOutput:
seq_lens: torch.Tensor
freqs: torch.Tensor
context: torch.Tensor
audio_dit_blocks: List = None
valid_patch_length: int = None
audio_dit_blocks: List[Any] = None
valid_patch_length: Optional[int] = None
hints: List[Any] = None
context_scale: float = 1.0
from lightx2v.models.networks.wan.infer.offload.transformer_infer import WanOffloadTransformerInfer
from lightx2v.utils.envs import *
class WanVaceTransformerInfer(WanOffloadTransformerInfer):
def __init__(self, config):
super().__init__(config)
self.vace_block_nums = len(self.config.vace_layers)
self.vace_blocks_mapping = {orig_idx: seq_idx for seq_idx, orig_idx in enumerate(self.config.vace_layers)}
def infer(self, weights, pre_infer_out):
pre_infer_out.hints = self.infer_vace(weights, pre_infer_out)
x = self.infer_main_blocks(weights, pre_infer_out)
return self.infer_non_blocks(weights, x, pre_infer_out.embed)
def infer_vace(self, weights, pre_infer_out):
c = weights.vace_patch_embedding.apply(pre_infer_out.vace_context.unsqueeze(0).to(self.sensitive_layer_dtype))
c = c.flatten(2).transpose(1, 2).contiguous().squeeze(0)
self.infer_state = "vace"
hints = []
for i in range(self.vace_block_nums):
c, c_skip = self.infer_vace_block(weights.vace_blocks[i], c, pre_infer_out.x, pre_infer_out)
hints.append(c_skip)
self.infer_state = "base"
return hints
def infer_vace_block(self, weights, c, x, pre_infer_out):
if hasattr(weights, "before_proj"):
c = weights.before_proj.apply(c) + x
c = self.infer_block(weights, c, pre_infer_out)
c_skip = weights.after_proj.apply(c)
return c, c_skip
def post_process(self, x, y, c_gate_msa, pre_infer_out):
x = super().post_process(x, y, c_gate_msa, pre_infer_out)
if self.infer_state == "base" and self.block_idx in self.vace_blocks_mapping:
hint_idx = self.vace_blocks_mapping[self.block_idx]
x = x + pre_infer_out.hints[hint_idx] * pre_infer_out.context_scale
return x
......@@ -26,7 +26,6 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.transformer_infer import (
WanTransformerInfer,
)
from lightx2v.models.networks.wan.weights.post_weights import WanPostWeights
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerWeights,
......@@ -42,7 +41,6 @@ except ImportError:
class WanModel:
pre_weight_class = WanPreWeights
post_weight_class = WanPostWeights
transformer_weight_class = WanTransformerWeights
def __init__(self, model_path, config, device):
......@@ -220,6 +218,8 @@ class WanModel:
"time",
"img_emb.proj.0",
"img_emb.proj.4",
"before_proj", # vace
"after_proj", # vace
}
if weight_dict is None:
......@@ -333,7 +333,7 @@ class WanModel:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.transformer_weights.post_weights_to_cuda()
self.transformer_weights.non_block_weights_to_cuda()
if self.transformer_infer.mask_map is None:
_, c, h, w = self.scheduler.latents.shape
......@@ -371,7 +371,7 @@ class WanModel:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
self.transformer_weights.post_weights_to_cpu()
self.transformer_weights.non_block_weights_to_cpu()
@torch.compile(disable=not CHECK_ENABLE_GRAPH_MODE())
@torch.no_grad()
......
import torch
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from lightx2v.models.networks.wan.infer.vace.transformer_infer import WanVaceTransformerInfer
from lightx2v.models.networks.wan.model import WanModel
from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from lightx2v.models.networks.wan.weights.vace.transformer_weights import (
WanVaceTransformerWeights,
)
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
class WanVaceModel(WanModel):
pre_weight_class = WanPreWeights
transformer_weight_class = WanVaceTransformerWeights
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _init_infer_class(self):
self.pre_infer_class = WanPreInfer
self.post_infer_class = WanPostInfer
self.transformer_infer_class = WanVaceTransformerInfer
@torch.no_grad()
def _infer_cond_uncond(self, inputs, infer_condition=True):
self.scheduler.infer_condition = infer_condition
pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs)
pre_infer_out.vace_context = inputs["image_encoder_output"]["vae_encoder_out"][0]
x = self.transformer_infer.infer(self.transformer_weights, pre_infer_out)
noise_pred = self.post_infer.infer(x, pre_infer_out)[0]
if self.clean_cuda_cache:
del x, pre_infer_out
torch.cuda.empty_cache()
return noise_pred
......@@ -24,6 +24,7 @@ class WanTransformerWeights(WeightModule):
else:
self.mm_type = config["mm_config"].get("mm_type", "Default") if config["mm_config"] else "Default"
self.blocks = WeightModuleList([WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config) for i in range(self.blocks_num)])
self.add_module("blocks", self.blocks)
# post blocks weights
......@@ -36,19 +37,19 @@ class WanTransformerWeights(WeightModule):
for phase in block.compute_phases:
phase.clear()
def post_weights_to_cuda(self):
def non_block_weights_to_cuda(self):
self.norm.to_cuda()
self.head.to_cuda()
self.head_modulation.to_cuda()
def post_weights_to_cpu(self):
def non_block_weights_to_cpu(self):
self.norm.to_cpu()
self.head.to_cpu()
self.head_modulation.to_cpu()
class WanTransformerAttentionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config):
def __init__(self, block_index, task, mm_type, config, block_prefix="blocks"):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
......@@ -68,6 +69,7 @@ class WanTransformerAttentionBlock(WeightModule):
[
WanModulation(
block_index,
block_prefix,
task,
mm_type,
config,
......@@ -76,6 +78,7 @@ class WanTransformerAttentionBlock(WeightModule):
),
WanSelfAttention(
block_index,
block_prefix,
task,
mm_type,
config,
......@@ -84,6 +87,7 @@ class WanTransformerAttentionBlock(WeightModule):
),
WanCrossAttention(
block_index,
block_prefix,
task,
mm_type,
config,
......@@ -92,6 +96,7 @@ class WanTransformerAttentionBlock(WeightModule):
),
WanFFN(
block_index,
block_prefix,
task,
mm_type,
config,
......@@ -105,7 +110,7 @@ class WanTransformerAttentionBlock(WeightModule):
class WanModulation(WeightModule):
def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
......@@ -120,7 +125,7 @@ class WanModulation(WeightModule):
self.add_module(
"modulation",
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.modulation",
f"{block_prefix}.{self.block_index}.modulation",
self.lazy_load,
self.lazy_load_file,
),
......@@ -128,7 +133,7 @@ class WanModulation(WeightModule):
class WanSelfAttention(WeightModule):
def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
......@@ -148,8 +153,8 @@ class WanSelfAttention(WeightModule):
self.add_module(
"self_attn_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.q.weight",
f"blocks.{self.block_index}.self_attn.q.bias",
f"{block_prefix}.{self.block_index}.self_attn.q.weight",
f"{block_prefix}.{self.block_index}.self_attn.q.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -157,8 +162,8 @@ class WanSelfAttention(WeightModule):
self.add_module(
"self_attn_k",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.k.weight",
f"blocks.{self.block_index}.self_attn.k.bias",
f"{block_prefix}.{self.block_index}.self_attn.k.weight",
f"{block_prefix}.{self.block_index}.self_attn.k.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -166,8 +171,8 @@ class WanSelfAttention(WeightModule):
self.add_module(
"self_attn_v",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.v.weight",
f"blocks.{self.block_index}.self_attn.v.bias",
f"{block_prefix}.{self.block_index}.self_attn.v.weight",
f"{block_prefix}.{self.block_index}.self_attn.v.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -175,8 +180,8 @@ class WanSelfAttention(WeightModule):
self.add_module(
"self_attn_o",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.self_attn.o.weight",
f"blocks.{self.block_index}.self_attn.o.bias",
f"{block_prefix}.{self.block_index}.self_attn.o.weight",
f"{block_prefix}.{self.block_index}.self_attn.o.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -184,7 +189,7 @@ class WanSelfAttention(WeightModule):
self.add_module(
"self_attn_norm_q",
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.self_attn.norm_q.weight",
f"{block_prefix}.{self.block_index}.self_attn.norm_q.weight",
self.lazy_load,
self.lazy_load_file,
),
......@@ -192,7 +197,7 @@ class WanSelfAttention(WeightModule):
self.add_module(
"self_attn_norm_k",
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.self_attn.norm_k.weight",
f"{block_prefix}.{self.block_index}.self_attn.norm_k.weight",
self.lazy_load,
self.lazy_load_file,
),
......@@ -201,7 +206,7 @@ class WanSelfAttention(WeightModule):
assert self.config["sparge_ckpt"], "sparge_ckpt must be set when sparge is True"
self.add_module(
"self_attn_1",
ATTN_WEIGHT_REGISTER["Sparge"](f"blocks.{self.block_index}"),
ATTN_WEIGHT_REGISTER["Sparge"](f"{block_prefix}.{self.block_index}"),
)
sparge_ckpt = torch.load(self.config["sparge_ckpt"])
self.self_attn_1.load(sparge_ckpt)
......@@ -215,7 +220,7 @@ class WanSelfAttention(WeightModule):
self.add_module(
"smooth_norm1_weight",
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm1.weight",
f"{block_prefix}.{self.block_index}.affine_norm1.weight",
self.lazy_load,
self.lazy_load_file,
),
......@@ -223,7 +228,7 @@ class WanSelfAttention(WeightModule):
self.add_module(
"smooth_norm1_bias",
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm1.bias",
f"{block_prefix}.{self.block_index}.affine_norm1.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -231,7 +236,7 @@ class WanSelfAttention(WeightModule):
class WanCrossAttention(WeightModule):
def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
......@@ -243,8 +248,8 @@ class WanCrossAttention(WeightModule):
self.add_module(
"norm3",
LN_WEIGHT_REGISTER["Default"](
f"blocks.{self.block_index}.norm3.weight",
f"blocks.{self.block_index}.norm3.bias",
f"{block_prefix}.{self.block_index}.norm3.weight",
f"{block_prefix}.{self.block_index}.norm3.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -252,8 +257,8 @@ class WanCrossAttention(WeightModule):
self.add_module(
"cross_attn_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.q.weight",
f"blocks.{self.block_index}.cross_attn.q.bias",
f"{block_prefix}.{self.block_index}.cross_attn.q.weight",
f"{block_prefix}.{self.block_index}.cross_attn.q.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -261,8 +266,8 @@ class WanCrossAttention(WeightModule):
self.add_module(
"cross_attn_k",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.k.weight",
f"blocks.{self.block_index}.cross_attn.k.bias",
f"{block_prefix}.{self.block_index}.cross_attn.k.weight",
f"{block_prefix}.{self.block_index}.cross_attn.k.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -270,8 +275,8 @@ class WanCrossAttention(WeightModule):
self.add_module(
"cross_attn_v",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.v.weight",
f"blocks.{self.block_index}.cross_attn.v.bias",
f"{block_prefix}.{self.block_index}.cross_attn.v.weight",
f"{block_prefix}.{self.block_index}.cross_attn.v.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -279,8 +284,8 @@ class WanCrossAttention(WeightModule):
self.add_module(
"cross_attn_o",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.o.weight",
f"blocks.{self.block_index}.cross_attn.o.bias",
f"{block_prefix}.{self.block_index}.cross_attn.o.weight",
f"{block_prefix}.{self.block_index}.cross_attn.o.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -288,7 +293,7 @@ class WanCrossAttention(WeightModule):
self.add_module(
"cross_attn_norm_q",
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.cross_attn.norm_q.weight",
f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight",
self.lazy_load,
self.lazy_load_file,
),
......@@ -296,7 +301,7 @@ class WanCrossAttention(WeightModule):
self.add_module(
"cross_attn_norm_k",
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.cross_attn.norm_k.weight",
f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight",
self.lazy_load,
self.lazy_load_file,
),
......@@ -307,8 +312,8 @@ class WanCrossAttention(WeightModule):
self.add_module(
"cross_attn_k_img",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.k_img.weight",
f"blocks.{self.block_index}.cross_attn.k_img.bias",
f"{block_prefix}.{self.block_index}.cross_attn.k_img.weight",
f"{block_prefix}.{self.block_index}.cross_attn.k_img.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -316,8 +321,8 @@ class WanCrossAttention(WeightModule):
self.add_module(
"cross_attn_v_img",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.cross_attn.v_img.weight",
f"blocks.{self.block_index}.cross_attn.v_img.bias",
f"{block_prefix}.{self.block_index}.cross_attn.v_img.weight",
f"{block_prefix}.{self.block_index}.cross_attn.v_img.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -325,7 +330,7 @@ class WanCrossAttention(WeightModule):
self.add_module(
"cross_attn_norm_k_img",
RMS_WEIGHT_REGISTER["sgl-kernel"](
f"blocks.{self.block_index}.cross_attn.norm_k_img.weight",
f"{block_prefix}.{self.block_index}.cross_attn.norm_k_img.weight",
self.lazy_load,
self.lazy_load_file,
),
......@@ -334,7 +339,7 @@ class WanCrossAttention(WeightModule):
class WanFFN(WeightModule):
def __init__(self, block_index, task, mm_type, config, lazy_load, lazy_load_file):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
......@@ -352,8 +357,8 @@ class WanFFN(WeightModule):
self.add_module(
"ffn_0",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.ffn.0.weight",
f"blocks.{self.block_index}.ffn.0.bias",
f"{block_prefix}.{self.block_index}.ffn.0.weight",
f"{block_prefix}.{self.block_index}.ffn.0.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -361,8 +366,8 @@ class WanFFN(WeightModule):
self.add_module(
"ffn_2",
MM_WEIGHT_REGISTER[self.mm_type](
f"blocks.{self.block_index}.ffn.2.weight",
f"blocks.{self.block_index}.ffn.2.bias",
f"{block_prefix}.{self.block_index}.ffn.2.weight",
f"{block_prefix}.{self.block_index}.ffn.2.bias",
self.lazy_load,
self.lazy_load_file,
),
......@@ -372,7 +377,7 @@ class WanFFN(WeightModule):
self.add_module(
"smooth_norm2_weight",
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm3.weight",
f"{block_prefix}.{self.block_index}.affine_norm3.weight",
self.lazy_load,
self.lazy_load_file,
),
......@@ -380,7 +385,7 @@ class WanFFN(WeightModule):
self.add_module(
"smooth_norm2_bias",
TENSOR_REGISTER["Default"](
f"blocks.{self.block_index}.affine_norm3.bias",
f"{block_prefix}.{self.block_index}.affine_norm3.bias",
self.lazy_load,
self.lazy_load_file,
),
......
from lightx2v.common.modules.weight_module import WeightModuleList
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanTransformerAttentionBlock,
WanTransformerWeights,
)
from lightx2v.utils.registry_factory import (
CONV3D_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
)
# "vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
# {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
class WanVaceTransformerWeights(WanTransformerWeights):
def __init__(self, config):
super().__init__(config)
self.patch_size = (1, 2, 2)
self.vace_blocks = WeightModuleList(
[WanVaceTransformerAttentionBlock(self.config.vace_layers[i], i, self.task, self.mm_type, self.config, "vace_blocks") for i in range(len(self.config.vace_layers))]
)
self.add_module("vace_blocks", self.vace_blocks)
self.add_module(
"vace_patch_embedding",
CONV3D_WEIGHT_REGISTER["Default"]("vace_patch_embedding.weight", "vace_patch_embedding.bias", stride=self.patch_size),
)
def clear(self):
super().clear()
for vace_block in self.vace_blocks:
for vace_phase in vace_block.compute_phases:
vace_phase.clear()
def non_block_weights_to_cuda(self):
super().non_block_weights_to_cuda()
self.vace_patch_embedding.to_cuda()
def non_block_weights_to_cpu(self):
super().non_block_weights_to_cpu()
self.vace_patch_embedding.to_cpu()
class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
def __init__(self, base_block_idx, block_index, task, mm_type, config, block_prefix):
super().__init__(block_index, task, mm_type, config, block_prefix)
if base_block_idx == 0:
self.add_module(
"before_proj",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.before_proj.weight",
f"{block_prefix}.{self.block_index}.before_proj.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"after_proj",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.after_proj.weight",
f"{block_prefix}.{self.block_index}.after_proj.bias",
self.lazy_load,
self.lazy_load_file,
),
)
......@@ -43,6 +43,8 @@ class DefaultRunner(BaseRunner):
self.run_input_encoder = self._run_input_encoder_local_flf2v
elif self.config["task"] == "t2v":
self.run_input_encoder = self._run_input_encoder_local_t2v
elif self.config["task"] == "vace":
self.run_input_encoder = self._run_input_encoder_local_vace
def set_init_device(self):
if self.config.cpu_offload:
......@@ -179,6 +181,26 @@ class DefaultRunner(BaseRunner):
gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)
@ProfilingContext("Run Encoders")
def _run_input_encoder_local_vace(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
src_video = self.config.get("src_video", None)
src_mask = self.config.get("src_mask", None)
src_ref_images = self.config.get("src_ref_images", None)
src_video, src_mask, src_ref_images = self.prepare_source(
[src_video],
[src_mask],
[None if src_ref_images is None else src_ref_images.split(",")],
(self.config.target_width, self.config.target_height),
)
self.src_ref_images = src_ref_images
vae_encoder_out = self.run_vae_encoder(src_video, src_ref_images, src_mask)
text_encoder_output = self.run_text_encoder(prompt)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(None, vae_encoder_out, text_encoder_output)
@ProfilingContext("Run DiT")
def _run_dit_local(self, total_steps=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
......
......@@ -139,7 +139,7 @@ class WanRunner(DefaultRunner):
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
}
if self.config.task not in ["i2v", "flf2v"]:
if self.config.task not in ["i2v", "flf2v", "vace"]:
return None
else:
return WanVAE(**vae_config)
......
import gc
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.models.input_encoders.hf.vace.vace_processor import VaceVideoProcessor
from lightx2v.models.networks.wan.vace_model import WanVaceModel
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.registry_factory import RUNNER_REGISTER
@RUNNER_REGISTER("wan2.1_vace")
class WanVaceRunner(WanRunner):
def __init__(self, config):
super().__init__(config)
assert self.config.task == "vace"
self.vid_proc = VaceVideoProcessor(
downsample=tuple([x * y for x, y in zip(self.config.vae_stride, self.config.patch_size)]),
min_area=720 * 1280,
max_area=720 * 1280,
min_fps=self.config.get("fps", 16),
max_fps=self.config.get("fps", 16),
zero_start=True,
seq_len=75600,
keep_last=True,
)
def load_transformer(self):
model = WanVaceModel(
self.config.model_path,
self.config,
self.init_device,
)
return model
def prepare_source(self, src_video, src_mask, src_ref_images, image_size, device=torch.device("cuda")):
area = image_size[0] * image_size[1]
self.vid_proc.set_area(area)
if area == 720 * 1280:
self.vid_proc.set_seq_len(75600)
elif area == 480 * 832:
self.vid_proc.set_seq_len(32760)
else:
raise NotImplementedError(f"image_size {image_size} is not supported")
image_size = (image_size[1], image_size[0])
image_sizes = []
for i, (sub_src_video, sub_src_mask) in enumerate(zip(src_video, src_mask)):
if sub_src_mask is not None and sub_src_video is not None:
src_video[i], src_mask[i], _, _, _ = self.vid_proc.load_video_pair(sub_src_video, sub_src_mask)
src_video[i] = src_video[i].to(device)
src_mask[i] = src_mask[i].to(device)
src_mask[i] = torch.clamp((src_mask[i][:1, :, :, :] + 1) / 2, min=0, max=1)
image_sizes.append(src_video[i].shape[2:])
elif sub_src_video is None:
src_video[i] = torch.zeros((3, self.config.target_video_length, image_size[0], image_size[1]), device=device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(image_size)
else:
src_video[i], _, _, _ = self.vid_proc.load_video(sub_src_video)
src_video[i] = src_video[i].to(device)
src_mask[i] = torch.ones_like(src_video[i], device=device)
image_sizes.append(src_video[i].shape[2:])
for i, ref_images in enumerate(src_ref_images):
if ref_images is not None:
image_size = image_sizes[i]
for j, ref_img in enumerate(ref_images):
if ref_img is not None:
ref_img = Image.open(ref_img).convert("RGB")
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(1)
if ref_img.shape[-2:] != image_size:
canvas_height, canvas_width = image_size
ref_height, ref_width = ref_img.shape[-2:]
white_canvas = torch.ones((3, 1, canvas_height, canvas_width), device=device) # [-1, 1]
scale = min(canvas_height / ref_height, canvas_width / ref_width)
new_height = int(ref_height * scale)
new_width = int(ref_width * scale)
resized_image = F.interpolate(ref_img.squeeze(1).unsqueeze(0), size=(new_height, new_width), mode="bilinear", align_corners=False).squeeze(0).unsqueeze(1)
top = (canvas_height - new_height) // 2
left = (canvas_width - new_width) // 2
white_canvas[:, :, top : top + new_height, left : left + new_width] = resized_image
ref_img = white_canvas
src_ref_images[i][j] = ref_img.to(device)
return src_video, src_mask, src_ref_images
def run_vae_encoder(self, frames, ref_images, masks):
if ref_images is None:
ref_images = [None] * len(frames)
else:
assert len(frames) == len(ref_images)
if masks is None:
latents = self.vae_encoder.encode(frames)
else:
masks = [torch.where(m > 0.5, 1.0, 0.0) for m in masks]
inactive = [i * (1 - m) + 0 * m for i, m in zip(frames, masks)]
reactive = [i * m + 0 * (1 - m) for i, m in zip(frames, masks)]
inactive = self.vae_encoder.encode(inactive)
reactive = self.vae_encoder.encode(reactive)
latents = [torch.cat((u, c), dim=0) for u, c in zip(inactive, reactive)]
cat_latents = []
for latent, refs in zip(latents, ref_images):
if refs is not None:
if masks is None:
ref_latent = self.vae_encoder.encode(refs)
else:
ref_latent = self.vae_encoder.encode(refs)
ref_latent = [torch.cat((u, torch.zeros_like(u)), dim=0) for u in ref_latent]
assert all([x.shape[1] == 1 for x in ref_latent])
latent = torch.cat([*ref_latent, latent], dim=1)
cat_latents.append(latent)
self.latent_shape = list(cat_latents[0].shape)
return self.get_vae_encoder_output(cat_latents, masks, ref_images)
def get_vae_encoder_output(self, cat_latents, masks, ref_images):
if ref_images is None:
ref_images = [None] * len(masks)
else:
assert len(masks) == len(ref_images)
result_masks = []
for mask, refs in zip(masks, ref_images):
c, depth, height, width = mask.shape
new_depth = int((depth + 3) // self.config.vae_stride[0])
height = 2 * (int(height) // (self.config.vae_stride[1] * 2))
width = 2 * (int(width) // (self.config.vae_stride[2] * 2))
# reshape
mask = mask[0, :, :, :]
mask = mask.view(depth, height, self.config.vae_stride[1], width, self.config.vae_stride[1]) # depth, height, 8, width, 8
mask = mask.permute(2, 4, 0, 1, 3) # 8, 8, depth, height, width
mask = mask.reshape(self.config.vae_stride[1] * self.config.vae_stride[2], depth, height, width) # 8*8, depth, height, width
# interpolation
mask = F.interpolate(mask.unsqueeze(0), size=(new_depth, height, width), mode="nearest-exact").squeeze(0)
if refs is not None:
length = len(refs)
mask_pad = torch.zeros_like(mask[:, :length, :, :])
mask = torch.cat((mask_pad, mask), dim=1)
result_masks.append(mask)
return [torch.cat([zz, mm], dim=0) for zz, mm in zip(cat_latents, result_masks)]
def set_target_shape(self):
target_shape = self.latent_shape
target_shape[0] = int(target_shape[0] / 2)
self.config.target_shape = target_shape
@ProfilingContext("Run VAE Decoder")
def _run_vae_decoder_local(self, latents, generator):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_decoder = self.load_vae_decoder()
if self.src_ref_images is not None:
assert len(self.src_ref_images) == 1
refs = self.src_ref_images[0]
if refs is not None:
latents = latents[:, len(refs) :, :, :]
images = self.vae_decoder.decode(latents, generator=generator, config=self.config)
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_decoder
torch.cuda.empty_cache()
gc.collect()
return images
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=1
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1_vace \
--task vace \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan/wan_vace.json \
--prompt "在一个欢乐而充满节日气氛的场景中,穿着鲜艳红色春服的小女孩正与她的可爱卡通蛇嬉戏。她的春服上绣着金色吉祥图案,散发着喜庆的气息,脸上洋溢着灿烂的笑容。蛇身呈现出亮眼的绿色,形状圆润,宽大的眼睛让它显得既友善又幽默。小女孩欢快地用手轻轻抚摸着蛇的头部,共同享受着这温馨的时刻。周围五彩斑斓的灯笼和彩带装饰着环境,阳光透过洒在她们身上,营造出一个充满友爱与幸福的新年氛围。" \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--src_ref_images ${lightx2v_path}/assets/inputs/imgs/girl.png,${lightx2v_path}/assets/inputs/imgs/snake.png \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_vace.mp4\
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment