Commit a1db6fa0 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

update deploy (#323)



update deploy

---------
Co-authored-by: default avatarliuliang1 <liuliang1@sensetime.com>
Co-authored-by: default avatarqinxinyi <qinxinyi@sensetime.com>
Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent 99158e75
FROM pytorch/pytorch:2.7.1-cuda12.8-cudnn9-devel AS base
FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel AS base
WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive
ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8
# use tsinghua source
RUN sed -i 's|http://archive.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list \
&& sed -i 's|http://security.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list
ENV LD_LIBRARY_PATH=/usr/local/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
RUN apt-get update && apt-get install -y vim tmux zip unzip wget git build-essential libibverbs-dev ca-certificates \
curl iproute2 ffmpeg libsm6 libxext6 kmod ccache libnuma-dev \
curl iproute2 libsm6 libxext6 kmod ccache libnuma-dev libssl-dev flex bison libgtk-3-dev libpango1.0-dev \
libsoup2.4-dev libnice-dev libopus-dev libvpx-dev libx264-dev libsrtp2-dev libglib2.0-dev libdrm-dev\
&& apt-get clean && rm -rf /var/lib/apt/lists/*
RUN pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
RUN pip install --no-cache-dir packaging ninja cmake scikit-build-core uv ruff pre-commit -U
RUN pip install --no-cache-dir packaging ninja cmake scikit-build-core uv meson ruff pre-commit fastapi uvicorn requests -U
RUN git clone https://github.com/vllm-project/vllm.git && cd vllm \
&& python use_existing_torch.py && pip install -r requirements/build.txt \
......@@ -28,6 +24,8 @@ RUN git clone https://github.com/sgl-project/sglang.git && cd sglang/sgl-kernel
RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \
imageio-ffmpeg einops loguru qtorch ftfy easydict
RUN conda install conda-forge::ffmpeg=8.0.0 -y && ln -s /opt/conda/bin/ffmpeg /usr/bin/ffmpeg
RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive
RUN cd flash-attention && python setup.py install && rm -rf build
......@@ -42,4 +40,34 @@ RUN git clone https://github.com/KONAKONA666/q8_kernels.git
RUN cd q8_kernels && git submodule init && git submodule update && python setup.py install && rm -rf build
# cloud deploy
RUN pip install --no-cache-dir aio-pika asyncpg>=0.27.0 aioboto3>=12.0.0 PyJWT alibabacloud_dypnsapi20170525==1.2.2 redis==6.4.0 tos -U
RUN cd /opt \
&& wget https://mirrors.tuna.tsinghua.edu.cn/gnu//libiconv/libiconv-1.15.tar.gz \
&& tar zxvf libiconv-1.15.tar.gz \
&& cd libiconv-1.15 \
&& ./configure \
&& make \
&& make install
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
ENV PATH=/root/.cargo/bin:$PATH
RUN cd /opt \
&& git clone https://github.com/GStreamer/gstreamer.git -b 1.24.12 --depth 1 \
&& cd gstreamer \
&& meson setup builddir \
&& meson compile -C builddir \
&& meson install -C builddir \
&& ldconfig
RUN cd /opt \
&& git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.24.12 --depth 1 \
&& cd gst-plugins-rs \
&& cargo build --package gst-plugin-webrtchttp --release \
&& install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/
RUN ldconfig
WORKDIR /workspace
......@@ -11,7 +11,7 @@ RUN sed -i 's|http://archive.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.ed
&& sed -i 's|http://security.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list
RUN apt-get update && apt-get install -y vim tmux zip unzip wget git build-essential libibverbs-dev ca-certificates \
curl iproute2 ffmpeg libsm6 libxext6 kmod ccache libnuma-dev \
curl iproute2 libsm6 libxext6 kmod ccache libnuma-dev \
&& apt-get clean && rm -rf /var/lib/apt/lists/*
RUN pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
......@@ -28,6 +28,8 @@ RUN git clone https://github.com/sgl-project/sglang.git && cd sglang/sgl-kernel
RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \
imageio-ffmpeg einops loguru qtorch ftfy easydict
RUN conda install conda-forge::ffmpeg=8.0.0 -y && ln -s /opt/conda/bin/ffmpeg /usr/bin/ffmpeg
RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive
RUN cd flash-attention && python setup.py install && rm -rf build
......
# For rtc whep, build gstreamer whith whepsrc plugin
FROM registry.ms-sc-01.maoshanwangtech.com/ms-ccr/lightx2v:25080601-cu128-SageSm90 AS gstreamer-base
FROM lightx2v/lightx2v:25091903-cu128 AS base
RUN apt update -y \
&& apt update -y \
&& apt install -y libssl-dev flex bison \
libgtk-3-dev libpango1.0-dev libsoup2.4-dev \
libnice-dev libopus-dev libvpx-dev libx264-dev \
libsrtp2-dev libglib2.0-dev libdrm-dev
RUN cd /opt \
&& wget https://mirrors.tuna.tsinghua.edu.cn/gnu//libiconv/libiconv-1.15.tar.gz \
&& tar zxvf libiconv-1.15.tar.gz \
&& cd libiconv-1.15 \
&& ./configure \
&& make \
&& make install
RUN pip install meson
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
ENV PATH=/root/.cargo/bin:$PATH
RUN cd /opt \
&& git clone https://github.com/GStreamer/gstreamer.git -b 1.24.12 --depth 1 \
&& cd gstreamer \
&& meson setup builddir \
&& meson compile -C builddir \
&& meson install -C builddir \
&& ldconfig
RUN cd /opt \
&& git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.24.12 --depth 1 \
&& cd gst-plugins-rs \
&& cargo build --package gst-plugin-webrtchttp --release \
&& install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/
# Lightx2v deploy image
FROM registry.ms-sc-01.maoshanwangtech.com/ms-ccr/lightx2v:25080601-cu128-SageSm90
RUN mkdir /workspace/lightx2v
WORKDIR /workspace/lightx2v
ENV PYTHONPATH=/workspace/lightx2v
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt
RUN conda install conda-forge::ffmpeg=8.0.0 -y
RUN rm /usr/bin/ffmpeg && ln -s /opt/conda/bin/ffmpeg /usr/bin/ffmpeg
RUN apt update -y \
&& apt install -y libssl-dev \
libgtk-3-dev libpango1.0-dev libsoup2.4-dev \
libnice-dev libopus-dev libvpx-dev libx264-dev \
libsrtp2-dev libglib2.0-dev libdrm-dev
ENV LBDIR=/usr/local/lib/x86_64-linux-gnu
COPY --from=gstreamer-base /usr/local/bin/gst-* /usr/local/bin/
COPY --from=gstreamer-base $LBDIR $LBDIR
RUN ldconfig
ENV LD_LIBRARY_PATH=$LBDIR:$LD_LIBRARY_PATH
RUN gst-launch-1.0 --version
RUN gst-inspect-1.0 whepsrc
RUN mkdir /workspace/LightX2V
WORKDIR /workspace/LightX2V
ENV PYTHONPATH=/workspace/LightX2V
COPY assets assets
COPY configs configs
......
......@@ -102,6 +102,10 @@
"latents": "TENSOR",
"output_video": "VIDEO"
},
"model_name_inner_to_outer": {
"seko_talk": "SekoTalk"
},
"model_name_outer_to_inner": {},
"monitor": {
"subtask_created_timeout": 1800,
"subtask_pending_timeout": 1800,
......
......@@ -27,16 +27,16 @@ We strongly recommend using the Docker environment, which is the simplest and fa
#### 1. Pull Image
Visit LightX2V's [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags), select a tag with the latest date, such as `25090503-cu128`:
Visit LightX2V's [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags), select a tag with the latest date, such as `25091903-cu128`:
```bash
docker pull lightx2v/lightx2v:25090503-cu128
docker pull lightx2v/lightx2v:25091903-cu128
```
We recommend using the `cuda128` environment for faster inference speed. If you need to use the `cuda124` environment, you can use image versions with the `-cu124` suffix:
```bash
docker pull lightx2v/lightx2v:25090503-cu124
docker pull lightx2v/lightx2v:25091903-cu124
```
#### 2. Run Container
......@@ -51,10 +51,10 @@ For mainland China, if the network is unstable when pulling images, you can pull
```bash
# cuda128
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25090503-cu128
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25091903-cu128
# cuda124
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25090503-cu124
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25091903-cu124
```
### 🐍 Conda Environment Setup
......
......@@ -27,16 +27,16 @@
#### 1. 拉取镜像
访问 LightX2V 的 [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags),选择一个最新日期的 tag,比如 `25090503-cu128`
访问 LightX2V 的 [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags),选择一个最新日期的 tag,比如 `25091903-cu128`
```bash
docker pull lightx2v/lightx2v:25090503-cu128
docker pull lightx2v/lightx2v:25091903-cu128
```
我们推荐使用`cuda128`环境,以获得更快的推理速度,若需要使用`cuda124`环境,可以使用带`-cu124`后缀的镜像版本:
```bash
docker pull lightx2v/lightx2v:25090503-cu124
docker pull lightx2v/lightx2v:25091903-cu124
```
#### 2. 运行容器
......@@ -51,10 +51,10 @@ docker run --gpus all -itd --ipc=host --name [容器名] -v [挂载设置] --ent
```bash
# cuda128
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25090503-cu128
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25091903-cu128
# cuda124
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25090503-cu124
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25091903-cu124
```
### 🐍 Conda 环境搭建
......
......@@ -16,6 +16,8 @@ class Pipeline:
self.model_lists = []
self.types = {}
self.queues = set()
self.model_name_inner_to_outer = self.meta.get("model_name_inner_to_outer", {})
self.model_name_outer_to_inner = self.meta.get("model_name_outer_to_inner", {})
self.tidy_pipeline()
def init_dict(self, base, task, model_cls):
......@@ -132,6 +134,14 @@ class Pipeline:
item = item[k]
return item
def check_item_by_keys(self, keys):
item = self.data
for k in keys:
if k not in item:
return False
item = item[k]
return True
def get_model_lists(self):
return self.model_lists
......@@ -144,6 +154,12 @@ class Pipeline:
def get_queues(self):
return self.queues
def inner_model_name(self, name):
return self.model_name_outer_to_inner.get(name, name)
def outer_model_name(self, name):
return self.model_name_inner_to_outer.get(name, name)
if __name__ == "__main__":
pipeline = Pipeline(sys.argv[1])
......
import asyncio
import base64
import io
import os
......@@ -87,6 +88,72 @@ async def fetch_resource(url, timeout):
return content
# check, resize, read rotate meta info
def format_image_data(data, max_size=1280):
image = Image.open(io.BytesIO(data)).convert("RGB")
exif = image.getexif()
changed = False
w, h = image.size
assert w > 0 and h > 0, "image is empty"
logger.info(f"load image: {w}x{h}, exif: {exif}")
if w > max_size or h > max_size:
ratio = max_size / max(w, h)
w = int(w * ratio)
h = int(h * ratio)
image = image.resize((w, h))
logger.info(f"resize image to: {image.size}")
changed = True
orientation_key = 274
if orientation_key and orientation_key in exif:
orientation = exif[orientation_key]
if orientation == 2:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
elif orientation == 3:
image = image.rotate(180, expand=True)
elif orientation == 4:
image = image.transpose(Image.FLIP_TOP_BOTTOM)
elif orientation == 5:
image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(90, expand=True)
elif orientation == 6:
image = image.rotate(270, expand=True)
elif orientation == 7:
image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(270, expand=True)
elif orientation == 8:
image = image.rotate(90, expand=True)
# reset orientation to 1
if orientation != 1:
logger.info(f"reset orientation from {orientation} to 1")
exif[orientation_key] = 1
changed = True
if not changed:
return data
output = io.BytesIO()
image.save(output, format=image.format or "JPEG", exif=exif.tobytes())
return output.getvalue()
def format_audio_data(data):
if len(data) < 4:
raise ValueError("Audio file too short")
try:
waveform, sample_rate = torchaudio.load(io.BytesIO(data), num_frames=10)
logger.info(f"load audio: {waveform.size()}, {sample_rate}")
assert waveform.size(0) > 0, "audio is empty"
assert sample_rate > 0, "audio sample rate is not valid"
except Exception as e:
logger.warning(f"torchaudio failed to load audio, trying alternative method: {e}")
# check audio headers
audio_headers = [b"RIFF", b"ID3", b"\xff\xfb", b"\xff\xf3", b"\xff\xf2", b"OggS"]
if not any(data.startswith(header) for header in audio_headers):
logger.warning("Audio file doesn't have recognized header, but continuing...")
logger.info(f"Audio validation passed (alternative method), size: {len(data)} bytes")
return data
async def preload_data(inp, inp_type, typ, val):
try:
if typ == "url":
......@@ -102,27 +169,10 @@ async def preload_data(inp, inp_type, typ, val):
# check if valid image bytes
if inp_type == "IMAGE":
image = Image.open(io.BytesIO(data))
logger.info(f"load image: {image.size}")
assert image.size[0] > 0 and image.size[1] > 0, "image is empty"
data = await asyncio.to_thread(format_image_data, data)
elif inp_type == "AUDIO":
if typ != "stream":
try:
waveform, sample_rate = torchaudio.load(io.BytesIO(data), num_frames=10)
logger.info(f"load audio: {waveform.size()}, {sample_rate}")
assert waveform.size(0) > 0, "audio is empty"
assert sample_rate > 0, "audio sample rate is not valid"
except Exception as e:
logger.warning(f"torchaudio failed to load audio, trying alternative method: {e}")
# 尝试使用其他方法验证音频文件
# 检查文件头是否为有效的音频格式
if len(data) < 4:
raise ValueError("Audio file too short")
# 检查常见的音频文件头
audio_headers = [b"RIFF", b"ID3", b"\xff\xfb", b"\xff\xf3", b"\xff\xf2", b"OggS"]
if not any(data.startswith(header) for header in audio_headers):
logger.warning("Audio file doesn't have recognized header, but continuing...")
logger.info(f"Audio validation passed (alternative method), size: {len(data)} bytes")
data = await asyncio.to_thread(format_audio_data, data)
else:
raise Exception(f"cannot parse inp_type={inp_type} data")
return data
......@@ -152,3 +202,21 @@ def check_params(params, raw_inputs, raw_outputs, types):
assert stream_audio, "stream audio is not supported, please set env STREAM_AUDIO=1"
elif types[x] == "VIDEO":
assert stream_video, "stream video is not supported, please set env STREAM_VIDEO=1"
if __name__ == "__main__":
# https://github.com/recurser/exif-orientation-examples
exif_dir = "/data/nvme0/liuliang1/exif-orientation-examples"
out_dir = "/data/nvme0/liuliang1/exif-orientation-examples/outs"
os.makedirs(out_dir, exist_ok=True)
for base_name in ["Landscape", "Portrait"]:
for i in range(9):
fin_name = os.path.join(exif_dir, f"{base_name}_{i}.jpg")
fout_name = os.path.join(out_dir, f"{base_name}_{i}_formatted.jpg")
logger.info(f"format image: {fin_name} -> {fout_name}")
with open(fin_name, "rb") as f:
data = f.read()
data = format_image_data(data)
with open(fout_name, "wb") as f:
f.write(data)
import queue
import random
import signal
import socket
import subprocess
......@@ -18,14 +19,13 @@ class VARecorder:
livestream_url: str,
fps: float = 16.0,
sample_rate: int = 16000,
audio_port: int = 30200,
video_port: int = 30201,
):
self.livestream_url = livestream_url
self.fps = fps
self.sample_rate = sample_rate
self.audio_port = audio_port
self.video_port = video_port
self.audio_port = random.choice(range(32000, 40000))
self.video_port = self.audio_port + 1
logger.info(f"VARecorder audio port: {self.audio_port}, video port: {self.video_port}")
self.width = None
self.height = None
......@@ -116,6 +116,58 @@ class VARecorder:
finally:
logger.info("Video push worker thread stopped")
def start_ffmpeg_process_local(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg",
"-re",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"44100",
"-b:v",
"4M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"mp4",
self.livestream_url,
"-y",
"-loglevel",
"info",
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
......@@ -240,7 +292,7 @@ class VARecorder:
elif self.livestream_url.startswith("http"):
self.start_ffmpeg_process_whip()
else:
raise Exception(f"Unsupported livestream URL: {self.livestream_url}")
self.start_ffmpeg_process_local()
self.audio_thread = threading.Thread(target=self.audio_worker)
self.video_thread = threading.Thread(target=self.video_worker)
self.audio_thread.start()
......@@ -353,12 +405,13 @@ if __name__ == "__main__":
recorder = VARecorder(
# livestream_url="rtmp://localhost/live/test",
livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
# livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
livestream_url="/path/to/output_video.mp4",
fps=fps,
sample_rate=sample_rate,
)
audio_path = "/mtc/liuliang1/lightx2v/test_deploy/media_test/test_b_2min.wav"
audio_path = "/path/to/test_b_2min.wav"
audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
audio_array = audio_array.numpy().reshape(-1)
......
......@@ -236,6 +236,11 @@ async def prepare_subtasks(task_id):
await server_monitor.pending_subtasks_add(sub["queue"], sub["task_id"])
def format_task(task):
task["status"] = task["status"].name
task["model_cls"] = model_pipelines.outer_model_name(task["model_cls"])
@app.get("/api/v1/model/list")
async def api_v1_model_list(user=Depends(verify_user_access)):
try:
......@@ -254,6 +259,7 @@ async def api_v1_task_submit(request: Request, user=Depends(verify_user_access))
return error_response(msg, 400)
params = await request.json()
keys = [params.pop("task"), params.pop("model_cls"), params.pop("stage")]
keys[1] = model_pipelines.inner_model_name(keys[1])
assert len(params["prompt"]) > 0, "valid prompt is required"
# get worker infos, model input names
......@@ -303,7 +309,7 @@ async def api_v1_task_query(request: Request, user=Depends(verify_user_access)):
task, subtasks = await task_manager.query_task(task_id, user["user_id"], only_task=False)
if task is not None:
task["subtasks"] = await server_monitor.format_subtask(subtasks)
task["status"] = task["status"].name
format_task(task)
tasks.append(task)
return {"tasks": tasks}
......@@ -313,7 +319,7 @@ async def api_v1_task_query(request: Request, user=Depends(verify_user_access)):
if task is None:
return error_response(f"Task {task_id} not found", 404)
task["subtasks"] = await server_monitor.format_subtask(subtasks)
task["status"] = task["status"].name
format_task(task)
return task
except Exception as e:
traceback.print_exc()
......@@ -344,7 +350,7 @@ async def api_v1_task_list(request: Request, user=Depends(verify_user_access)):
tasks = await task_manager.list_tasks(**query_params)
for task in tasks:
task["status"] = task["status"].name
format_task(task)
return {"tasks": tasks, "pagination": page_info}
except Exception as e:
......@@ -457,12 +463,18 @@ async def api_v1_task_cancel(request: Request, user=Depends(verify_user_access))
async def api_v1_task_resume(request: Request, user=Depends(verify_user_access)):
try:
task_id = request.query_params["task_id"]
task = await task_manager.query_task(task_id, user_id=user["user_id"])
keys = [task["task_type"], task["model_cls"], task["stage"]]
if not model_pipelines.check_item_by_keys(keys):
return error_response(f"Model {keys} is not supported now, please submit a new task", 400)
ret = await task_manager.resume_task(task_id, user_id=user["user_id"], all_subtask=False)
if ret:
if ret is True:
await prepare_subtasks(task_id)
return {"msg": "ok"}
else:
return error_response(f"Task {task_id} resume failed", 400)
return error_response(f"Task {task_id} resume failed: {ret}", 400)
except Exception as e:
traceback.print_exc()
return error_response(str(e), 500)
......@@ -605,7 +617,7 @@ async def api_v1_worker_ping_subtask(request: Request, valid=Depends(verify_work
queue = params.pop("queue")
task = await task_manager.query_task(task_id)
if task["status"] != TaskStatus.RUNNING:
if task is None or task["status"] != TaskStatus.RUNNING:
return {"msg": "delete"}
assert await task_manager.ping_subtask(task_id, worker_name, identity)
......@@ -714,27 +726,18 @@ async def api_v1_template_list(request: Request, valid=Depends(verify_user_acces
if page <= total_pages:
start_idx = (page - 1) * page_size
end_idx = start_idx + page_size
all_images.sort(key=lambda x: x)
all_audios.sort(key=lambda x: x)
all_videos.sort(key=lambda x: x)
for image in all_images[start_idx:end_idx]:
url = await data_manager.presign_template_url("images", image)
if url is None:
url = f"./assets/template/images/{image}"
paginated_image_templates.append({"filename": image, "url": url})
for audio in all_audios[start_idx:end_idx]:
url = await data_manager.presign_template_url("audios", audio)
if url is None:
url = f"./assets/template/audios/{audio}"
paginated_audio_templates.append({"filename": audio, "url": url})
for video in all_videos[start_idx:end_idx]:
url = await data_manager.presign_template_url("videos", video)
if url is None:
url = f"./assets/template/videos/{video}"
paginated_video_templates.append({"filename": video, "url": url})
async def handle_media(media_type, media_names, paginated_media_templates):
media_names.sort(key=lambda x: x)
for media_name in media_names[start_idx:end_idx]:
url = await data_manager.presign_template_url(media_type, media_name)
if url is None:
url = f"./assets/template/{media_type}/{media_name}"
paginated_media_templates.append({"filename": media_name, "url": url})
await handle_media("images", all_images, paginated_image_templates)
await handle_media("audios", all_audios, paginated_audio_templates)
await handle_media("videos", all_videos, paginated_video_templates)
return {
"templates": {"images": paginated_image_templates, "audios": paginated_audio_templates, "videos": paginated_video_templates},
......@@ -760,6 +763,7 @@ async def api_v1_template_tasks(request: Request, valid=Depends(verify_user_acce
page_size = min(page_size, 100)
all_templates = []
all_categories = set()
template_files = await data_manager.list_template_files("tasks")
template_files = [] if template_files is None else template_files
......@@ -767,6 +771,8 @@ async def api_v1_template_tasks(request: Request, valid=Depends(verify_user_acce
try:
bytes_data = await data_manager.load_template_file("tasks", template_file)
template_data = json.loads(bytes_data)
template_data["task"]["model_cls"] = model_pipelines.outer_model_name(template_data["task"]["model_cls"])
all_categories.update(template_data["task"]["tags"])
if category is not None and category != "all" and category not in template_data["task"]["tags"]:
continue
if search is not None and search not in template_data["task"]["params"]["prompt"] + template_data["task"]["params"]["negative_prompt"] + template_data["task"][
......@@ -787,7 +793,7 @@ async def api_v1_template_tasks(request: Request, valid=Depends(verify_user_acce
end_idx = start_idx + page_size
paginated_templates = all_templates[start_idx:end_idx]
return {"templates": paginated_templates, "pagination": {"page": page, "page_size": page_size, "total": total_templates, "total_pages": total_pages}}
return {"templates": paginated_templates, "pagination": {"page": page, "page_size": page_size, "total": total_templates, "total_pages": total_pages}, "categories": list(all_categories)}
except Exception as e:
traceback.print_exc()
......
......@@ -92,6 +92,11 @@ class WorkerClient:
if elapse > self.offline_timeout:
logger.warning(f"Worker {self.identity} {self.queue} offline timeout2: {elapse:.2f} s")
return False
# fetching too long
elif self.status == WorkerStatus.FETCHING:
if elapse > self.fetching_timeout:
logger.warning(f"Worker {self.identity} {self.queue} fetching timeout: {elapse:.2f} s")
return False
return True
......@@ -111,7 +116,7 @@ class ServerMonitor:
self.fetching_timeout = self.config.get("fetching_timeout", 1000)
for queue in self.all_queues:
self.subtask_run_timeouts[queue] = self.config["subtask_running_timeouts"].get(queue, 60)
self.subtask_run_timeouts[queue] = self.config["subtask_running_timeouts"].get(queue, 3600)
self.subtask_created_timeout = self.config["subtask_created_timeout"]
self.subtask_pending_timeout = self.config["subtask_pending_timeout"]
self.worker_avg_window = self.config["worker_avg_window"]
......
This diff is collapsed.
......@@ -273,10 +273,10 @@ class LocalTaskManager(BaseTaskManager):
task, subtasks = self.load(task_id, user_id)
# the task is not finished
if task["status"] not in FinishedStatus:
return False
return "Active task cannot be resumed"
# the task is no need to resume
if not all_subtask and task["status"] == TaskStatus.SUCCEED:
return False
return "Succeed task cannot be resumed"
for sub in subtasks:
if all_subtask or sub["status"] != TaskStatus.SUCCEED:
self.mark_subtask_change(records, sub, None, TaskStatus.CREATED)
......
......@@ -702,10 +702,10 @@ class PostgresSQLTaskManager(BaseTaskManager):
task, subtasks = await self.load(conn, task_id, user_id)
# the task is not finished
if task["status"] not in FinishedStatus:
return False
return "Active task cannot be resumed"
# the task is no need to resume
if not all_subtask and task["status"] == TaskStatus.SUCCEED:
return False
return "Succeed task cannot be resumed"
for sub in subtasks:
if all_subtask or sub["status"] != TaskStatus.SUCCEED:
......
......@@ -23,15 +23,19 @@ from lightx2v.utils.utils import seed_all
class BaseWorker:
@ProfilingContext4DebugL1("Init Worker Worker Cost:")
def __init__(self, args):
args.save_video_path = ""
config = set_config(args)
config["mode"] = ""
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
seed_all(config.seed)
self.rank = 0
self.world_size = 1
if config.parallel:
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
set_parallel_config(config)
seed_all(config.seed)
# same as va_recorder rank and worker main ping rank
self.out_video_rank = self.world_size - 1
torch.set_grad_enabled(False)
self.runner = RUNNER_REGISTER[config.model_cls](config)
# fixed config
......@@ -121,7 +125,7 @@ class BaseWorker:
async def save_output_video(self, tmp_video_path, output_video_path, data_manager):
# save output video
if data_manager.name != "local" and self.rank == 0 and isinstance(tmp_video_path, str):
if data_manager.name != "local" and self.rank == self.out_video_rank and isinstance(tmp_video_path, str):
video_data = open(tmp_video_path, "rb").read()
await data_manager.save_bytes(video_data, output_video_path)
......
......@@ -85,7 +85,7 @@ def main():
help="The file of the source mask. Default None.",
)
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
parser.add_argument("--save_video_path", type=str, default=None, help="The path to save video path/file")
args = parser.parse_args()
# set config
......
......@@ -295,7 +295,9 @@ class DefaultRunner(BaseRunner):
save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg")
logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
return {"video": self.gen_video}
if self.config.get("return_video", False):
return {"video": self.gen_video}
return {"video": None}
def run_pipeline(self, save_video=True):
if self.config["use_prompt_enhancer"]:
......
import gc
import os
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
......@@ -12,7 +11,6 @@ import torchvision.transforms.functional as TF
from PIL import Image
from einops import rearrange
from loguru import logger
from torchvision.io import write_video
from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize
......@@ -28,9 +26,7 @@ from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import find_torch_model_path, load_weights, vae_to_comfyui_image
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io._video_deprecation_warning")
from lightx2v.utils.utils import find_torch_model_path, load_weights, vae_to_comfyui_image_inplace
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
......@@ -475,10 +471,12 @@ class WanAudioRunner(WanRunner): # type:ignore
def init_run(self):
super().init_run()
self.scheduler.set_audio_adapter(self.audio_adapter)
self.gen_video_list = []
self.cut_audio_list = []
self.prev_video = None
if self.config.get("return_video", False):
self.gen_video_final = torch.zeros((self.inputs["expected_frames"], self.config.tgt_h, self.config.tgt_w, 3), dtype=torch.float32, device="cpu")
else:
self.gen_video_final = None
self.cut_audio_final = None
@ProfilingContext4DebugL1("Init run segment")
def init_run_segment(self, segment_idx, audio_array=None):
......@@ -510,22 +508,31 @@ class WanAudioRunner(WanRunner): # type:ignore
def end_run_segment(self):
self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float)
useful_length = self.segment.end_frame - self.segment.start_frame
self.gen_video_list.append(self.gen_video[:, :, :useful_length].cpu())
self.cut_audio_list.append(self.segment.audio_array[: useful_length * self._audio_processor.audio_frame_rate])
video_seg = self.gen_video[:, :, :useful_length].cpu()
audio_seg = self.segment.audio_array[: useful_length * self._audio_processor.audio_frame_rate]
if self.va_recorder:
cur_video = vae_to_comfyui_image(self.gen_video_list[-1])
self.va_recorder.pub_livestream(cur_video, self.cut_audio_list[-1])
video_seg = vae_to_comfyui_image_inplace(video_seg)
# [Warning] Need check whether video segment interpolation works...
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
video_seg = self.vfi_model.interpolate_frames(
video_seg,
source_fps=self.config.get("fps", 16),
target_fps=target_fps,
)
if self.va_reader:
self.gen_video_list.pop()
self.cut_audio_list.pop()
if self.va_recorder:
self.va_recorder.pub_livestream(video_seg, audio_seg)
elif self.config.get("return_video", False):
self.gen_video_final[self.segment.start_frame : self.segment.end_frame].copy_(video_seg)
self.cut_audio_final = np.concatenate([self.cut_audio_final, audio_seg], axis=0).astype(np.float32) if self.cut_audio_final is not None else audio_seg
# Update prev_video for next iteration
self.prev_video = self.gen_video
# Clean up GPU memory after each segment
del self.gen_video
del video_seg, audio_seg
torch.cuda.empty_cache()
def get_rank_and_world_size(self):
......@@ -540,18 +547,19 @@ class WanAudioRunner(WanRunner): # type:ignore
output_video_path = self.config.get("save_video_path", None)
self.va_recorder = None
if isinstance(output_video_path, dict):
assert output_video_path["type"] == "stream", f"unexcept save_video_path: {output_video_path}"
rank, world_size = self.get_rank_and_world_size()
if rank == 2 % world_size:
record_fps = self.config.get("target_fps", 16)
audio_sr = self.config.get("audio_sr", 16000)
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
record_fps = self.config["video_frame_interpolation"]["target_fps"]
self.va_recorder = VARecorder(
livestream_url=output_video_path["data"],
fps=record_fps,
sample_rate=audio_sr,
)
output_video_path = output_video_path["data"]
logger.info(f"init va_recorder with output_video_path: {output_video_path}")
rank, world_size = self.get_rank_and_world_size()
if output_video_path and rank == world_size - 1:
record_fps = self.config.get("target_fps", 16)
audio_sr = self.config.get("audio_sr", 16000)
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
record_fps = self.config["video_frame_interpolation"]["target_fps"]
self.va_recorder = VARecorder(
livestream_url=output_video_path,
fps=record_fps,
sample_rate=audio_sr,
)
def init_va_reader(self):
audio_path = self.config.get("audio_path", None)
......@@ -583,8 +591,8 @@ class WanAudioRunner(WanRunner): # type:ignore
return super().run_main(total_steps)
rank, world_size = self.get_rank_and_world_size()
if rank == 2 % world_size:
assert self.va_recorder is not None, "va_recorder is required for stream audio input for rank 0"
if rank == world_size - 1:
assert self.va_recorder is not None, "va_recorder is required for stream audio input for rank 2"
self.va_reader.start()
self.init_run()
......@@ -627,67 +635,17 @@ class WanAudioRunner(WanRunner): # type:ignore
self.va_recorder = None
@ProfilingContext4DebugL1("Process after vae decoder")
def process_images_after_vae_decoder(self, save_video=True):
# Merge results
gen_lvideo = torch.cat(self.gen_video_list, dim=2).float()
merge_audio = np.concatenate(self.cut_audio_list, axis=0).astype(np.float32)
comfyui_images = vae_to_comfyui_image(gen_lvideo)
# Apply frame interpolation if configured
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
comfyui_images = self.vfi_model.interpolate_frames(
comfyui_images,
source_fps=self.config.get("fps", 16),
target_fps=target_fps,
)
if save_video and isinstance(self.config["save_video_path"], str):
if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
fps = self.config["video_frame_interpolation"]["target_fps"]
else:
fps = self.config.get("fps", 16)
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(f"🎬 Start to save video 🎬")
self._save_video_with_audio(comfyui_images, merge_audio, fps)
logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
# Convert audio to ComfyUI format
audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0)
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
return {"video": comfyui_images, "audio": comfyui_audio}
def process_images_after_vae_decoder(self, save_video=False):
if self.config.get("return_video", False):
audio_waveform = torch.from_numpy(self.cut_audio_final).unsqueeze(0).unsqueeze(0)
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
return {"video": self.gen_video_final, "audio": comfyui_audio}
return {"video": None, "audio": None}
def init_modules(self):
super().init_modules()
self.run_input_encoder = self._run_input_encoder_local_r2v_audio
def _save_video_with_audio(self, images, audio_array, fps):
output_path = self.config.get("save_video_path")
parent_dir = os.path.dirname(output_path)
if parent_dir and not os.path.exists(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
sample_rate = self._audio_processor.audio_sr
if images.dtype != torch.uint8:
images = (images * 255).clamp(0, 255).to(torch.uint8)
write_video(
filename=output_path,
video_array=images,
fps=fps,
video_codec="libx264",
audio_array=torch.tensor(audio_array[None]),
audio_fps=sample_rate,
audio_codec="aac",
options={"preset": "medium", "crf": "23"}, # 可调整视频输出质量
)
def load_transformer(self):
"""Load transformer with LoRA support"""
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
......
......@@ -31,6 +31,7 @@ def get_default_config():
"tgt_h": None,
"tgt_w": None,
"target_shape": None,
"return_video": False,
}
return default_config
......@@ -73,6 +74,8 @@ def set_config(args):
logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.")
config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1
assert not (config.save_video_path and config.return_video), "save_video_path and return_video cannot be set at the same time"
return config
......
......@@ -131,6 +131,39 @@ def vae_to_comfyui_image(vae_output: torch.Tensor) -> torch.Tensor:
return images
def vae_to_comfyui_image_inplace(vae_output: torch.Tensor) -> torch.Tensor:
"""
Convert VAE decoder output to ComfyUI Image format (inplace operation)
Args:
vae_output: VAE decoder output tensor, typically in range [-1, 1]
Shape: [B, C, T, H, W] or [B, C, H, W]
WARNING: This tensor will be modified in-place!
Returns:
ComfyUI Image tensor in range [0, 1]
Shape: [B, H, W, C] for single frame or [B*T, H, W, C] for video
Note: The returned tensor is the same object as input (modified in-place)
"""
# Handle video tensor (5D) vs image tensor (4D)
if vae_output.dim() == 5:
# Video tensor: [B, C, T, H, W]
B, C, T, H, W = vae_output.shape
# Reshape to [B*T, C, H, W] for processing (inplace view)
vae_output = vae_output.permute(0, 2, 1, 3, 4).contiguous().view(B * T, C, H, W)
# Normalize from [-1, 1] to [0, 1] (inplace)
vae_output.add_(1).div_(2)
# Clamp values to [0, 1] (inplace)
vae_output.clamp_(0, 1)
# Convert from [B, C, H, W] to [B, H, W, C] and move to CPU
vae_output = vae_output.permute(0, 2, 3, 1).cpu()
return vae_output
def save_to_video(
images: torch.Tensor,
output_path: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment