Commit a1db6fa0 authored by LiangLiu's avatar LiangLiu Committed by GitHub
Browse files

update deploy (#323)



update deploy

---------
Co-authored-by: default avatarliuliang1 <liuliang1@sensetime.com>
Co-authored-by: default avatarqinxinyi <qinxinyi@sensetime.com>
Co-authored-by: default avatarYang Yong(雍洋) <yongyang1030@163.com>
parent 99158e75
FROM pytorch/pytorch:2.7.1-cuda12.8-cudnn9-devel AS base FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel AS base
WORKDIR /app WORKDIR /app
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive
ENV LANG=C.UTF-8 ENV LANG=C.UTF-8
ENV LC_ALL=C.UTF-8 ENV LC_ALL=C.UTF-8
ENV LD_LIBRARY_PATH=/usr/local/lib/x86_64-linux-gnu:$LD_LIBRARY_PATH
# use tsinghua source
RUN sed -i 's|http://archive.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list \
&& sed -i 's|http://security.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list
RUN apt-get update && apt-get install -y vim tmux zip unzip wget git build-essential libibverbs-dev ca-certificates \ RUN apt-get update && apt-get install -y vim tmux zip unzip wget git build-essential libibverbs-dev ca-certificates \
curl iproute2 ffmpeg libsm6 libxext6 kmod ccache libnuma-dev \ curl iproute2 libsm6 libxext6 kmod ccache libnuma-dev libssl-dev flex bison libgtk-3-dev libpango1.0-dev \
libsoup2.4-dev libnice-dev libopus-dev libvpx-dev libx264-dev libsrtp2-dev libglib2.0-dev libdrm-dev\
&& apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get clean && rm -rf /var/lib/apt/lists/*
RUN pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple RUN pip install --no-cache-dir packaging ninja cmake scikit-build-core uv meson ruff pre-commit fastapi uvicorn requests -U
RUN pip install --no-cache-dir packaging ninja cmake scikit-build-core uv ruff pre-commit -U
RUN git clone https://github.com/vllm-project/vllm.git && cd vllm \ RUN git clone https://github.com/vllm-project/vllm.git && cd vllm \
&& python use_existing_torch.py && pip install -r requirements/build.txt \ && python use_existing_torch.py && pip install -r requirements/build.txt \
...@@ -28,6 +24,8 @@ RUN git clone https://github.com/sgl-project/sglang.git && cd sglang/sgl-kernel ...@@ -28,6 +24,8 @@ RUN git clone https://github.com/sgl-project/sglang.git && cd sglang/sgl-kernel
RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \ RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \
imageio-ffmpeg einops loguru qtorch ftfy easydict imageio-ffmpeg einops loguru qtorch ftfy easydict
RUN conda install conda-forge::ffmpeg=8.0.0 -y && ln -s /opt/conda/bin/ffmpeg /usr/bin/ffmpeg
RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive
RUN cd flash-attention && python setup.py install && rm -rf build RUN cd flash-attention && python setup.py install && rm -rf build
...@@ -42,4 +40,34 @@ RUN git clone https://github.com/KONAKONA666/q8_kernels.git ...@@ -42,4 +40,34 @@ RUN git clone https://github.com/KONAKONA666/q8_kernels.git
RUN cd q8_kernels && git submodule init && git submodule update && python setup.py install && rm -rf build RUN cd q8_kernels && git submodule init && git submodule update && python setup.py install && rm -rf build
# cloud deploy
RUN pip install --no-cache-dir aio-pika asyncpg>=0.27.0 aioboto3>=12.0.0 PyJWT alibabacloud_dypnsapi20170525==1.2.2 redis==6.4.0 tos -U
RUN cd /opt \
&& wget https://mirrors.tuna.tsinghua.edu.cn/gnu//libiconv/libiconv-1.15.tar.gz \
&& tar zxvf libiconv-1.15.tar.gz \
&& cd libiconv-1.15 \
&& ./configure \
&& make \
&& make install
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
ENV PATH=/root/.cargo/bin:$PATH
RUN cd /opt \
&& git clone https://github.com/GStreamer/gstreamer.git -b 1.24.12 --depth 1 \
&& cd gstreamer \
&& meson setup builddir \
&& meson compile -C builddir \
&& meson install -C builddir \
&& ldconfig
RUN cd /opt \
&& git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.24.12 --depth 1 \
&& cd gst-plugins-rs \
&& cargo build --package gst-plugin-webrtchttp --release \
&& install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/
RUN ldconfig
WORKDIR /workspace WORKDIR /workspace
...@@ -11,7 +11,7 @@ RUN sed -i 's|http://archive.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.ed ...@@ -11,7 +11,7 @@ RUN sed -i 's|http://archive.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.ed
&& sed -i 's|http://security.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list && sed -i 's|http://security.ubuntu.com/ubuntu/|https://mirrors.tuna.tsinghua.edu.cn/ubuntu/|g' /etc/apt/sources.list
RUN apt-get update && apt-get install -y vim tmux zip unzip wget git build-essential libibverbs-dev ca-certificates \ RUN apt-get update && apt-get install -y vim tmux zip unzip wget git build-essential libibverbs-dev ca-certificates \
curl iproute2 ffmpeg libsm6 libxext6 kmod ccache libnuma-dev \ curl iproute2 libsm6 libxext6 kmod ccache libnuma-dev \
&& apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get clean && rm -rf /var/lib/apt/lists/*
RUN pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple RUN pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
...@@ -28,6 +28,8 @@ RUN git clone https://github.com/sgl-project/sglang.git && cd sglang/sgl-kernel ...@@ -28,6 +28,8 @@ RUN git clone https://github.com/sgl-project/sglang.git && cd sglang/sgl-kernel
RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \ RUN pip install --no-cache-dir diffusers transformers tokenizers accelerate safetensors opencv-python numpy imageio \
imageio-ffmpeg einops loguru qtorch ftfy easydict imageio-ffmpeg einops loguru qtorch ftfy easydict
RUN conda install conda-forge::ffmpeg=8.0.0 -y && ln -s /opt/conda/bin/ffmpeg /usr/bin/ffmpeg
RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive RUN git clone https://github.com/Dao-AILab/flash-attention.git --recursive
RUN cd flash-attention && python setup.py install && rm -rf build RUN cd flash-attention && python setup.py install && rm -rf build
......
# For rtc whep, build gstreamer whith whepsrc plugin FROM lightx2v/lightx2v:25091903-cu128 AS base
FROM registry.ms-sc-01.maoshanwangtech.com/ms-ccr/lightx2v:25080601-cu128-SageSm90 AS gstreamer-base
RUN apt update -y \ RUN mkdir /workspace/LightX2V
&& apt update -y \ WORKDIR /workspace/LightX2V
&& apt install -y libssl-dev flex bison \ ENV PYTHONPATH=/workspace/LightX2V
libgtk-3-dev libpango1.0-dev libsoup2.4-dev \
libnice-dev libopus-dev libvpx-dev libx264-dev \
libsrtp2-dev libglib2.0-dev libdrm-dev
RUN cd /opt \
&& wget https://mirrors.tuna.tsinghua.edu.cn/gnu//libiconv/libiconv-1.15.tar.gz \
&& tar zxvf libiconv-1.15.tar.gz \
&& cd libiconv-1.15 \
&& ./configure \
&& make \
&& make install
RUN pip install meson
RUN curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh -s -- -y --default-toolchain stable
ENV PATH=/root/.cargo/bin:$PATH
RUN cd /opt \
&& git clone https://github.com/GStreamer/gstreamer.git -b 1.24.12 --depth 1 \
&& cd gstreamer \
&& meson setup builddir \
&& meson compile -C builddir \
&& meson install -C builddir \
&& ldconfig
RUN cd /opt \
&& git clone https://github.com/GStreamer/gst-plugins-rs.git -b gstreamer-1.24.12 --depth 1 \
&& cd gst-plugins-rs \
&& cargo build --package gst-plugin-webrtchttp --release \
&& install -m 644 target/release/libgstwebrtchttp.so $(pkg-config --variable=pluginsdir gstreamer-1.0)/
# Lightx2v deploy image
FROM registry.ms-sc-01.maoshanwangtech.com/ms-ccr/lightx2v:25080601-cu128-SageSm90
RUN mkdir /workspace/lightx2v
WORKDIR /workspace/lightx2v
ENV PYTHONPATH=/workspace/lightx2v
COPY requirements.txt requirements.txt
RUN pip install -r requirements.txt
RUN conda install conda-forge::ffmpeg=8.0.0 -y
RUN rm /usr/bin/ffmpeg && ln -s /opt/conda/bin/ffmpeg /usr/bin/ffmpeg
RUN apt update -y \
&& apt install -y libssl-dev \
libgtk-3-dev libpango1.0-dev libsoup2.4-dev \
libnice-dev libopus-dev libvpx-dev libx264-dev \
libsrtp2-dev libglib2.0-dev libdrm-dev
ENV LBDIR=/usr/local/lib/x86_64-linux-gnu
COPY --from=gstreamer-base /usr/local/bin/gst-* /usr/local/bin/
COPY --from=gstreamer-base $LBDIR $LBDIR
RUN ldconfig
ENV LD_LIBRARY_PATH=$LBDIR:$LD_LIBRARY_PATH
RUN gst-launch-1.0 --version
RUN gst-inspect-1.0 whepsrc
COPY assets assets COPY assets assets
COPY configs configs COPY configs configs
......
...@@ -102,6 +102,10 @@ ...@@ -102,6 +102,10 @@
"latents": "TENSOR", "latents": "TENSOR",
"output_video": "VIDEO" "output_video": "VIDEO"
}, },
"model_name_inner_to_outer": {
"seko_talk": "SekoTalk"
},
"model_name_outer_to_inner": {},
"monitor": { "monitor": {
"subtask_created_timeout": 1800, "subtask_created_timeout": 1800,
"subtask_pending_timeout": 1800, "subtask_pending_timeout": 1800,
......
...@@ -27,16 +27,16 @@ We strongly recommend using the Docker environment, which is the simplest and fa ...@@ -27,16 +27,16 @@ We strongly recommend using the Docker environment, which is the simplest and fa
#### 1. Pull Image #### 1. Pull Image
Visit LightX2V's [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags), select a tag with the latest date, such as `25090503-cu128`: Visit LightX2V's [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags), select a tag with the latest date, such as `25091903-cu128`:
```bash ```bash
docker pull lightx2v/lightx2v:25090503-cu128 docker pull lightx2v/lightx2v:25091903-cu128
``` ```
We recommend using the `cuda128` environment for faster inference speed. If you need to use the `cuda124` environment, you can use image versions with the `-cu124` suffix: We recommend using the `cuda128` environment for faster inference speed. If you need to use the `cuda124` environment, you can use image versions with the `-cu124` suffix:
```bash ```bash
docker pull lightx2v/lightx2v:25090503-cu124 docker pull lightx2v/lightx2v:25091903-cu124
``` ```
#### 2. Run Container #### 2. Run Container
...@@ -51,10 +51,10 @@ For mainland China, if the network is unstable when pulling images, you can pull ...@@ -51,10 +51,10 @@ For mainland China, if the network is unstable when pulling images, you can pull
```bash ```bash
# cuda128 # cuda128
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25090503-cu128 docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25091903-cu128
# cuda124 # cuda124
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25090503-cu124 docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25091903-cu124
``` ```
### 🐍 Conda Environment Setup ### 🐍 Conda Environment Setup
......
...@@ -27,16 +27,16 @@ ...@@ -27,16 +27,16 @@
#### 1. 拉取镜像 #### 1. 拉取镜像
访问 LightX2V 的 [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags),选择一个最新日期的 tag,比如 `25090503-cu128` 访问 LightX2V 的 [Docker Hub](https://hub.docker.com/r/lightx2v/lightx2v/tags),选择一个最新日期的 tag,比如 `25091903-cu128`
```bash ```bash
docker pull lightx2v/lightx2v:25090503-cu128 docker pull lightx2v/lightx2v:25091903-cu128
``` ```
我们推荐使用`cuda128`环境,以获得更快的推理速度,若需要使用`cuda124`环境,可以使用带`-cu124`后缀的镜像版本: 我们推荐使用`cuda128`环境,以获得更快的推理速度,若需要使用`cuda124`环境,可以使用带`-cu124`后缀的镜像版本:
```bash ```bash
docker pull lightx2v/lightx2v:25090503-cu124 docker pull lightx2v/lightx2v:25091903-cu124
``` ```
#### 2. 运行容器 #### 2. 运行容器
...@@ -51,10 +51,10 @@ docker run --gpus all -itd --ipc=host --name [容器名] -v [挂载设置] --ent ...@@ -51,10 +51,10 @@ docker run --gpus all -itd --ipc=host --name [容器名] -v [挂载设置] --ent
```bash ```bash
# cuda128 # cuda128
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25090503-cu128 docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25091903-cu128
# cuda124 # cuda124
docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25090503-cu124 docker pull registry.cn-hangzhou.aliyuncs.com/yongyang/lightx2v:25091903-cu124
``` ```
### 🐍 Conda 环境搭建 ### 🐍 Conda 环境搭建
......
...@@ -16,6 +16,8 @@ class Pipeline: ...@@ -16,6 +16,8 @@ class Pipeline:
self.model_lists = [] self.model_lists = []
self.types = {} self.types = {}
self.queues = set() self.queues = set()
self.model_name_inner_to_outer = self.meta.get("model_name_inner_to_outer", {})
self.model_name_outer_to_inner = self.meta.get("model_name_outer_to_inner", {})
self.tidy_pipeline() self.tidy_pipeline()
def init_dict(self, base, task, model_cls): def init_dict(self, base, task, model_cls):
...@@ -132,6 +134,14 @@ class Pipeline: ...@@ -132,6 +134,14 @@ class Pipeline:
item = item[k] item = item[k]
return item return item
def check_item_by_keys(self, keys):
item = self.data
for k in keys:
if k not in item:
return False
item = item[k]
return True
def get_model_lists(self): def get_model_lists(self):
return self.model_lists return self.model_lists
...@@ -144,6 +154,12 @@ class Pipeline: ...@@ -144,6 +154,12 @@ class Pipeline:
def get_queues(self): def get_queues(self):
return self.queues return self.queues
def inner_model_name(self, name):
return self.model_name_outer_to_inner.get(name, name)
def outer_model_name(self, name):
return self.model_name_inner_to_outer.get(name, name)
if __name__ == "__main__": if __name__ == "__main__":
pipeline = Pipeline(sys.argv[1]) pipeline = Pipeline(sys.argv[1])
......
import asyncio
import base64 import base64
import io import io
import os import os
...@@ -87,6 +88,72 @@ async def fetch_resource(url, timeout): ...@@ -87,6 +88,72 @@ async def fetch_resource(url, timeout):
return content return content
# check, resize, read rotate meta info
def format_image_data(data, max_size=1280):
image = Image.open(io.BytesIO(data)).convert("RGB")
exif = image.getexif()
changed = False
w, h = image.size
assert w > 0 and h > 0, "image is empty"
logger.info(f"load image: {w}x{h}, exif: {exif}")
if w > max_size or h > max_size:
ratio = max_size / max(w, h)
w = int(w * ratio)
h = int(h * ratio)
image = image.resize((w, h))
logger.info(f"resize image to: {image.size}")
changed = True
orientation_key = 274
if orientation_key and orientation_key in exif:
orientation = exif[orientation_key]
if orientation == 2:
image = image.transpose(Image.FLIP_LEFT_RIGHT)
elif orientation == 3:
image = image.rotate(180, expand=True)
elif orientation == 4:
image = image.transpose(Image.FLIP_TOP_BOTTOM)
elif orientation == 5:
image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(90, expand=True)
elif orientation == 6:
image = image.rotate(270, expand=True)
elif orientation == 7:
image = image.transpose(Image.FLIP_LEFT_RIGHT).rotate(270, expand=True)
elif orientation == 8:
image = image.rotate(90, expand=True)
# reset orientation to 1
if orientation != 1:
logger.info(f"reset orientation from {orientation} to 1")
exif[orientation_key] = 1
changed = True
if not changed:
return data
output = io.BytesIO()
image.save(output, format=image.format or "JPEG", exif=exif.tobytes())
return output.getvalue()
def format_audio_data(data):
if len(data) < 4:
raise ValueError("Audio file too short")
try:
waveform, sample_rate = torchaudio.load(io.BytesIO(data), num_frames=10)
logger.info(f"load audio: {waveform.size()}, {sample_rate}")
assert waveform.size(0) > 0, "audio is empty"
assert sample_rate > 0, "audio sample rate is not valid"
except Exception as e:
logger.warning(f"torchaudio failed to load audio, trying alternative method: {e}")
# check audio headers
audio_headers = [b"RIFF", b"ID3", b"\xff\xfb", b"\xff\xf3", b"\xff\xf2", b"OggS"]
if not any(data.startswith(header) for header in audio_headers):
logger.warning("Audio file doesn't have recognized header, but continuing...")
logger.info(f"Audio validation passed (alternative method), size: {len(data)} bytes")
return data
async def preload_data(inp, inp_type, typ, val): async def preload_data(inp, inp_type, typ, val):
try: try:
if typ == "url": if typ == "url":
...@@ -102,27 +169,10 @@ async def preload_data(inp, inp_type, typ, val): ...@@ -102,27 +169,10 @@ async def preload_data(inp, inp_type, typ, val):
# check if valid image bytes # check if valid image bytes
if inp_type == "IMAGE": if inp_type == "IMAGE":
image = Image.open(io.BytesIO(data)) data = await asyncio.to_thread(format_image_data, data)
logger.info(f"load image: {image.size}")
assert image.size[0] > 0 and image.size[1] > 0, "image is empty"
elif inp_type == "AUDIO": elif inp_type == "AUDIO":
if typ != "stream": if typ != "stream":
try: data = await asyncio.to_thread(format_audio_data, data)
waveform, sample_rate = torchaudio.load(io.BytesIO(data), num_frames=10)
logger.info(f"load audio: {waveform.size()}, {sample_rate}")
assert waveform.size(0) > 0, "audio is empty"
assert sample_rate > 0, "audio sample rate is not valid"
except Exception as e:
logger.warning(f"torchaudio failed to load audio, trying alternative method: {e}")
# 尝试使用其他方法验证音频文件
# 检查文件头是否为有效的音频格式
if len(data) < 4:
raise ValueError("Audio file too short")
# 检查常见的音频文件头
audio_headers = [b"RIFF", b"ID3", b"\xff\xfb", b"\xff\xf3", b"\xff\xf2", b"OggS"]
if not any(data.startswith(header) for header in audio_headers):
logger.warning("Audio file doesn't have recognized header, but continuing...")
logger.info(f"Audio validation passed (alternative method), size: {len(data)} bytes")
else: else:
raise Exception(f"cannot parse inp_type={inp_type} data") raise Exception(f"cannot parse inp_type={inp_type} data")
return data return data
...@@ -152,3 +202,21 @@ def check_params(params, raw_inputs, raw_outputs, types): ...@@ -152,3 +202,21 @@ def check_params(params, raw_inputs, raw_outputs, types):
assert stream_audio, "stream audio is not supported, please set env STREAM_AUDIO=1" assert stream_audio, "stream audio is not supported, please set env STREAM_AUDIO=1"
elif types[x] == "VIDEO": elif types[x] == "VIDEO":
assert stream_video, "stream video is not supported, please set env STREAM_VIDEO=1" assert stream_video, "stream video is not supported, please set env STREAM_VIDEO=1"
if __name__ == "__main__":
# https://github.com/recurser/exif-orientation-examples
exif_dir = "/data/nvme0/liuliang1/exif-orientation-examples"
out_dir = "/data/nvme0/liuliang1/exif-orientation-examples/outs"
os.makedirs(out_dir, exist_ok=True)
for base_name in ["Landscape", "Portrait"]:
for i in range(9):
fin_name = os.path.join(exif_dir, f"{base_name}_{i}.jpg")
fout_name = os.path.join(out_dir, f"{base_name}_{i}_formatted.jpg")
logger.info(f"format image: {fin_name} -> {fout_name}")
with open(fin_name, "rb") as f:
data = f.read()
data = format_image_data(data)
with open(fout_name, "wb") as f:
f.write(data)
import queue import queue
import random
import signal import signal
import socket import socket
import subprocess import subprocess
...@@ -18,14 +19,13 @@ class VARecorder: ...@@ -18,14 +19,13 @@ class VARecorder:
livestream_url: str, livestream_url: str,
fps: float = 16.0, fps: float = 16.0,
sample_rate: int = 16000, sample_rate: int = 16000,
audio_port: int = 30200,
video_port: int = 30201,
): ):
self.livestream_url = livestream_url self.livestream_url = livestream_url
self.fps = fps self.fps = fps
self.sample_rate = sample_rate self.sample_rate = sample_rate
self.audio_port = audio_port self.audio_port = random.choice(range(32000, 40000))
self.video_port = video_port self.video_port = self.audio_port + 1
logger.info(f"VARecorder audio port: {self.audio_port}, video port: {self.video_port}")
self.width = None self.width = None
self.height = None self.height = None
...@@ -116,6 +116,58 @@ class VARecorder: ...@@ -116,6 +116,58 @@ class VARecorder:
finally: finally:
logger.info("Video push worker thread stopped") logger.info("Video push worker thread stopped")
def start_ffmpeg_process_local(self):
"""Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [
"/opt/conda/bin/ffmpeg",
"-re",
"-f",
"s16le",
"-ar",
str(self.sample_rate),
"-ac",
"1",
"-i",
f"tcp://127.0.0.1:{self.audio_port}",
"-f",
"rawvideo",
"-re",
"-pix_fmt",
"rgb24",
"-r",
str(self.fps),
"-s",
f"{self.width}x{self.height}",
"-i",
f"tcp://127.0.0.1:{self.video_port}",
"-ar",
"44100",
"-b:v",
"4M",
"-c:v",
"libx264",
"-preset",
"ultrafast",
"-tune",
"zerolatency",
"-g",
f"{self.fps}",
"-pix_fmt",
"yuv420p",
"-f",
"mp4",
self.livestream_url,
"-y",
"-loglevel",
"info",
]
try:
self.ffmpeg_process = subprocess.Popen(ffmpeg_cmd)
logger.info(f"FFmpeg streaming started with PID: {self.ffmpeg_process.pid}")
logger.info(f"FFmpeg command: {' '.join(ffmpeg_cmd)}")
except Exception as e:
logger.error(f"Failed to start FFmpeg: {e}")
def start_ffmpeg_process_rtmp(self): def start_ffmpeg_process_rtmp(self):
"""Start ffmpeg process that connects to our TCP sockets""" """Start ffmpeg process that connects to our TCP sockets"""
ffmpeg_cmd = [ ffmpeg_cmd = [
...@@ -240,7 +292,7 @@ class VARecorder: ...@@ -240,7 +292,7 @@ class VARecorder:
elif self.livestream_url.startswith("http"): elif self.livestream_url.startswith("http"):
self.start_ffmpeg_process_whip() self.start_ffmpeg_process_whip()
else: else:
raise Exception(f"Unsupported livestream URL: {self.livestream_url}") self.start_ffmpeg_process_local()
self.audio_thread = threading.Thread(target=self.audio_worker) self.audio_thread = threading.Thread(target=self.audio_worker)
self.video_thread = threading.Thread(target=self.video_worker) self.video_thread = threading.Thread(target=self.video_worker)
self.audio_thread.start() self.audio_thread.start()
...@@ -353,12 +405,13 @@ if __name__ == "__main__": ...@@ -353,12 +405,13 @@ if __name__ == "__main__":
recorder = VARecorder( recorder = VARecorder(
# livestream_url="rtmp://localhost/live/test", # livestream_url="rtmp://localhost/live/test",
livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000", # livestream_url="https://reverse.st-oc-01.chielo.org/10.5.64.49:8000/rtc/v1/whip/?app=live&stream=ll_test_video&eip=127.0.0.1:8000",
livestream_url="/path/to/output_video.mp4",
fps=fps, fps=fps,
sample_rate=sample_rate, sample_rate=sample_rate,
) )
audio_path = "/mtc/liuliang1/lightx2v/test_deploy/media_test/test_b_2min.wav" audio_path = "/path/to/test_b_2min.wav"
audio_array, ori_sr = ta.load(audio_path) audio_array, ori_sr = ta.load(audio_path)
audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000) audio_array = ta.functional.resample(audio_array.mean(0), orig_freq=ori_sr, new_freq=16000)
audio_array = audio_array.numpy().reshape(-1) audio_array = audio_array.numpy().reshape(-1)
......
...@@ -236,6 +236,11 @@ async def prepare_subtasks(task_id): ...@@ -236,6 +236,11 @@ async def prepare_subtasks(task_id):
await server_monitor.pending_subtasks_add(sub["queue"], sub["task_id"]) await server_monitor.pending_subtasks_add(sub["queue"], sub["task_id"])
def format_task(task):
task["status"] = task["status"].name
task["model_cls"] = model_pipelines.outer_model_name(task["model_cls"])
@app.get("/api/v1/model/list") @app.get("/api/v1/model/list")
async def api_v1_model_list(user=Depends(verify_user_access)): async def api_v1_model_list(user=Depends(verify_user_access)):
try: try:
...@@ -254,6 +259,7 @@ async def api_v1_task_submit(request: Request, user=Depends(verify_user_access)) ...@@ -254,6 +259,7 @@ async def api_v1_task_submit(request: Request, user=Depends(verify_user_access))
return error_response(msg, 400) return error_response(msg, 400)
params = await request.json() params = await request.json()
keys = [params.pop("task"), params.pop("model_cls"), params.pop("stage")] keys = [params.pop("task"), params.pop("model_cls"), params.pop("stage")]
keys[1] = model_pipelines.inner_model_name(keys[1])
assert len(params["prompt"]) > 0, "valid prompt is required" assert len(params["prompt"]) > 0, "valid prompt is required"
# get worker infos, model input names # get worker infos, model input names
...@@ -303,7 +309,7 @@ async def api_v1_task_query(request: Request, user=Depends(verify_user_access)): ...@@ -303,7 +309,7 @@ async def api_v1_task_query(request: Request, user=Depends(verify_user_access)):
task, subtasks = await task_manager.query_task(task_id, user["user_id"], only_task=False) task, subtasks = await task_manager.query_task(task_id, user["user_id"], only_task=False)
if task is not None: if task is not None:
task["subtasks"] = await server_monitor.format_subtask(subtasks) task["subtasks"] = await server_monitor.format_subtask(subtasks)
task["status"] = task["status"].name format_task(task)
tasks.append(task) tasks.append(task)
return {"tasks": tasks} return {"tasks": tasks}
...@@ -313,7 +319,7 @@ async def api_v1_task_query(request: Request, user=Depends(verify_user_access)): ...@@ -313,7 +319,7 @@ async def api_v1_task_query(request: Request, user=Depends(verify_user_access)):
if task is None: if task is None:
return error_response(f"Task {task_id} not found", 404) return error_response(f"Task {task_id} not found", 404)
task["subtasks"] = await server_monitor.format_subtask(subtasks) task["subtasks"] = await server_monitor.format_subtask(subtasks)
task["status"] = task["status"].name format_task(task)
return task return task
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
...@@ -344,7 +350,7 @@ async def api_v1_task_list(request: Request, user=Depends(verify_user_access)): ...@@ -344,7 +350,7 @@ async def api_v1_task_list(request: Request, user=Depends(verify_user_access)):
tasks = await task_manager.list_tasks(**query_params) tasks = await task_manager.list_tasks(**query_params)
for task in tasks: for task in tasks:
task["status"] = task["status"].name format_task(task)
return {"tasks": tasks, "pagination": page_info} return {"tasks": tasks, "pagination": page_info}
except Exception as e: except Exception as e:
...@@ -457,12 +463,18 @@ async def api_v1_task_cancel(request: Request, user=Depends(verify_user_access)) ...@@ -457,12 +463,18 @@ async def api_v1_task_cancel(request: Request, user=Depends(verify_user_access))
async def api_v1_task_resume(request: Request, user=Depends(verify_user_access)): async def api_v1_task_resume(request: Request, user=Depends(verify_user_access)):
try: try:
task_id = request.query_params["task_id"] task_id = request.query_params["task_id"]
task = await task_manager.query_task(task_id, user_id=user["user_id"])
keys = [task["task_type"], task["model_cls"], task["stage"]]
if not model_pipelines.check_item_by_keys(keys):
return error_response(f"Model {keys} is not supported now, please submit a new task", 400)
ret = await task_manager.resume_task(task_id, user_id=user["user_id"], all_subtask=False) ret = await task_manager.resume_task(task_id, user_id=user["user_id"], all_subtask=False)
if ret: if ret is True:
await prepare_subtasks(task_id) await prepare_subtasks(task_id)
return {"msg": "ok"} return {"msg": "ok"}
else: else:
return error_response(f"Task {task_id} resume failed", 400) return error_response(f"Task {task_id} resume failed: {ret}", 400)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
return error_response(str(e), 500) return error_response(str(e), 500)
...@@ -605,7 +617,7 @@ async def api_v1_worker_ping_subtask(request: Request, valid=Depends(verify_work ...@@ -605,7 +617,7 @@ async def api_v1_worker_ping_subtask(request: Request, valid=Depends(verify_work
queue = params.pop("queue") queue = params.pop("queue")
task = await task_manager.query_task(task_id) task = await task_manager.query_task(task_id)
if task["status"] != TaskStatus.RUNNING: if task is None or task["status"] != TaskStatus.RUNNING:
return {"msg": "delete"} return {"msg": "delete"}
assert await task_manager.ping_subtask(task_id, worker_name, identity) assert await task_manager.ping_subtask(task_id, worker_name, identity)
...@@ -714,27 +726,18 @@ async def api_v1_template_list(request: Request, valid=Depends(verify_user_acces ...@@ -714,27 +726,18 @@ async def api_v1_template_list(request: Request, valid=Depends(verify_user_acces
if page <= total_pages: if page <= total_pages:
start_idx = (page - 1) * page_size start_idx = (page - 1) * page_size
end_idx = start_idx + page_size end_idx = start_idx + page_size
all_images.sort(key=lambda x: x)
all_audios.sort(key=lambda x: x) async def handle_media(media_type, media_names, paginated_media_templates):
all_videos.sort(key=lambda x: x) media_names.sort(key=lambda x: x)
for media_name in media_names[start_idx:end_idx]:
for image in all_images[start_idx:end_idx]: url = await data_manager.presign_template_url(media_type, media_name)
url = await data_manager.presign_template_url("images", image) if url is None:
if url is None: url = f"./assets/template/{media_type}/{media_name}"
url = f"./assets/template/images/{image}" paginated_media_templates.append({"filename": media_name, "url": url})
paginated_image_templates.append({"filename": image, "url": url})
await handle_media("images", all_images, paginated_image_templates)
for audio in all_audios[start_idx:end_idx]: await handle_media("audios", all_audios, paginated_audio_templates)
url = await data_manager.presign_template_url("audios", audio) await handle_media("videos", all_videos, paginated_video_templates)
if url is None:
url = f"./assets/template/audios/{audio}"
paginated_audio_templates.append({"filename": audio, "url": url})
for video in all_videos[start_idx:end_idx]:
url = await data_manager.presign_template_url("videos", video)
if url is None:
url = f"./assets/template/videos/{video}"
paginated_video_templates.append({"filename": video, "url": url})
return { return {
"templates": {"images": paginated_image_templates, "audios": paginated_audio_templates, "videos": paginated_video_templates}, "templates": {"images": paginated_image_templates, "audios": paginated_audio_templates, "videos": paginated_video_templates},
...@@ -760,6 +763,7 @@ async def api_v1_template_tasks(request: Request, valid=Depends(verify_user_acce ...@@ -760,6 +763,7 @@ async def api_v1_template_tasks(request: Request, valid=Depends(verify_user_acce
page_size = min(page_size, 100) page_size = min(page_size, 100)
all_templates = [] all_templates = []
all_categories = set()
template_files = await data_manager.list_template_files("tasks") template_files = await data_manager.list_template_files("tasks")
template_files = [] if template_files is None else template_files template_files = [] if template_files is None else template_files
...@@ -767,6 +771,8 @@ async def api_v1_template_tasks(request: Request, valid=Depends(verify_user_acce ...@@ -767,6 +771,8 @@ async def api_v1_template_tasks(request: Request, valid=Depends(verify_user_acce
try: try:
bytes_data = await data_manager.load_template_file("tasks", template_file) bytes_data = await data_manager.load_template_file("tasks", template_file)
template_data = json.loads(bytes_data) template_data = json.loads(bytes_data)
template_data["task"]["model_cls"] = model_pipelines.outer_model_name(template_data["task"]["model_cls"])
all_categories.update(template_data["task"]["tags"])
if category is not None and category != "all" and category not in template_data["task"]["tags"]: if category is not None and category != "all" and category not in template_data["task"]["tags"]:
continue continue
if search is not None and search not in template_data["task"]["params"]["prompt"] + template_data["task"]["params"]["negative_prompt"] + template_data["task"][ if search is not None and search not in template_data["task"]["params"]["prompt"] + template_data["task"]["params"]["negative_prompt"] + template_data["task"][
...@@ -787,7 +793,7 @@ async def api_v1_template_tasks(request: Request, valid=Depends(verify_user_acce ...@@ -787,7 +793,7 @@ async def api_v1_template_tasks(request: Request, valid=Depends(verify_user_acce
end_idx = start_idx + page_size end_idx = start_idx + page_size
paginated_templates = all_templates[start_idx:end_idx] paginated_templates = all_templates[start_idx:end_idx]
return {"templates": paginated_templates, "pagination": {"page": page, "page_size": page_size, "total": total_templates, "total_pages": total_pages}} return {"templates": paginated_templates, "pagination": {"page": page, "page_size": page_size, "total": total_templates, "total_pages": total_pages}, "categories": list(all_categories)}
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
......
...@@ -92,6 +92,11 @@ class WorkerClient: ...@@ -92,6 +92,11 @@ class WorkerClient:
if elapse > self.offline_timeout: if elapse > self.offline_timeout:
logger.warning(f"Worker {self.identity} {self.queue} offline timeout2: {elapse:.2f} s") logger.warning(f"Worker {self.identity} {self.queue} offline timeout2: {elapse:.2f} s")
return False return False
# fetching too long
elif self.status == WorkerStatus.FETCHING:
if elapse > self.fetching_timeout:
logger.warning(f"Worker {self.identity} {self.queue} fetching timeout: {elapse:.2f} s")
return False
return True return True
...@@ -111,7 +116,7 @@ class ServerMonitor: ...@@ -111,7 +116,7 @@ class ServerMonitor:
self.fetching_timeout = self.config.get("fetching_timeout", 1000) self.fetching_timeout = self.config.get("fetching_timeout", 1000)
for queue in self.all_queues: for queue in self.all_queues:
self.subtask_run_timeouts[queue] = self.config["subtask_running_timeouts"].get(queue, 60) self.subtask_run_timeouts[queue] = self.config["subtask_running_timeouts"].get(queue, 3600)
self.subtask_created_timeout = self.config["subtask_created_timeout"] self.subtask_created_timeout = self.config["subtask_created_timeout"]
self.subtask_pending_timeout = self.config["subtask_pending_timeout"] self.subtask_pending_timeout = self.config["subtask_pending_timeout"]
self.worker_avg_window = self.config["worker_avg_window"] self.worker_avg_window = self.config["worker_avg_window"]
......
This diff is collapsed.
...@@ -273,10 +273,10 @@ class LocalTaskManager(BaseTaskManager): ...@@ -273,10 +273,10 @@ class LocalTaskManager(BaseTaskManager):
task, subtasks = self.load(task_id, user_id) task, subtasks = self.load(task_id, user_id)
# the task is not finished # the task is not finished
if task["status"] not in FinishedStatus: if task["status"] not in FinishedStatus:
return False return "Active task cannot be resumed"
# the task is no need to resume # the task is no need to resume
if not all_subtask and task["status"] == TaskStatus.SUCCEED: if not all_subtask and task["status"] == TaskStatus.SUCCEED:
return False return "Succeed task cannot be resumed"
for sub in subtasks: for sub in subtasks:
if all_subtask or sub["status"] != TaskStatus.SUCCEED: if all_subtask or sub["status"] != TaskStatus.SUCCEED:
self.mark_subtask_change(records, sub, None, TaskStatus.CREATED) self.mark_subtask_change(records, sub, None, TaskStatus.CREATED)
......
...@@ -702,10 +702,10 @@ class PostgresSQLTaskManager(BaseTaskManager): ...@@ -702,10 +702,10 @@ class PostgresSQLTaskManager(BaseTaskManager):
task, subtasks = await self.load(conn, task_id, user_id) task, subtasks = await self.load(conn, task_id, user_id)
# the task is not finished # the task is not finished
if task["status"] not in FinishedStatus: if task["status"] not in FinishedStatus:
return False return "Active task cannot be resumed"
# the task is no need to resume # the task is no need to resume
if not all_subtask and task["status"] == TaskStatus.SUCCEED: if not all_subtask and task["status"] == TaskStatus.SUCCEED:
return False return "Succeed task cannot be resumed"
for sub in subtasks: for sub in subtasks:
if all_subtask or sub["status"] != TaskStatus.SUCCEED: if all_subtask or sub["status"] != TaskStatus.SUCCEED:
......
...@@ -23,15 +23,19 @@ from lightx2v.utils.utils import seed_all ...@@ -23,15 +23,19 @@ from lightx2v.utils.utils import seed_all
class BaseWorker: class BaseWorker:
@ProfilingContext4DebugL1("Init Worker Worker Cost:") @ProfilingContext4DebugL1("Init Worker Worker Cost:")
def __init__(self, args): def __init__(self, args):
args.save_video_path = ""
config = set_config(args) config = set_config(args)
config["mode"] = ""
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
seed_all(config.seed) seed_all(config.seed)
self.rank = 0 self.rank = 0
self.world_size = 1
if config.parallel: if config.parallel:
self.rank = dist.get_rank() self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
set_parallel_config(config) set_parallel_config(config)
seed_all(config.seed) seed_all(config.seed)
# same as va_recorder rank and worker main ping rank
self.out_video_rank = self.world_size - 1
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.runner = RUNNER_REGISTER[config.model_cls](config) self.runner = RUNNER_REGISTER[config.model_cls](config)
# fixed config # fixed config
...@@ -121,7 +125,7 @@ class BaseWorker: ...@@ -121,7 +125,7 @@ class BaseWorker:
async def save_output_video(self, tmp_video_path, output_video_path, data_manager): async def save_output_video(self, tmp_video_path, output_video_path, data_manager):
# save output video # save output video
if data_manager.name != "local" and self.rank == 0 and isinstance(tmp_video_path, str): if data_manager.name != "local" and self.rank == self.out_video_rank and isinstance(tmp_video_path, str):
video_data = open(tmp_video_path, "rb").read() video_data = open(tmp_video_path, "rb").read()
await data_manager.save_bytes(video_data, output_video_path) await data_manager.save_bytes(video_data, output_video_path)
......
...@@ -85,7 +85,7 @@ def main(): ...@@ -85,7 +85,7 @@ def main():
help="The file of the source mask. Default None.", help="The file of the source mask. Default None.",
) )
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file") parser.add_argument("--save_video_path", type=str, default=None, help="The path to save video path/file")
args = parser.parse_args() args = parser.parse_args()
# set config # set config
......
...@@ -295,7 +295,9 @@ class DefaultRunner(BaseRunner): ...@@ -295,7 +295,9 @@ class DefaultRunner(BaseRunner):
save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg") save_to_video(self.gen_video, self.config.save_video_path, fps=fps, method="ffmpeg")
logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅") logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
return {"video": self.gen_video} if self.config.get("return_video", False):
return {"video": self.gen_video}
return {"video": None}
def run_pipeline(self, save_video=True): def run_pipeline(self, save_video=True):
if self.config["use_prompt_enhancer"]: if self.config["use_prompt_enhancer"]:
......
import gc import gc
import os import os
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
...@@ -12,7 +11,6 @@ import torchvision.transforms.functional as TF ...@@ -12,7 +11,6 @@ import torchvision.transforms.functional as TF
from PIL import Image from PIL import Image
from einops import rearrange from einops import rearrange
from loguru import logger from loguru import logger
from torchvision.io import write_video
from torchvision.transforms import InterpolationMode from torchvision.transforms import InterpolationMode
from torchvision.transforms.functional import resize from torchvision.transforms.functional import resize
...@@ -28,9 +26,7 @@ from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE ...@@ -28,9 +26,7 @@ from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.utils import find_torch_model_path, load_weights, vae_to_comfyui_image from lightx2v.utils.utils import find_torch_model_path, load_weights, vae_to_comfyui_image_inplace
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.io._video_deprecation_warning")
def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size): def get_optimal_patched_size_with_sp(patched_h, patched_w, sp_size):
...@@ -475,10 +471,12 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -475,10 +471,12 @@ class WanAudioRunner(WanRunner): # type:ignore
def init_run(self): def init_run(self):
super().init_run() super().init_run()
self.scheduler.set_audio_adapter(self.audio_adapter) self.scheduler.set_audio_adapter(self.audio_adapter)
self.gen_video_list = []
self.cut_audio_list = []
self.prev_video = None self.prev_video = None
if self.config.get("return_video", False):
self.gen_video_final = torch.zeros((self.inputs["expected_frames"], self.config.tgt_h, self.config.tgt_w, 3), dtype=torch.float32, device="cpu")
else:
self.gen_video_final = None
self.cut_audio_final = None
@ProfilingContext4DebugL1("Init run segment") @ProfilingContext4DebugL1("Init run segment")
def init_run_segment(self, segment_idx, audio_array=None): def init_run_segment(self, segment_idx, audio_array=None):
...@@ -510,22 +508,31 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -510,22 +508,31 @@ class WanAudioRunner(WanRunner): # type:ignore
def end_run_segment(self): def end_run_segment(self):
self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float) self.gen_video = torch.clamp(self.gen_video, -1, 1).to(torch.float)
useful_length = self.segment.end_frame - self.segment.start_frame useful_length = self.segment.end_frame - self.segment.start_frame
self.gen_video_list.append(self.gen_video[:, :, :useful_length].cpu()) video_seg = self.gen_video[:, :, :useful_length].cpu()
self.cut_audio_list.append(self.segment.audio_array[: useful_length * self._audio_processor.audio_frame_rate]) audio_seg = self.segment.audio_array[: useful_length * self._audio_processor.audio_frame_rate]
if self.va_recorder: video_seg = vae_to_comfyui_image_inplace(video_seg)
cur_video = vae_to_comfyui_image(self.gen_video_list[-1])
self.va_recorder.pub_livestream(cur_video, self.cut_audio_list[-1]) # [Warning] Need check whether video segment interpolation works...
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
video_seg = self.vfi_model.interpolate_frames(
video_seg,
source_fps=self.config.get("fps", 16),
target_fps=target_fps,
)
if self.va_reader: if self.va_recorder:
self.gen_video_list.pop() self.va_recorder.pub_livestream(video_seg, audio_seg)
self.cut_audio_list.pop() elif self.config.get("return_video", False):
self.gen_video_final[self.segment.start_frame : self.segment.end_frame].copy_(video_seg)
self.cut_audio_final = np.concatenate([self.cut_audio_final, audio_seg], axis=0).astype(np.float32) if self.cut_audio_final is not None else audio_seg
# Update prev_video for next iteration # Update prev_video for next iteration
self.prev_video = self.gen_video self.prev_video = self.gen_video
# Clean up GPU memory after each segment del video_seg, audio_seg
del self.gen_video
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_rank_and_world_size(self): def get_rank_and_world_size(self):
...@@ -540,18 +547,19 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -540,18 +547,19 @@ class WanAudioRunner(WanRunner): # type:ignore
output_video_path = self.config.get("save_video_path", None) output_video_path = self.config.get("save_video_path", None)
self.va_recorder = None self.va_recorder = None
if isinstance(output_video_path, dict): if isinstance(output_video_path, dict):
assert output_video_path["type"] == "stream", f"unexcept save_video_path: {output_video_path}" output_video_path = output_video_path["data"]
rank, world_size = self.get_rank_and_world_size() logger.info(f"init va_recorder with output_video_path: {output_video_path}")
if rank == 2 % world_size: rank, world_size = self.get_rank_and_world_size()
record_fps = self.config.get("target_fps", 16) if output_video_path and rank == world_size - 1:
audio_sr = self.config.get("audio_sr", 16000) record_fps = self.config.get("target_fps", 16)
if "video_frame_interpolation" in self.config and self.vfi_model is not None: audio_sr = self.config.get("audio_sr", 16000)
record_fps = self.config["video_frame_interpolation"]["target_fps"] if "video_frame_interpolation" in self.config and self.vfi_model is not None:
self.va_recorder = VARecorder( record_fps = self.config["video_frame_interpolation"]["target_fps"]
livestream_url=output_video_path["data"], self.va_recorder = VARecorder(
fps=record_fps, livestream_url=output_video_path,
sample_rate=audio_sr, fps=record_fps,
) sample_rate=audio_sr,
)
def init_va_reader(self): def init_va_reader(self):
audio_path = self.config.get("audio_path", None) audio_path = self.config.get("audio_path", None)
...@@ -583,8 +591,8 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -583,8 +591,8 @@ class WanAudioRunner(WanRunner): # type:ignore
return super().run_main(total_steps) return super().run_main(total_steps)
rank, world_size = self.get_rank_and_world_size() rank, world_size = self.get_rank_and_world_size()
if rank == 2 % world_size: if rank == world_size - 1:
assert self.va_recorder is not None, "va_recorder is required for stream audio input for rank 0" assert self.va_recorder is not None, "va_recorder is required for stream audio input for rank 2"
self.va_reader.start() self.va_reader.start()
self.init_run() self.init_run()
...@@ -627,67 +635,17 @@ class WanAudioRunner(WanRunner): # type:ignore ...@@ -627,67 +635,17 @@ class WanAudioRunner(WanRunner): # type:ignore
self.va_recorder = None self.va_recorder = None
@ProfilingContext4DebugL1("Process after vae decoder") @ProfilingContext4DebugL1("Process after vae decoder")
def process_images_after_vae_decoder(self, save_video=True): def process_images_after_vae_decoder(self, save_video=False):
# Merge results if self.config.get("return_video", False):
gen_lvideo = torch.cat(self.gen_video_list, dim=2).float() audio_waveform = torch.from_numpy(self.cut_audio_final).unsqueeze(0).unsqueeze(0)
merge_audio = np.concatenate(self.cut_audio_list, axis=0).astype(np.float32) comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
return {"video": self.gen_video_final, "audio": comfyui_audio}
comfyui_images = vae_to_comfyui_image(gen_lvideo) return {"video": None, "audio": None}
# Apply frame interpolation if configured
if "video_frame_interpolation" in self.config and self.vfi_model is not None:
target_fps = self.config["video_frame_interpolation"]["target_fps"]
logger.info(f"Interpolating frames from {self.config.get('fps', 16)} to {target_fps}")
comfyui_images = self.vfi_model.interpolate_frames(
comfyui_images,
source_fps=self.config.get("fps", 16),
target_fps=target_fps,
)
if save_video and isinstance(self.config["save_video_path"], str):
if "video_frame_interpolation" in self.config and self.config["video_frame_interpolation"].get("target_fps"):
fps = self.config["video_frame_interpolation"]["target_fps"]
else:
fps = self.config.get("fps", 16)
if not dist.is_initialized() or dist.get_rank() == 0:
logger.info(f"🎬 Start to save video 🎬")
self._save_video_with_audio(comfyui_images, merge_audio, fps)
logger.info(f"✅ Video saved successfully to: {self.config.save_video_path} ✅")
# Convert audio to ComfyUI format
audio_waveform = torch.from_numpy(merge_audio).unsqueeze(0).unsqueeze(0)
comfyui_audio = {"waveform": audio_waveform, "sample_rate": self._audio_processor.audio_sr}
return {"video": comfyui_images, "audio": comfyui_audio}
def init_modules(self): def init_modules(self):
super().init_modules() super().init_modules()
self.run_input_encoder = self._run_input_encoder_local_r2v_audio self.run_input_encoder = self._run_input_encoder_local_r2v_audio
def _save_video_with_audio(self, images, audio_array, fps):
output_path = self.config.get("save_video_path")
parent_dir = os.path.dirname(output_path)
if parent_dir and not os.path.exists(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
sample_rate = self._audio_processor.audio_sr
if images.dtype != torch.uint8:
images = (images * 255).clamp(0, 255).to(torch.uint8)
write_video(
filename=output_path,
video_array=images,
fps=fps,
video_codec="libx264",
audio_array=torch.tensor(audio_array[None]),
audio_fps=sample_rate,
audio_codec="aac",
options={"preset": "medium", "crf": "23"}, # 可调整视频输出质量
)
def load_transformer(self): def load_transformer(self):
"""Load transformer with LoRA support""" """Load transformer with LoRA support"""
base_model = WanAudioModel(self.config.model_path, self.config, self.init_device) base_model = WanAudioModel(self.config.model_path, self.config, self.init_device)
......
...@@ -31,6 +31,7 @@ def get_default_config(): ...@@ -31,6 +31,7 @@ def get_default_config():
"tgt_h": None, "tgt_h": None,
"tgt_w": None, "tgt_w": None,
"target_shape": None, "target_shape": None,
"return_video": False,
} }
return default_config return default_config
...@@ -73,6 +74,8 @@ def set_config(args): ...@@ -73,6 +74,8 @@ def set_config(args):
logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.") logger.warning(f"`num_frames - 1` has to be divisible by {config.vae_stride[0]}. Rounding to the nearest number.")
config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1 config.target_video_length = config.target_video_length // config.vae_stride[0] * config.vae_stride[0] + 1
assert not (config.save_video_path and config.return_video), "save_video_path and return_video cannot be set at the same time"
return config return config
......
...@@ -131,6 +131,39 @@ def vae_to_comfyui_image(vae_output: torch.Tensor) -> torch.Tensor: ...@@ -131,6 +131,39 @@ def vae_to_comfyui_image(vae_output: torch.Tensor) -> torch.Tensor:
return images return images
def vae_to_comfyui_image_inplace(vae_output: torch.Tensor) -> torch.Tensor:
"""
Convert VAE decoder output to ComfyUI Image format (inplace operation)
Args:
vae_output: VAE decoder output tensor, typically in range [-1, 1]
Shape: [B, C, T, H, W] or [B, C, H, W]
WARNING: This tensor will be modified in-place!
Returns:
ComfyUI Image tensor in range [0, 1]
Shape: [B, H, W, C] for single frame or [B*T, H, W, C] for video
Note: The returned tensor is the same object as input (modified in-place)
"""
# Handle video tensor (5D) vs image tensor (4D)
if vae_output.dim() == 5:
# Video tensor: [B, C, T, H, W]
B, C, T, H, W = vae_output.shape
# Reshape to [B*T, C, H, W] for processing (inplace view)
vae_output = vae_output.permute(0, 2, 1, 3, 4).contiguous().view(B * T, C, H, W)
# Normalize from [-1, 1] to [0, 1] (inplace)
vae_output.add_(1).div_(2)
# Clamp values to [0, 1] (inplace)
vae_output.clamp_(0, 1)
# Convert from [B, C, H, W] to [B, H, W, C] and move to CPU
vae_output = vae_output.permute(0, 2, 3, 1).cpu()
return vae_output
def save_to_video( def save_to_video(
images: torch.Tensor, images: torch.Tensor,
output_path: str, output_path: str,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment