Commit ab1b2790 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

Deploy server and worker (#284)



* Init deploy: not ok

* Test data_manager & task_manager

* pipeline is no need for worker

* Update worker text_encoder

* deploy: submit task

* add apis

* Test pipelineRunner

* Fix pipeline

* Tidy worker & test PipelineWorker ok

* Tidy code

* Fix multi_stage for wan2.1 t2v & i2v

* api query task, get result & report subtasks failed when workers stop

* Add model list functionality to Pipeline and API

* Add task cancel and task resume  to API

* Add RabbitMQ queue manager

* update local task manager atomic

* support postgreSQL task manager, add lifespan async init

* worker print -> logger

* Add S3 data manager, delete temp objects after finished.

* fix worker

* release fetch queue msg when closed, run stuck worker in another thread, stop worker when process down.

* DiTWorker run with thread & tidy logger print

* Init monitor without test

* fix monitor

* github OAuth and jwt token access & static demo html page

* Add user to task, ok for local task manager & update demo ui

* sql task manager support users

* task list with pages

* merge main fix

* Add proxy for auth request

* support wan audio

* worker ping subtask and ping life, fix rabbitmq async get,

* s3 data manager with async api & tidy monitor config

* fix merge main & update req.txt & fix html view video error

* Fix distributed worker

* LImit user visit freq

* Tidy

* Fix only rank save

* Fix audio input

* Fix worker fetch None

* index.html abs path to rel path

* Fix dist worker stuck

* support publish output video to rtmp & graceful stop running dit step or segment step

* Add VAReader

* Enhance VAReader with torch dist

* Fix audio stream input

* fix merge refractor main, support stream input_audio and output_video

* fix audio read with prev frames & fix end take frames & tidy worker end

* split audio model to 4 workers & fix audio end frame

* fix ping subtask with queue

* Fix audio worker put block & add whep, whip without test ok

* Tidy va recorder & va reader log, thread canel within 30s

* Fix dist worker stuck: broadcast stop signal

* Tidy

* record task active_elapse & subtask status_elapse

* Design prometheus metrics

* Tidy prometheus metrics

* Fix merge main

* send sigint to ffmpeg process

* Fix gstreamer pull audio by whep & Dockerfile for gstreamer & check params when submitting

* Fix merge main

* Query task with more info & va_reader buffer size = 1

* Fix va_recorder

* Add config for prev_frames

* update frontend

* update frontend

* update frontend

* update frontend
merge

* update frontend & partial backend

* Different rank for va_recorder and va_reader

* Fix mem leak: only one rank publish video, other rank should pop gen vids

* fix task category

* va_reader pre-alloc tensor & va_recorder send frames all & fix dist cancel infer

* Fix prev_frame_length

* Tidy

* Tidy

* update frontend & backend

* Fix lint error

* recover some files

* Tidy

* lint code

---------
Co-authored-by: default avatarliuliang1 <liuliang1@sensetime.com>
Co-authored-by: default avatarunknown <qinxinyi@sensetime.com>
parent acacd26f
This diff is collapsed.
import uuid
from enum import Enum
from re import T
from loguru import logger
from lightx2v.deploy.common.utils import current_time, data_name
class TaskStatus(Enum):
CREATED = 1
PENDING = 2
RUNNING = 3
SUCCEED = 4
FAILED = 5
CANCEL = 6
ActiveStatus = [TaskStatus.CREATED, TaskStatus.PENDING, TaskStatus.RUNNING]
FinishedStatus = [TaskStatus.SUCCEED, TaskStatus.FAILED, TaskStatus.CANCEL]
class BaseTaskManager:
def __init__(self):
pass
async def init(self):
pass
async def close(self):
pass
async def insert_user_if_not_exists(self, user_info):
raise NotImplementedError
async def query_user(self, user_id):
raise NotImplementedError
async def insert_task(self, task, subtasks):
raise NotImplementedError
async def list_tasks(self, **kwargs):
raise NotImplementedError
async def query_task(self, task_id, user_id=None, only_task=True):
raise NotImplementedError
async def next_subtasks(self, task_id):
raise NotImplementedError
async def run_subtasks(self, subtasks, worker_identity):
raise NotImplementedError
async def ping_subtask(self, task_id, worker_name, worker_identity):
raise NotImplementedError
async def finish_subtasks(self, task_id, status, worker_identity=None, worker_name=None, fail_msg=None, should_running=False):
raise NotImplementedError
async def cancel_task(self, task_id, user_id=None):
raise NotImplementedError
async def resume_task(self, task_id, all_subtask=False, user_id=None):
raise NotImplementedError
def fmt_dict(self, data):
for k in ["status"]:
if k in data:
data[k] = data[k].name
def parse_dict(self, data):
for k in ["status"]:
if k in data:
data[k] = TaskStatus[data[k]]
async def create_user(self, user_info):
assert user_info["source"] == "github", f"do not support {user_info['source']} user!"
cur_t = current_time()
user_id = f"{user_info['source']}_{user_info['id']}"
data = {
"user_id": user_id,
"source": user_info["source"],
"id": user_info["id"],
"username": user_info["username"],
"email": user_info["email"],
"homepage": user_info["homepage"],
"avatar_url": user_info["avatar_url"],
"create_t": cur_t,
"update_t": cur_t,
"extra_info": "",
"tag": "",
}
assert await self.insert_user_if_not_exists(data), f"create user {data} failed"
return user_id
async def create_task(self, worker_keys, workers, params, inputs, outputs, user_id):
task_type, model_cls, stage = worker_keys
cur_t = current_time()
task_id = str(uuid.uuid4())
task = {
"task_id": task_id,
"task_type": task_type,
"model_cls": model_cls,
"stage": stage,
"params": params,
"create_t": cur_t,
"update_t": cur_t,
"status": TaskStatus.CREATED,
"extra_info": "",
"tag": "",
"inputs": {x: data_name(x, task_id) for x in inputs},
"outputs": {x: data_name(x, task_id) for x in outputs},
"user_id": user_id,
}
self.mark_task_start(task)
subtasks = []
for worker_name, worker_item in workers.items():
subtasks.append(
{
"task_id": task_id,
"worker_name": worker_name,
"inputs": {x: data_name(x, task_id) for x in worker_item["inputs"]},
"outputs": {x: data_name(x, task_id) for x in worker_item["outputs"]},
"queue": worker_item["queue"],
"previous": worker_item["previous"],
"status": TaskStatus.CREATED,
"worker_identity": "",
"result": "",
"fail_time": 0,
"extra_info": "",
"create_t": cur_t,
"update_t": cur_t,
"ping_t": 0.0,
"infer_cost": -1.0,
}
)
self.mark_subtask_change(subtasks[-1], None, TaskStatus.CREATED)
ret = await self.insert_task(task, subtasks)
# if insert error
if not ret:
self.mark_task_end(task, TaskStatus.FAILED)
for sub in subtasks:
self.mark_subtask_change(sub, sub["status"], TaskStatus.FAILED)
assert ret, f"create task {task_id} failed"
return task_id
async def mark_server_restart(self):
if self.metrics_monitor:
tasks = await self.list_tasks(status=ActiveStatus)
subtasks = await self.list_tasks(status=ActiveStatus, subtasks=True)
logger.warning(f"Mark system restart, {len(tasks)} tasks, {len(subtasks)} subtasks")
self.metrics_monitor.record_task_recover(tasks)
self.metrics_monitor.record_subtask_recover(subtasks)
def mark_task_start(self, task):
t = current_time()
if not isinstance(task["extra_info"], dict):
task["extra_info"] = {}
if "active_elapse" in task["extra_info"]:
del task["extra_info"]["active_elapse"]
task["extra_info"]["start_t"] = t
logger.info(f"Task {task['task_id']} active start")
if self.metrics_monitor:
self.metrics_monitor.record_task_start(task)
def mark_task_end(self, task, end_status):
if "start_t" not in task["extra_info"]:
logger.warning(f"Task {task} has no start time")
else:
elapse = current_time() - task["extra_info"]["start_t"]
task["extra_info"]["active_elapse"] = elapse
del task["extra_info"]["start_t"]
logger.info(f"Task {task['task_id']} active end with [{end_status}], elapse: {elapse}")
if self.metrics_monitor:
self.metrics_monitor.record_task_end(task, end_status, elapse)
def mark_subtask_change(self, subtask, old_status, new_status, fail_msg=None):
t = current_time()
if not isinstance(subtask["extra_info"], dict):
subtask["extra_info"] = {}
if isinstance(fail_msg, str) and len(fail_msg) > 0:
subtask["extra_info"]["fail_msg"] = fail_msg
elif "fail_msg" in subtask["extra_info"]:
del subtask["extra_info"]["fail_msg"]
if old_status == new_status:
logger.warning(f"Subtask {subtask} update same status: {old_status} vs {new_status}")
return
elapse, elapse_key = None, None
if old_status in ActiveStatus:
if "start_t" not in subtask["extra_info"]:
logger.warning(f"Subtask {subtask} has no start time, status: {old_status}")
else:
elapse = t - subtask["extra_info"]["start_t"]
elapse_key = f"{old_status.name}-{new_status.name}"
if "elapses" not in subtask["extra_info"]:
subtask["extra_info"]["elapses"] = {}
subtask["extra_info"]["elapses"][elapse_key] = elapse
del subtask["extra_info"]["start_t"]
if new_status in ActiveStatus:
subtask["extra_info"]["start_t"] = t
if new_status == TaskStatus.CREATED and "elapses" in subtask["extra_info"]:
del subtask["extra_info"]["elapses"]
logger.info(
f"Subtask {subtask['task_id']} {subtask['worker_name']} status changed: \
[{old_status}] -> [{new_status}], {elapse_key}: {elapse}, fail_msg: {fail_msg}"
)
if self.metrics_monitor:
self.metrics_monitor.record_subtask_change(subtask, old_status, new_status, elapse_key, elapse)
# Import task manager implementations
from .local_task_manager import LocalTaskManager # noqa
from .sql_task_manager import PostgresSQLTaskManager # noqa
__all__ = ["BaseTaskManager", "LocalTaskManager", "PostgresSQLTaskManager"]
import asyncio
import json
import os
from lightx2v.deploy.common.utils import class_try_catch_async, current_time, str2time, time2str
from lightx2v.deploy.task_manager import ActiveStatus, BaseTaskManager, FinishedStatus, TaskStatus
class LocalTaskManager(BaseTaskManager):
def __init__(self, local_dir, metrics_monitor=None):
self.local_dir = local_dir
if not os.path.exists(self.local_dir):
os.makedirs(self.local_dir)
self.metrics_monitor = metrics_monitor
def get_task_filename(self, task_id):
return os.path.join(self.local_dir, f"task_{task_id}.json")
def get_user_filename(self, user_id):
return os.path.join(self.local_dir, f"user_{user_id}.json")
def fmt_dict(self, data):
super().fmt_dict(data)
for k in ["create_t", "update_t", "ping_t"]:
if k in data:
data[k] = time2str(data[k])
def parse_dict(self, data):
super().parse_dict(data)
for k in ["create_t", "update_t", "ping_t"]:
if k in data:
data[k] = str2time(data[k])
def save(self, task, subtasks, with_fmt=True):
info = {"task": task, "subtasks": subtasks}
if with_fmt:
self.fmt_dict(info["task"])
[self.fmt_dict(x) for x in info["subtasks"]]
out_name = self.get_task_filename(task["task_id"])
with open(out_name, "w") as fout:
fout.write(json.dumps(info, indent=4, ensure_ascii=False))
def load(self, task_id, user_id=None, only_task=False):
fpath = self.get_task_filename(task_id)
info = json.load(open(fpath))
task, subtasks = info["task"], info["subtasks"]
if user_id is not None and task["user_id"] != user_id:
raise Exception(f"Task {task_id} is not belong to user {user_id}")
self.parse_dict(task)
if only_task:
return task
for sub in subtasks:
self.parse_dict(sub)
return task, subtasks
@class_try_catch_async
async def insert_task(self, task, subtasks):
self.save(task, subtasks)
return True
@class_try_catch_async
async def list_tasks(self, **kwargs):
tasks = []
fs = [os.path.join(self.local_dir, f) for f in os.listdir(self.local_dir)]
for f in os.listdir(self.local_dir):
if not f.startswith("task_"):
continue
fpath = os.path.join(self.local_dir, f)
info = json.load(open(fpath))
if kwargs.get("subtasks", False):
items = info["subtasks"]
assert "user_id" not in kwargs, "user_id is not allowed when subtasks is True"
else:
items = [info["task"]]
for task in items:
self.parse_dict(task)
if "user_id" in kwargs and task["user_id"] != kwargs["user_id"]:
continue
if "status" in kwargs:
if isinstance(kwargs["status"], list) and task["status"] not in kwargs["status"]:
continue
elif kwargs["status"] != task["status"]:
continue
if "start_created_t" in kwargs and kwargs["start_created_t"] > task["create_t"]:
continue
if "end_created_t" in kwargs and kwargs["end_created_t"] < task["create_t"]:
continue
if "start_updated_t" in kwargs and kwargs["start_updated_t"] > task["update_t"]:
continue
if "end_updated_t" in kwargs and kwargs["end_updated_t"] < task["update_t"]:
continue
if "start_ping_t" in kwargs and kwargs["start_ping_t"] > task["ping_t"]:
continue
if "end_ping_t" in kwargs and kwargs["end_ping_t"] < task["ping_t"]:
continue
tasks.append(task)
if "count" in kwargs:
return len(tasks)
tasks = sorted(tasks, key=lambda x: x["create_t"], reverse=True)
if "offset" in kwargs:
tasks = tasks[kwargs["offset"] :]
if "limit" in kwargs:
tasks = tasks[: kwargs["limit"]]
return tasks
@class_try_catch_async
async def query_task(self, task_id, user_id=None, only_task=True):
return self.load(task_id, user_id, only_task)
@class_try_catch_async
async def next_subtasks(self, task_id):
task, subtasks = self.load(task_id)
if task["status"] not in ActiveStatus:
return []
succeeds = set()
for sub in subtasks:
if sub["status"] == TaskStatus.SUCCEED:
succeeds.add(sub["worker_name"])
nexts = []
for sub in subtasks:
if sub["status"] == TaskStatus.CREATED:
dep_ok = True
for prev in sub["previous"]:
if prev not in succeeds:
dep_ok = False
break
if dep_ok:
self.mark_subtask_change(sub, sub["status"], TaskStatus.PENDING)
sub["params"] = task["params"]
sub["status"] = TaskStatus.PENDING
sub["update_t"] = current_time()
nexts.append(sub)
if len(nexts) > 0:
task["status"] = TaskStatus.PENDING
task["update_t"] = current_time()
self.save(task, subtasks)
return nexts
@class_try_catch_async
async def run_subtasks(self, cands, worker_identity):
valids = []
for cand in cands:
task_id = cand["task_id"]
worker_name = cand["worker_name"]
task, subtasks = self.load(task_id)
if task["status"] in [TaskStatus.SUCCEED, TaskStatus.FAILED, TaskStatus.CANCEL]:
continue
for sub in subtasks:
if sub["worker_name"] == worker_name:
self.mark_subtask_change(sub, sub["status"], TaskStatus.RUNNING)
sub["status"] = TaskStatus.RUNNING
sub["worker_identity"] = worker_identity
sub["update_t"] = current_time()
task["status"] = TaskStatus.RUNNING
task["update_t"] = current_time()
task["ping_t"] = current_time()
self.save(task, subtasks)
valids.append(cand)
break
return valids
@class_try_catch_async
async def ping_subtask(self, task_id, worker_name, worker_identity):
task, subtasks = self.load(task_id)
for sub in subtasks:
if sub["worker_name"] == worker_name:
pre = sub["worker_identity"]
assert pre == worker_identity, f"worker identity not matched: {pre} vs {worker_identity}"
sub["ping_t"] = current_time()
self.save(task, subtasks)
return True
return False
@class_try_catch_async
async def finish_subtasks(self, task_id, status, worker_identity=None, worker_name=None, fail_msg=None, should_running=False):
task, subtasks = self.load(task_id)
subs = subtasks
if worker_name:
subs = [sub for sub in subtasks if sub["worker_name"] == worker_name]
assert len(subs) >= 1, f"no worker task_id={task_id}, name={worker_name}"
if worker_identity:
pre = subs[0]["worker_identity"]
assert pre == worker_identity, f"worker identity not matched: {pre} vs {worker_identity}"
assert status in [TaskStatus.SUCCEED, TaskStatus.FAILED], f"invalid finish status: {status}"
for sub in subs:
if sub["status"] not in FinishedStatus:
if should_running and sub["status"] != TaskStatus.RUNNING:
print(f"task {task_id} is not running, skip finish subtask: {sub}")
continue
self.mark_subtask_change(sub, sub["status"], status, fail_msg=fail_msg)
sub["status"] = status
sub["update_t"] = current_time()
if task["status"] == TaskStatus.CANCEL:
self.save(task, subtasks)
return TaskStatus.CANCEL
running_subs = []
failed_sub = False
for sub in subtasks:
if sub["status"] not in FinishedStatus:
running_subs.append(sub)
if sub["status"] == TaskStatus.FAILED:
failed_sub = True
# some subtask failed, we should fail all other subtasks
if failed_sub:
if task["status"] != TaskStatus.FAILED:
self.mark_task_end(task, TaskStatus.FAILED)
task["status"] = TaskStatus.FAILED
task["update_t"] = current_time()
for sub in running_subs:
self.mark_subtask_change(sub, sub["status"], TaskStatus.FAILED, fail_msg="other subtask failed")
sub["status"] = TaskStatus.FAILED
sub["update_t"] = current_time()
self.save(task, subtasks)
return TaskStatus.FAILED
# all subtasks finished and all succeed
elif len(running_subs) == 0:
if task["status"] != TaskStatus.SUCCEED:
self.mark_task_end(task, TaskStatus.SUCCEED)
task["status"] = TaskStatus.SUCCEED
task["update_t"] = current_time()
self.save(task, subtasks)
return TaskStatus.SUCCEED
self.save(task, subtasks)
return None
@class_try_catch_async
async def cancel_task(self, task_id, user_id=None):
task, subtasks = self.load(task_id, user_id)
if task["status"] not in ActiveStatus:
return f"Task {task_id} is not in active status (current status: {task['status']}). Only tasks with status CREATED, PENDING, or RUNNING can be cancelled."
for sub in subtasks:
if sub["status"] not in FinishedStatus:
self.mark_subtask_change(sub, sub["status"], TaskStatus.CANCEL)
sub["status"] = TaskStatus.CANCEL
sub["update_t"] = current_time()
self.mark_task_end(task, TaskStatus.CANCEL)
task["status"] = TaskStatus.CANCEL
task["update_t"] = current_time()
self.save(task, subtasks)
return True
@class_try_catch_async
async def resume_task(self, task_id, all_subtask=False, user_id=None):
task, subtasks = self.load(task_id, user_id)
# the task is not finished
if task["status"] not in FinishedStatus:
return False
# the task is no need to resume
if not all_subtask and task["status"] == TaskStatus.SUCCEED:
return False
for sub in subtasks:
if all_subtask or sub["status"] != TaskStatus.SUCCEED:
self.mark_subtask_change(sub, None, TaskStatus.CREATED)
sub["status"] = TaskStatus.CREATED
sub["update_t"] = current_time()
sub["ping_t"] = 0.0
self.mark_task_start(task)
task["status"] = TaskStatus.CREATED
task["update_t"] = current_time()
self.save(task, subtasks)
return True
@class_try_catch_async
async def insert_user_if_not_exists(self, user_info):
fpath = self.get_user_filename(user_info["user_id"])
if os.path.exists(fpath):
return True
self.fmt_dict(user_info)
with open(fpath, "w") as fout:
fout.write(json.dumps(user_info, indent=4, ensure_ascii=False))
return True
@class_try_catch_async
async def query_user(self, user_id):
fpath = self.get_user_filename(user_id)
if not os.path.exists(fpath):
return None
data = json.load(open(fpath))
self.parse_dict(data)
return data
async def test():
from lightx2v.deploy.common.pipeline import Pipeline
p = Pipeline("/data/nvme1/liuliang1/lightx2v/configs/model_pipeline.json")
m = LocalTaskManager("/data/nvme1/liuliang1/lightx2v/local_task")
await m.init()
keys = ["t2v", "wan2.1", "multi_stage"]
workers = p.get_workers(keys)
inputs = p.get_inputs(keys)
outputs = p.get_outputs(keys)
params = {
"prompt": "fake input prompts",
"resolution": {
"height": 233,
"width": 456,
},
}
user_info = {
"source": "github",
"id": "test-id-233",
"username": "test-username-233",
"email": "test-email-233@test.com",
"homepage": "https://test.com",
"avatar_url": "https://test.com/avatar.png",
}
user_id = await m.create_user(user_info)
print(" - create_user:", user_id)
user = await m.query_user(user_id)
print(" - query_user:", user)
task_id = await m.create_task(keys, workers, params, inputs, outputs, user_id)
print(" - create_task:", task_id)
tasks = await m.list_tasks()
print(" - list_tasks:", tasks)
task = await m.query_task(task_id)
print(" - query_task:", task)
subtasks = await m.next_subtasks(task_id)
print(" - next_subtasks:", subtasks)
await m.run_subtasks(subtasks, "fake-worker")
await m.finish_subtasks(task_id, TaskStatus.FAILED)
await m.cancel_task(task_id)
await m.resume_task(task_id)
for sub in subtasks:
await m.finish_subtasks(sub["task_id"], TaskStatus.SUCCEED, worker_name=sub["worker_name"], worker_identity="fake-worker")
subtasks = await m.next_subtasks(task_id)
print(" - final next_subtasks:", subtasks)
task = await m.query_task(task_id)
print(" - final task:", task)
await m.close()
if __name__ == "__main__":
asyncio.run(test())
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
from abc import ABC from abc import ABC
import torch
import torch.distributed as dist
from lightx2v.utils.utils import save_videos_grid from lightx2v.utils.utils import save_videos_grid
...@@ -147,3 +150,28 @@ class BaseRunner(ABC): ...@@ -147,3 +150,28 @@ class BaseRunner(ABC):
def end_run(self): def end_run(self):
pass pass
def check_stop(self):
"""Check if the stop signal is received"""
rank, world_size = 0, 1
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
signal_rank = world_size - 1
stopped = 0
if rank == signal_rank and hasattr(self, "stop_signal") and self.stop_signal:
stopped = 1
if world_size > 1:
if rank == signal_rank:
t = torch.tensor([stopped], dtype=torch.int32).to(device="cuda")
else:
t = torch.zeros(1, dtype=torch.int32, device="cuda")
dist.broadcast(t, src=signal_rank)
stopped = t.item()
print(f"rank {rank} recv stopped: {stopped}")
if stopped == 1:
raise Exception(f"find rank: {rank} stop_signal, stop running, it's an expected behavior")
...@@ -111,6 +111,9 @@ class DefaultRunner(BaseRunner): ...@@ -111,6 +111,9 @@ class DefaultRunner(BaseRunner):
if total_steps is None: if total_steps is None:
total_steps = self.model.scheduler.infer_steps total_steps = self.model.scheduler.infer_steps
for step_index in range(total_steps): for step_index in range(total_steps):
# only for single segment, check stop signal every step
if self.video_segment_num == 1:
self.check_stop()
logger.info(f"==> step_index: {step_index + 1} / {total_steps}") logger.info(f"==> step_index: {step_index + 1} / {total_steps}")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4Debug("step_pre"):
...@@ -145,6 +148,9 @@ class DefaultRunner(BaseRunner): ...@@ -145,6 +148,9 @@ class DefaultRunner(BaseRunner):
gc.collect() gc.collect()
def read_image_input(self, img_path): def read_image_input(self, img_path):
if isinstance(img_path, Image.Image):
img_ori = img_path
else:
img_ori = Image.open(img_path).convert("RGB") img_ori = Image.open(img_path).convert("RGB")
img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda() img = TF.to_tensor(img_ori).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
return img, img_ori return img, img_ori
...@@ -219,6 +225,7 @@ class DefaultRunner(BaseRunner): ...@@ -219,6 +225,7 @@ class DefaultRunner(BaseRunner):
for segment_idx in range(self.video_segment_num): for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 segment_idx: {segment_idx + 1}/{self.video_segment_num}") logger.info(f"🔄 segment_idx: {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext(f"segment end2end {segment_idx}"): with ProfilingContext(f"segment end2end {segment_idx}"):
self.check_stop()
# 1. default do nothing # 1. default do nothing
self.init_run_segment(segment_idx) self.init_run_segment(segment_idx)
# 2. main inference loop # 2. main inference loop
......
...@@ -15,6 +15,8 @@ from loguru import logger ...@@ -15,6 +15,8 @@ from loguru import logger
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize from torchvision.transforms.functional import resize
from lightx2v.deploy.common.va_reader import VAReader
from lightx2v.deploy.common.va_recorder import VARecorder
from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter from lightx2v.models.input_encoders.hf.seko_audio.audio_adapter import AudioAdapter
from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel from lightx2v.models.input_encoders.hf.seko_audio.audio_encoder import SekoAudioEncoderModel
from lightx2v.models.networks.wan.audio_model import WanAudioModel from lightx2v.models.networks.wan.audio_model import WanAudioModel
...@@ -221,7 +223,7 @@ class AudioProcessor: ...@@ -221,7 +223,7 @@ class AudioProcessor:
def get_audio_range(self, start_frame: int, end_frame: int) -> Tuple[int, int]: def get_audio_range(self, start_frame: int, end_frame: int) -> Tuple[int, int]:
"""Calculate audio range for given frame range""" """Calculate audio range for given frame range"""
audio_frame_rate = self.audio_sr / self.target_fps audio_frame_rate = self.audio_sr / self.target_fps
return round(start_frame * audio_frame_rate), round((end_frame + 1) * audio_frame_rate) return round(start_frame * audio_frame_rate), round(end_frame * audio_frame_rate)
def segment_audio(self, audio_array: np.ndarray, expected_frames: int, max_num_frames: int, prev_frame_length: int = 5) -> List[AudioSegment]: def segment_audio(self, audio_array: np.ndarray, expected_frames: int, max_num_frames: int, prev_frame_length: int = 5) -> List[AudioSegment]:
"""Segment audio based on frame requirements""" """Segment audio based on frame requirements"""
...@@ -299,6 +301,8 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -299,6 +301,8 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_sr = self.config.get("audio_sr", 16000) audio_sr = self.config.get("audio_sr", 16000)
target_fps = self.config.get("target_fps", 16) target_fps = self.config.get("target_fps", 16)
self._audio_processor = AudioProcessor(audio_sr, target_fps) self._audio_processor = AudioProcessor(audio_sr, target_fps)
if not isinstance(self.config["audio_path"], str):
return [], 0
audio_array = self._audio_processor.load_audio(self.config["audio_path"]) audio_array = self._audio_processor.load_audio(self.config["audio_path"])
video_duration = self.config.get("video_duration", 5) video_duration = self.config.get("video_duration", 5)
...@@ -312,6 +316,9 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -312,6 +316,9 @@ class WanAudioRunner(WanRunner): # type:ignore
return audio_segments, expected_frames return audio_segments, expected_frames
def read_image_input(self, img_path): def read_image_input(self, img_path):
if isinstance(img_path, Image.Image):
ref_img = img_path
else:
ref_img = Image.open(img_path).convert("RGB") ref_img = Image.open(img_path).convert("RGB")
ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda() ref_img = TF.to_tensor(ref_img).sub_(0.5).div_(0.5).unsqueeze(0).cuda()
...@@ -449,9 +456,11 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -449,9 +456,11 @@ class WanAudioRunner(WanRunner): # type:ignore
self.prev_video = None self.prev_video = None
@ProfilingContext4Debug("Init run segment") @ProfilingContext4Debug("Init run segment")
def init_run_segment(self, segment_idx): def init_run_segment(self, segment_idx, audio_array=None):
self.segment_idx = segment_idx self.segment_idx = segment_idx
if audio_array is not None:
self.segment = AudioSegment(audio_array, 0, audio_array.shape[0], False)
else:
self.segment = self.inputs["audio_segments"][segment_idx] self.segment = self.inputs["audio_segments"][segment_idx]
self.config.seed = self.config.seed + segment_idx self.config.seed = self.config.seed + segment_idx
...@@ -477,7 +486,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -477,7 +486,7 @@ class WanAudioRunner(WanRunner): # type:ignore
# Extract relevant frames # Extract relevant frames
start_frame = 0 if self.segment_idx == 0 else self.prev_frame_length start_frame = 0 if self.segment_idx == 0 else self.prev_frame_length
start_audio_frame = 0 if self.segment_idx == 0 else int((self.prev_frame_length + 1) * self._audio_processor.audio_sr / self.config.get("target_fps", 16)) start_audio_frame = 0 if self.segment_idx == 0 else int(self.prev_frame_length * self._audio_processor.audio_sr / self.config.get("target_fps", 16))
if self.segment.is_last and self.segment.useful_length: if self.segment.is_last and self.segment.useful_length:
end_frame = self.segment.end_frame - self.segment.start_frame end_frame = self.segment.end_frame - self.segment.start_frame
...@@ -490,6 +499,14 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -490,6 +499,14 @@ class WanAudioRunner(WanRunner): # type:ignore
self.gen_video_list.append(self.gen_video[:, :, start_frame:].cpu()) self.gen_video_list.append(self.gen_video[:, :, start_frame:].cpu())
self.cut_audio_list.append(self.segment.audio_array[start_audio_frame:]) self.cut_audio_list.append(self.segment.audio_array[start_audio_frame:])
if self.va_recorder:
cur_video = vae_to_comfyui_image(self.gen_video_list[-1])
self.va_recorder.pub_livestream(cur_video, self.cut_audio_list[-1])
if self.va_reader:
self.gen_video_list.pop()
self.cut_audio_list.pop()
# Update prev_video for next iteration # Update prev_video for next iteration
self.prev_video = self.gen_video self.prev_video = self.gen_video
...@@ -497,6 +514,102 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -497,6 +514,102 @@ class WanAudioRunner(WanRunner): # type:ignore
del self.gen_video del self.gen_video
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_rank_and_world_size(self):
rank = 0
world_size = 1
if dist.is_initialized():
rank = dist.get_rank()
world_size = dist.get_world_size()
return rank, world_size
def init_va_recorder(self):
output_video_path = self.config.get("save_video_path", None)
self.va_recorder = None
if isinstance(output_video_path, dict):
assert output_video_path["type"] == "stream", f"unexcept save_video_path: {output_video_path}"
rank, world_size = self.get_rank_and_world_size()
if rank == 2 % world_size:
record_fps = self.config.get("target_fps", 16)
audio_sr = self.config.get("audio_sr", 16000)
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
record_fps = self.config["video_frame_interpolation"]["target_fps"]
self.va_recorder = VARecorder(
livestream_url=output_video_path["data"],
fps=record_fps,
sample_rate=audio_sr,
)
def init_va_reader(self):
audio_path = self.config.get("audio_path", None)
self.va_reader = None
if isinstance(audio_path, dict):
assert audio_path["type"] == "stream", f"unexcept audio_path: {audio_path}"
rank, world_size = self.get_rank_and_world_size()
target_fps = self.config.get("target_fps", 16)
max_num_frames = self.config.get("target_video_length", 81)
audio_sr = self.config.get("audio_sr", 16000)
prev_frames = self.config.get("prev_frame_length", 5)
self.va_reader = VAReader(
rank=rank,
world_size=world_size,
stream_url=audio_path["data"],
sample_rate=audio_sr,
segment_duration=max_num_frames / target_fps,
prev_duration=prev_frames / target_fps,
target_rank=1,
)
def run_main(self, total_steps=None):
try:
self.init_va_recorder()
self.init_va_reader()
logger.info(f"init va_recorder: {self.va_recorder} and va_reader: {self.va_reader}")
if self.va_reader is None:
return super().run_main(total_steps)
rank, world_size = self.get_rank_and_world_size()
if rank == 2 % world_size:
assert self.va_recorder is not None, "va_recorder is required for stream audio input for rank 0"
self.va_reader.start()
self.init_run()
self.video_segment_num = "unlimited"
fetch_timeout = self.va_reader.segment_duration + 1
segment_idx = 0
fail_count = 0
max_fail_count = 10
while True:
with ProfilingContext4Debug(f"stream segment get audio segment {segment_idx}"):
self.check_stop()
audio_array = self.va_reader.get_audio_segment(timeout=fetch_timeout)
if audio_array is None:
fail_count += 1
logger.warning(f"Failed to get audio chunk {fail_count} times")
if fail_count > max_fail_count:
raise Exception(f"Failed to get audio chunk {fail_count} times, stop reader")
continue
with ProfilingContext4Debug(f"stream segment end2end {segment_idx}"):
fail_count = 0
self.init_run_segment(segment_idx, audio_array)
latents, generator = self.run_segment(total_steps=None)
self.gen_video = self.run_vae_decoder(latents)
self.end_run_segment()
segment_idx += 1
finally:
if hasattr(self.model, "scheduler"):
self.end_run()
if self.va_reader:
self.va_reader.stop()
self.va_reader = None
if self.va_recorder:
self.va_recorder.stop(wait=False)
self.va_recorder = None
@ProfilingContext4Debug("Process after vae decoder") @ProfilingContext4Debug("Process after vae decoder")
def process_images_after_vae_decoder(self, save_video=True): def process_images_after_vae_decoder(self, save_video=True):
# Merge results # Merge results
...@@ -515,7 +628,7 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -515,7 +628,7 @@ class WanAudioRunner(WanRunner): # type:ignore
target_fps=target_fps, target_fps=target_fps,
) )
if save_video: if save_video and isinstance(self.config["save_video_path"], str):
if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"): if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
fps = self.config["video_frame_interpolation"]["target_fps"] fps = self.config["video_frame_interpolation"]["target_fps"]
else: else:
......
...@@ -21,3 +21,10 @@ easydict ...@@ -21,3 +21,10 @@ easydict
gradio gradio
aiohttp aiohttp
pydantic pydantic
aio-pika
asyncpg>=0.27.0
aioboto3>=12.0.0
fastapi
uvicorn
PyJWT
requests
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment