"vscode:/vscode.git/clone" did not exist on "a48e2d273f85f5beeeddd042b79c41049fe01afa"
Commit e76beda7 authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Add flf2v model

Add flf2v model
parents 3d8cb02e ecb2107c
{
"infer_steps": 50,
"target_video_length": 81,
"target_height": 720,
"target_width": 1280,
"self_attn_1_type": "sage_attn2",
"cross_attn_1_type": "sage_attn2",
"cross_attn_2_type": "sage_attn2",
"seed": 442,
"sample_guide_scale": 5,
"sample_shift": 16,
"enable_cfg": true,
"cpu_offload": false
}
......@@ -57,7 +57,7 @@ def main():
default="wan2.1",
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i"], default="t2v")
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "flf2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--use_prompt_enhancer", action="store_true")
......@@ -66,6 +66,7 @@ def main():
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file for image-to-video (i2v) task")
parser.add_argument("--last_frame_path", type=str, default="", help="The path to last frame file for first-last-frame-to-video (i2v) task")
parser.add_argument("--audio_path", type=str, default="", help="The path to input audio file for audio-to-video (a2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
......
......@@ -52,13 +52,13 @@ class WanCausVidModel(WanModel):
if self.config["cpu_offload"]:
self.pre_weight.to_cuda()
self.post_weight.to_cuda()
self.transformer_weights.post_weights_to_cuda()
embed, grid_sizes, pre_infer_out = self.pre_infer.infer(self.pre_weight, inputs, kv_start=kv_start, kv_end=kv_end)
x = self.transformer_infer.infer(self.transformer_weights, grid_sizes, embed, *pre_infer_out, kv_start, kv_end)
self.scheduler.noise_pred = self.post_infer.infer(self.post_weight, x, embed, grid_sizes)[0]
self.scheduler.noise_pred = self.post_infer.infer(x, embed, grid_sizes)[0]
if self.config["cpu_offload"]:
self.pre_weight.to_cpu()
self.post_weight.to_cpu()
self.transformer_weights.post_weights_to_cpu()
......@@ -50,7 +50,7 @@ class WanPreInfer:
else:
context = inputs["text_encoder_output"]["context_null"]
if self.task == "i2v":
if self.task in ["i2v", "flf2v"]:
if self.config.get("use_image_encoder", True):
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
......@@ -113,7 +113,11 @@ class WanPreInfer:
del out, stacked
torch.cuda.empty_cache()
if self.task == "i2v" and self.config.get("use_image_encoder", True):
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
if self.task == "flf2v":
_, n, d = clip_fea.shape
clip_fea = clip_fea.view(2 * n, d)
clip_fea = clip_fea + weights.emb_pos.tensor.squeeze()
context_clip = weights.proj_0.apply(clip_fea)
if self.clean_cuda_cache:
del clip_fea
......@@ -125,6 +129,7 @@ class WanPreInfer:
context_clip = weights.proj_3.apply(context_clip)
context_clip = weights.proj_4.apply(context_clip)
context = torch.concat([context_clip, context], dim=0)
if self.clean_cuda_cache:
if self.config.get("use_image_encoder", True):
del context_clip
......
......@@ -438,7 +438,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x.add_(y_out * gate_msa.squeeze())
norm3_out = weights.norm3.apply(x)
if self.task == "i2v" and self.config.get("use_image_encoder", True):
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
context_img = context[:257]
context = context[257:]
else:
......@@ -446,7 +446,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.sensitive_layer_dtype != self.infer_dtype:
context = context.to(self.infer_dtype)
if self.task == "i2v" and self.config.get("use_image_encoder", True):
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
context_img = context_img.to(self.infer_dtype)
n, d = self.num_heads, self.head_dim
......@@ -469,7 +469,7 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls=self.config["model_cls"],
)
if self.task == "i2v" and self.config.get("use_image_encoder", True) and context_img is not None:
if self.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True) and context_img is not None:
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
......
......@@ -3,6 +3,7 @@ from lightx2v.utils.registry_factory import (
CONV3D_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
TENSOR_REGISTER,
)
......@@ -39,7 +40,7 @@ class WanPreWeights(WeightModule):
MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"),
)
if config.task == "i2v" and config.get("use_image_encoder", True):
if config.task in ["i2v", "flf2v"] and config.get("use_image_encoder", True):
self.add_module(
"proj_0",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"),
......@@ -66,3 +67,9 @@ class WanPreWeights(WeightModule):
"cfg_cond_proj_2",
MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_2.weight", "guidance_embedding.linear_2.bias"),
)
if config.task == "flf2v":
self.add_module(
"emb_pos",
TENSOR_REGISTER["Default"](f"img_emb.emb_pos"),
)
......@@ -303,7 +303,7 @@ class WanCrossAttention(WeightModule):
)
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
if self.config.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
self.add_module(
"cross_attn_k_img",
MM_WEIGHT_REGISTER[self.mm_type](
......
......@@ -39,7 +39,9 @@ class DefaultRunner(BaseRunner):
self.run_vae_decoder = self._run_vae_decoder_local
if self.config["task"] == "i2v":
self.run_input_encoder = self._run_input_encoder_local_i2v
else:
elif self.config["task"] == "flf2v":
self.run_input_encoder = self._run_input_encoder_local_flf2v
elif self.config["task"] == "t2v":
self.run_input_encoder = self._run_input_encoder_local_t2v
def set_init_device(self):
......@@ -165,6 +167,18 @@ class DefaultRunner(BaseRunner):
"image_encoder_output": None,
}
@ProfilingContext("Run Encoders")
def _run_input_encoder_local_flf2v(self):
prompt = self.config["prompt_enhanced"] if self.config["use_prompt_enhancer"] else self.config["prompt"]
first_frame = Image.open(self.config["image_path"]).convert("RGB")
last_frame = Image.open(self.config["last_frame_path"]).convert("RGB")
clip_encoder_out = self.run_image_encoder(first_frame, last_frame) if self.config.get("use_image_encoder", True) else None
vae_encode_out = self.run_vae_encoder(first_frame, last_frame)
text_encoder_output = self.run_text_encoder(prompt, first_frame)
torch.cuda.empty_cache()
gc.collect()
return self.get_encoder_output_i2v(clip_encoder_out, vae_encode_out, text_encoder_output)
@ProfilingContext("Run DiT")
def _run_dit_local(self, total_steps=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
......
......@@ -54,7 +54,7 @@ class WanRunner(DefaultRunner):
def load_image_encoder(self):
image_encoder = None
if self.config.task == "i2v" and self.config.get("use_image_encoder", True):
if self.config.task in ["i2v", "flf2v"] and self.config.get("use_image_encoder", True):
# quant_config
clip_quantized = self.config.get("clip_quantized", False)
if clip_quantized:
......@@ -139,7 +139,7 @@ class WanRunner(DefaultRunner):
"use_tiling": self.config.get("use_tiling_vae", False),
"cpu_offload": vae_offload,
}
if self.config.task != "i2v":
if self.config.task not in ["i2v", "flf2v"]:
return None
else:
return WanVAE(**vae_config)
......@@ -193,7 +193,7 @@ class WanRunner(DefaultRunner):
scheduler = scheduler_class(self.config)
self.model.set_scheduler(scheduler)
def run_text_encoder(self, text, img):
def run_text_encoder(self, text, img=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.text_encoders = self.load_text_encoder()
n_prompt = self.config.get("negative_prompt", "")
......@@ -222,26 +222,32 @@ class WanRunner(DefaultRunner):
return text_encoder_output
def run_image_encoder(self, img):
def run_image_encoder(self, first_frame, last_frame=None):
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.image_encoder = self.load_image_encoder()
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([img[None, :, :, :]]).squeeze(0).to(GET_DTYPE())
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).cuda()
if last_frame is None:
clip_encoder_out = self.image_encoder.visual([first_frame[None, :, :, :]]).squeeze(0).to(GET_DTYPE())
else:
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).cuda()
clip_encoder_out = self.image_encoder.visual([first_frame[:, None, :, :].transpose(0, 1), last_frame[:, None, :, :].transpose(0, 1)]).squeeze(0).to(GET_DTYPE())
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.image_encoder
torch.cuda.empty_cache()
gc.collect()
return clip_encoder_out
def run_vae_encoder(self, img):
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
h, w = img.shape[1:]
def run_vae_encoder(self, first_frame, last_frame=None):
first_frame_size = first_frame.size
first_frame = TF.to_tensor(first_frame).sub_(0.5).div_(0.5).cuda()
h, w = first_frame.shape[1:]
aspect_ratio = h / w
max_area = self.config.target_height * self.config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // self.config.vae_stride[1] // self.config.patch_size[1] * self.config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // self.config.vae_stride[2] // self.config.patch_size[2] * self.config.patch_size[2])
if self.config.get("changing_resolution", False):
assert last_frame is None
self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encode_out_list = []
for i in range(len(self.config["resolution_rate"])):
......@@ -249,18 +255,27 @@ class WanRunner(DefaultRunner):
int(self.config.lat_h * self.config.resolution_rate[i]) // 2 * 2,
int(self.config.lat_w * self.config.resolution_rate[i]) // 2 * 2,
)
vae_encode_out_list.append(self.get_vae_encoder_output(img, lat_h, lat_w))
vae_encode_out_list.append(self.get_vae_encoder_output(img, self.config.lat_h, self.config.lat_w))
vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, lat_h, lat_w))
vae_encode_out_list.append(self.get_vae_encoder_output(first_frame, self.config.lat_h, self.config.lat_w))
return vae_encode_out_list
else:
if last_frame is not None:
last_frame_size = last_frame.size
last_frame = TF.to_tensor(last_frame).sub_(0.5).div_(0.5).cuda()
if first_frame_size != last_frame_size:
last_frame_resize_ratio = max(first_frame_size[0] / last_frame_size[0], first_frame_size[1] / last_frame_size[1])
last_frame_size = [
round(last_frame_size[0] * last_frame_resize_ratio),
round(last_frame_size[1] * last_frame_resize_ratio),
]
last_frame = TF.center_crop(last_frame, last_frame_size)
self.config.lat_h, self.config.lat_w = lat_h, lat_w
vae_encoder_out = self.get_vae_encoder_output(img, lat_h, lat_w)
vae_encoder_out = self.get_vae_encoder_output(first_frame, lat_h, lat_w, last_frame)
return vae_encoder_out
def get_vae_encoder_output(self, img, lat_h, lat_w):
def get_vae_encoder_output(self, first_frame, lat_h, lat_w, last_frame=None):
h = lat_h * self.config.vae_stride[1]
w = lat_w * self.config.vae_stride[2]
msk = torch.ones(
1,
self.config.target_video_length,
......@@ -268,24 +283,38 @@ class WanRunner(DefaultRunner):
lat_w,
device=torch.device("cuda"),
)
if last_frame is not None:
msk[:, 1:-1] = 0
else:
msk[:, 1:] = 0
msk = torch.concat([torch.repeat_interleave(msk[:, 0:1], repeats=4, dim=1), msk[:, 1:]], dim=1)
msk = msk.view(1, msk.shape[1] // 4, 4, lat_h, lat_w)
msk = msk.transpose(1, 2)[0]
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
self.vae_encoder = self.load_vae_encoder()
vae_encoder_out = self.vae_encoder.encode(
if last_frame is not None:
vae_input = torch.concat(
[
torch.concat(
torch.nn.functional.interpolate(first_frame[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 2, h, w),
torch.nn.functional.interpolate(last_frame[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
],
dim=1,
).cuda()
else:
vae_input = torch.concat(
[
torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.nn.functional.interpolate(first_frame[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1),
torch.zeros(3, self.config.target_video_length - 1, h, w),
],
dim=1,
).cuda()
],
self.config,
)[0]
vae_encoder_out = self.vae_encoder.encode([vae_input], self.config)[0]
if self.config.get("lazy_load", False) or self.config.get("unload_modules", False):
del self.vae_encoder
torch.cuda.empty_cache()
......@@ -293,7 +322,7 @@ class WanRunner(DefaultRunner):
vae_encoder_out = torch.concat([msk, vae_encoder_out]).to(GET_DTYPE())
return vae_encoder_out
def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img):
def get_encoder_output_i2v(self, clip_encoder_out, vae_encoder_out, text_encoder_output, img=None):
image_encoder_output = {
"clip_encoder_out": clip_encoder_out,
"vae_encoder_out": vae_encoder_out,
......@@ -305,7 +334,7 @@ class WanRunner(DefaultRunner):
def set_target_shape(self):
num_channels_latents = self.config.get("num_channels_latents", 16)
if self.config.task == "i2v":
if self.config.task in ["i2v", "flf2v"]:
self.config.target_shape = (
num_channels_latents,
(self.config.target_video_length - 1) // self.config.vae_stride[0] + 1,
......@@ -435,7 +464,7 @@ class Wan22DenseRunner(WanRunner):
"cpu_offload": vae_offload,
"offload_cache": self.config.get("vae_offload_cache", False),
}
if self.config.task != "i2v":
if self.config.task != ["i2v", "flf2v"]:
return None
else:
return Wan2_2_VAE(**vae_config)
......
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=1
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1 \
--task flf2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan/wan_flf2v.json \
--prompt "CG animation style, a small blue bird takes off from the ground, flapping its wings. The bird’s feathers are delicate, with a unique pattern on its chest. The background shows a blue sky with white clouds under bright sunshine. The camera follows the bird upward, capturing its flight and the vastness of the sky from a close-up, low-angle perspective." \
--negative_prompt "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--image_path ${lightx2v_path}/assets/inputs/imgs/flf2v_input_first_frame-fs8.png \
--last_frame_path ${lightx2v_path}/assets/inputs/imgs/flf2v_input_last_frame-fs8.png \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_flf2v.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment