Unverified Commit f7cdbcb5 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

multi-person & animate & podcast (#554)



- 服务化功能新增(前端+后端):
1、seko-talk 模型支持多人输入
2、支持播客合成与管理
3、支持wan2.2 animate 模型

- 后端接口新增:
1、 基于火山的播客websocket合成接口,支持边合成边听
2、播客的查询管理接口
3、基于 yolo 的多人人脸检测接口
4、音频多人切分接口

- 推理代码侵入式修改
1、将 animate 相关的 输入文件路径(mask/image/pose等)从固定写死的config中移除到可变的input_info中
2、animate的预处理相关代码包装成接口供服务化使用

@xinyiqin

---------
Co-authored-by: default avatarqinxinyi <qxy118045534@163.com>
parent 61dd69ca
......@@ -3,6 +3,7 @@ import ctypes
import gc
import json
import os
import sys
import tempfile
import threading
import traceback
......@@ -11,6 +12,7 @@ import torch
import torch.distributed as dist
from loguru import logger
import lightx2v
from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.infer import init_runner # noqa
from lightx2v.utils.input_info import set_input_info
......@@ -20,6 +22,12 @@ from lightx2v.utils.set_config import set_config, set_parallel_config
from lightx2v.utils.utils import seed_all
def init_tools_preprocess():
preprocess_path = os.path.abspath(os.path.join(lightx2v.__path__[0], "..", "tools", "preprocess"))
assert os.path.exists(preprocess_path), f"lightx2v tools preprocess path not found: {preprocess_path}"
sys.path.append(preprocess_path)
class BaseWorker:
@ProfilingContext4DebugL1("Init Worker Worker Cost:")
def __init__(self, args):
......@@ -49,6 +57,9 @@ class BaseWorker:
self.input_info.save_result_path = params.get("save_result_path", "")
self.input_info.seed = params.get("seed", self.input_info.seed)
self.input_info.audio_path = params.get("audio_path", "")
for k, v in params.get("processed_video_paths", {}).items():
logger.info(f"set {k} to {v}")
setattr(self.input_info, k, v)
async def prepare_input_image(self, params, inputs, tmp_dir, data_manager):
input_image_path = inputs.get("input_image", "")
......@@ -59,8 +70,55 @@ class BaseWorker:
img_data = await data_manager.load_bytes(input_image_path)
with open(tmp_image_path, "wb") as fout:
fout.write(img_data)
params["image_path"] = tmp_image_path
async def prepare_input_video(self, params, inputs, tmp_dir, data_manager):
if not self.is_animate_model():
return
init_tools_preprocess()
from preprocess_data import get_preprocess_parser, process_input_video
result_paths = {}
if self.rank == 0:
tmp_image_path = params.get("image_path", "")
assert os.path.exists(tmp_image_path), f"input_image should be save by prepare_input_image but not valid: {tmp_image_path}"
# prepare tmp input video
input_video_path = inputs.get("input_video", "")
tmp_video_path = os.path.join(tmp_dir, input_video_path)
processed_video_path = os.path.join(tmp_dir, "processe_results")
video_data = await data_manager.load_bytes(input_video_path)
with open(tmp_video_path, "wb") as fout:
fout.write(video_data)
# prepare preprocess args
pre_args = get_preprocess_parser().parse_args([])
pre_args.ckpt_path = self.runner.config["model_path"] + "/process_checkpoint"
pre_args.video_path = tmp_video_path
pre_args.refer_path = tmp_image_path
pre_args.save_path = processed_video_path
pre_args.replace_flag = self.runner.config.get("replace_flag", False)
pre_config = self.runner.config.get("preprocess_config", {})
pre_keys = ["resolution_area", "fps", "replace_flag", "retarget_flag", "use_flux", "iterations", "k", "w_len", "h_len"]
for k in pre_keys:
if k in pre_config:
setattr(pre_args, k, pre_config[k])
process_input_video(pre_args)
result_paths = {
"src_pose_path": os.path.join(processed_video_path, "src_pose.mp4"),
"src_face_path": os.path.join(processed_video_path, "src_face.mp4"),
"src_ref_images": os.path.join(processed_video_path, "src_ref.png"),
}
if pre_args.replace_flag:
result_paths["src_bg_path"] = os.path.join(processed_video_path, "src_bg.mp4")
result_paths["src_mask_path"] = os.path.join(processed_video_path, "src_mask.mp4")
params["image_path"] = tmp_image_path
# for dist, broadcast the video processed result to all ranks
result_paths = await self.broadcast_data(result_paths, 0)
for p in result_paths.values():
assert os.path.exists(p), f"Input video processed result not found: {p}!"
params["processed_video_paths"] = result_paths
async def prepare_input_audio(self, params, inputs, tmp_dir, data_manager):
input_audio_path = inputs.get("input_audio", "")
......@@ -72,9 +130,20 @@ class BaseWorker:
tmp_audio_path = stream_audio_path
if input_audio_path and self.is_audio_model() and isinstance(tmp_audio_path, str):
audio_data = await data_manager.load_bytes(input_audio_path)
with open(tmp_audio_path, "wb") as fout:
fout.write(audio_data)
extra_audio_inputs = params.get("extra_inputs", {}).get("input_audio", [])
# for multi-person audio directory input
if len(extra_audio_inputs) > 0:
os.makedirs(tmp_audio_path, exist_ok=True)
for inp in extra_audio_inputs:
tmp_path = os.path.join(tmp_dir, inputs[inp])
inp_data = await data_manager.load_bytes(inputs[inp])
with open(tmp_path, "wb") as fout:
fout.write(inp_data)
else:
audio_data = await data_manager.load_bytes(input_audio_path)
with open(tmp_audio_path, "wb") as fout:
fout.write(audio_data)
params["audio_path"] = tmp_audio_path
......@@ -83,7 +152,6 @@ class BaseWorker:
tmp_video_path = os.path.join(tmp_dir, output_video_path)
if data_manager.name == "local":
tmp_video_path = os.path.join(data_manager.local_dir, output_video_path)
# for stream video output, value is dict
stream_video_path = params.get("output_video", None)
if stream_video_path is not None:
......@@ -129,6 +197,32 @@ class BaseWorker:
def is_audio_model(self):
return "audio" in self.runner.config["model_cls"] or "seko_talk" in self.runner.config["model_cls"]
def is_animate_model(self):
return self.runner.config.get("task") == "animate"
async def broadcast_data(self, data, src_rank=0):
if self.world_size <= 1:
return data
if self.rank == src_rank:
val = json.dumps(data, ensure_ascii=False).encode("utf-8")
T = torch.frombuffer(bytearray(val), dtype=torch.uint8).to(device="cuda")
S = torch.tensor([T.shape[0]], dtype=torch.int32).to(device="cuda")
logger.info(f"hub rank {self.rank} send data: {data}")
else:
S = torch.zeros(1, dtype=torch.int32, device="cuda")
dist.broadcast(S, src=src_rank)
if self.rank != src_rank:
T = torch.zeros(S.item(), dtype=torch.uint8, device="cuda")
dist.broadcast(T, src=src_rank)
if self.rank != src_rank:
val = T.cpu().numpy().tobytes()
data = json.loads(val.decode("utf-8"))
logger.info(f"hub rank {self.rank} recv data: {data}")
return data
class RunnerThread(threading.Thread):
def __init__(self, loop, future, run_func, rank, *args, **kwargs):
......@@ -197,6 +291,7 @@ class PipelineWorker(BaseWorker):
with tempfile.TemporaryDirectory() as tmp_dir:
await self.prepare_input_image(params, inputs, tmp_dir, data_manager)
await self.prepare_input_audio(params, inputs, tmp_dir, data_manager)
await self.prepare_input_video(params, inputs, tmp_dir, data_manager)
tmp_video_path, output_video_path = self.prepare_output_video(params, outputs, tmp_dir, data_manager)
logger.info(f"run params: {params}, {inputs}, {outputs}")
......
......@@ -91,7 +91,30 @@ def main():
default=None,
help="The file of the source mask. Default None.",
)
parser.add_argument(
"--src_pose_path",
type=str,
default=None,
help="The file of the source pose. Default None.",
)
parser.add_argument(
"--src_face_path",
type=str,
default=None,
help="The file of the source face. Default None.",
)
parser.add_argument(
"--src_bg_path",
type=str,
default=None,
help="The file of the source background. Default None.",
)
parser.add_argument(
"--src_mask_path",
type=str,
default=None,
help="The file of the source mask. Default None.",
)
parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file")
parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)")
args = parser.parse_args()
......
......@@ -238,9 +238,9 @@ class WanAnimateRunner(WanRunner):
return y, pose_latents
def prepare_input(self):
src_pose_path = self.config["src_pose_path"] if "src_pose_path" in self.config else None
src_face_path = self.config["src_face_path"] if "src_face_path" in self.config else None
src_ref_path = self.config["src_ref_images"] if "src_ref_images" in self.config else None
src_pose_path = self.input_info.src_pose_path
src_face_path = self.input_info.src_face_path
src_ref_path = self.input_info.src_ref_images
self.cond_images, self.face_images, self.refer_images = self.prepare_source(src_pose_path, src_face_path, src_ref_path)
self.refer_pixel_values = torch.tensor(self.refer_images / 127.5 - 1, dtype=GET_DTYPE(), device="cuda").permute(2, 0, 1) # chw
self.latent_t = self.config["target_video_length"] // self.config["vae_stride"][0] + 1
......@@ -258,8 +258,8 @@ class WanAnimateRunner(WanRunner):
self.face_images = self.inputs_padding(self.face_images, target_len)
if self.config["replace_flag"] if "replace_flag" in self.config else False:
src_bg_path = self.config["src_bg_path"]
src_mask_path = self.config["src_mask_path"]
src_bg_path = self.input_info.src_bg_path
src_mask_path = self.input_info.src_mask_path
self.bg_images, self.mask_images = self.prepare_source_for_replace(src_bg_path, src_mask_path)
self.bg_images = self.inputs_padding(self.bg_images, target_len)
self.mask_images = self.inputs_padding(self.mask_images, target_len)
......
......@@ -94,6 +94,11 @@ class AnimateInputInfo:
prompt_enhanced: str = field(default_factory=str)
negative_prompt: str = field(default_factory=str)
image_path: str = field(default_factory=str)
src_pose_path: str = field(default_factory=str)
src_face_path: str = field(default_factory=str)
src_ref_images: str = field(default_factory=str)
src_bg_path: str = field(default_factory=str)
src_mask_path: str = field(default_factory=str)
save_result_path: str = field(default_factory=str)
return_result_tensor: bool = field(default_factory=lambda: False)
# shape related
......@@ -181,6 +186,11 @@ def set_input_info(args):
prompt=args.prompt,
negative_prompt=args.negative_prompt,
image_path=args.image_path,
src_pose_path=args.src_pose_path,
src_face_path=args.src_face_path,
src_ref_images=args.src_ref_images,
src_bg_path=args.src_bg_path,
src_mask_path=args.src_mask_path,
save_result_path=args.save_result_path,
return_result_tensor=args.return_result_tensor,
)
......
......@@ -25,6 +25,9 @@ python -m lightx2v.infer \
--task animate \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_animate.json \
--src_pose_path ${lightx2v_path}/save_results/animate/process_results/src_pose.mp4 \
--src_face_path ${lightx2v_path}/save_results/animate/process_results/src_face.mp4 \
--src_ref_images ${lightx2v_path}/save_results/animate/process_results/src_ref.png \
--prompt "视频中的人在做动作" \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_animate.mp4
......@@ -25,6 +25,9 @@ python -m lightx2v.infer \
--task animate \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_animate_lora.json \
--src_pose_path ${lightx2v_path}/save_results/animate/process_results/src_pose.mp4 \
--src_face_path ${lightx2v_path}/save_results/animate/process_results/src_face.mp4 \
--src_ref_images ${lightx2v_path}/save_results/animate/process_results/src_ref.png \
--prompt "视频中的人在做动作" \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_animate_lora.mp4
......@@ -29,6 +29,11 @@ python -m lightx2v.infer \
--task animate \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan22/wan_animate_replace_4090.json \
--src_pose_path ${lightx2v_path}/save_results/animate/process_results/src_pose.mp4 \
--src_face_path ${lightx2v_path}/save_results/animate/process_results/src_face.mp4 \
--src_ref_images ${lightx2v_path}/save_results/animate/process_results/src_ref.png \
--src_bg_path ${lightx2v_path}/save_results/animate/process_results/src_bg.mp4 \
--src_mask_path ${lightx2v_path}/save_results/animate/process_results/src_mask.mp4 \
--prompt "视频中的人在做动作" \
--negative_prompt "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan22_replace.mp4
......@@ -5,7 +5,7 @@ import os
from process_pipepline import ProcessPipeline
def _parse_args():
def get_preprocess_parser():
parser = argparse.ArgumentParser(description="The preprocessing pipeline for Wan-animate.")
parser.add_argument("--ckpt_path", type=str, default=None, help="The path to the preprocessing model's checkpoint directory. ")
......@@ -47,13 +47,10 @@ def _parse_args():
default=1,
help="The number of subdivisions for the grid along the 'h' dimension. A higher value results in a more detailed contour. A value of 1 means no subdivision is performed.",
)
args = parser.parse_args()
return args
return parser
if __name__ == "__main__":
args = _parse_args()
def process_input_video(args):
args_dict = vars(args)
print(args_dict)
......@@ -83,3 +80,9 @@ if __name__ == "__main__":
use_flux=args.use_flux,
replace_flag=args.replace_flag,
)
if __name__ == "__main__":
parser = get_preprocess_parser()
args = parser.parse_args()
process_input_video(args)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment