Commit 398b598a authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Refactor/api (#94)

* fix: correct frequency computation in WanTransformerInfer

* refactor: restructure API server and distributed inference services

- Removed the old api_server_dist.py file and integrated its functionality into a new modular structure.
- Created a new ApiServer class to handle API routes and services.
- Introduced DistributedInferenceService and FileService for better separation of concerns.
- Updated the main entry point to initialize and run the new API server with distributed inference capabilities.
- Added schema definitions for task requests and responses to improve data handling.
- Enhanced error handling and logging throughout the services.

* refactor: enhance API structure and file handling in server

- Introduced APIRouter for modular route management in the ApiServer class.
- Updated task creation and file download endpoints to improve clarity and functionality.
- Implemented a new method for streaming file responses with proper MIME type handling.
- Refactored task request schema to auto-generate task IDs and handle optional video save paths.
- Improved error handling and logging for better debugging and user feedback.

* feat: add configurable parameters for video generation

- Introduced new parameters: infer_steps, target_video_length, and seed to the API and task request schema.
- Updated DefaultRunner and VideoGenerationService to handle the new parameters for enhanced video generation control.
- Improved default values for parameters to ensure consistent behavior.

* refactor: enhance profiling context for async support

* refactor: improve signal handling in API server

* feat: enhance video generation capabilities with audio support

* refactor: improve subprocess call for audio-video merging in wan_audio_runner.py

* refactor: enhance API server argument parsing and improve code readability

* refactor: enhance logging and improve code comments for clarity

* refactor: update response model for task listing endpoint to return a dictionary

* docs: update API endpoints and improve documentation clarity

* refactor: update API endpoints in scripts for task management and remove unused code

* fix: pre-commit
parent 1e422663
......@@ -2,7 +2,6 @@
lightx2v provides asynchronous service functionality. The code entry point is [here](https://github.com/ModelTC/lightx2v/blob/main/lightx2v/api_server.py)
### Start the Service
```shell
......@@ -12,36 +11,26 @@ bash scripts/start_server.sh
The `--port 8000` option means the service will bind to port `8000` on the local machine. You can change this as needed.
### Client Sends Request
```shell
python scripts/post.py
```
The service endpoint is: `/v1/local/video/generate`
The service endpoint is: `/v1/tasks/`
The `message` parameter in `scripts/post.py` is as follows:
```python
message = {
"task_id": generate_task_id(),
"task_id_must_unique": True,
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4",
}
```
1. `prompt`, `negative_prompt`, and `image_path` are basic inputs for video generation. `image_path` can be an empty string, indicating no image input is needed.
2. `save_video_path` specifies the path where the generated video will be saved on the server. The relative path is relative to the server's startup directory. It is recommended to set an absolute path according to your environment.
3. `task_id` is the ID of the task, which is a string. You can customize a string or use the `generate_task_id()` function to generate a random string. The task ID is used to distinguish between different video generation tasks.
4. `task_id_must_unique` indicates whether each `task_id` must be unique. If set to `False`, there is no such restriction. In this case, if duplicate `task_id`s are sent, the server's `task` record will be overwritten by the newer task with the same `task_id`. If you do not need to keep a record of all tasks for querying, you can set this to `False`.
### Client Checks Server Status
......@@ -51,12 +40,11 @@ python scripts/check_status.py
The service endpoints include:
1. `/v1/local/video/generate/service_status` is used to check the status of the service. It returns whether the service is `busy` or `idle`. The service only accepts new requests when it is `idle`.
2. `/v1/local/video/generate/get_all_tasks` is used to get all tasks received and completed by the server.
1. `/v1/service/status` is used to check the status of the service. It returns whether the service is `busy` or `idle`. The service only accepts new requests when it is `idle`.
3. `/v1/local/video/generate/task_status` is used to get the status of a specified `task_id`. It returns whether the task is `processing` or `completed`.
2. `/v1/tasks/` is used to get all tasks received and completed by the server.
3. `/v1/tasks/{task_id}/status` is used to get the status of a specified `task_id`. It returns whether the task is `processing` or `completed`.
### Client Stops the Current Task on the Server at Any Time
......@@ -64,7 +52,7 @@ The service endpoints include:
python scripts/stop_running_task.py
```
The service endpoint is: `/v1/local/video/generate/stop_running_task`
The service endpoint is: `/v1/tasks/running`
After terminating the task, the server will not exit but will return to waiting for new requests.
......@@ -78,7 +66,6 @@ num_gpus=8 bash scripts/start_multi_servers.sh
Where `num_gpus` indicates the number of services to start; the services will run on consecutive ports starting from `--start_port`.
### Scheduling Between Multiple Services
```shell
......@@ -86,3 +73,16 @@ python scripts/post_multi_servers.py
```
`post_multi_servers.py` will schedule multiple client requests based on the idle status of the services.
### API Endpoints Summary
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/v1/tasks/` | POST | Create video generation task |
| `/v1/tasks/form` | POST | Create video generation task via form |
| `/v1/tasks/` | GET | Get all task list |
| `/v1/tasks/{task_id}/status` | GET | Get status of specified task |
| `/v1/tasks/{task_id}/result` | GET | Get result video file of specified task |
| `/v1/tasks/running` | DELETE | Stop currently running task |
| `/v1/files/download/{file_path}` | GET | Download file |
| `/v1/service/status` | GET | Get service status |
# 如何启动服务
lightx2v提供了异步服务功能,代码入口处在[这里](https://github.com/ModelTC/lightx2v/blob/main/lightx2v/api_server.py)
lightx2v 提供异步服务功能。代码入口点在 [这里](https://github.com/ModelTC/lightx2v/blob/main/lightx2v/api_server.py)
### 启动服务
......@@ -10,8 +9,7 @@ lightx2v提供了异步服务功能,代码入口处在[这里](https://github.
bash scripts/start_server.sh
```
其中的`--port 8000`表示服务绑定在本机的`8000`端口上,可以自行修改
`--port 8000` 选项表示服务将绑定到本地机器的 `8000` 端口。您可以根据需要更改此端口。
### 客户端发送请求
......@@ -19,65 +17,54 @@ bash scripts/start_server.sh
python scripts/post.py
```
服务的接口是`/v1/local/video/generate`
服务端点`/v1/tasks/`
`scripts/post.py`中的`message`参数如下:
`scripts/post.py` 中的 `message` 参数如下:
```python
message = {
"task_id": generate_task_id(),
"task_id_must_unique": True,
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4",
"image_path": ""
}
```
1. `prompt`, `negative_prompt`, `image_path`是一些基础的视频生成的输入,`image_path`可以为空字符,表示不需要图片输入
2. `save_video_path`表示服务端生成的视频的路径,相对路径是相对服务端的启动路径,建议根据你自己的环境,设置一个绝对路径。
3. `task_id`表示该任务的id,格式是一个字符串。可以自定义个字符串,也可以调用`generate_task_id()`函数生成一个随机的字符串。任务的id用来区分不同的视频生成任务。
4. `task_id_must_unique`表示是否要求每个`task_id`是独一无二的,即不能发有重复的`task_id`。如果是`False`,就没有这个强制要求,此时如果发送了重复的`task_id`,服务端的`task`记录将会被相同`task_id`的较新的`task`覆盖掉。如果不需要记录所有的`task`以用于查询,那这里就可以设置成`False`
1. `prompt``negative_prompt``image_path` 是视频生成的基本输入。`image_path` 可以是空字符串,表示不需要图像输入。
### 客户端获取服务端的状态
### 客户端检查服务器状态
```shell
python scripts/check_status.py
```
其中服务的接口有
服务端点包括
1. `/v1/local/video/generate/service_status`用于检查服务状态,可以返回得到服务是`busy`还是`idle`只有在`idle`状态,该服务才会接收新的请求。
1. `/v1/service/status` 用于检查服务状态。返回服务是 `busy` 还是 `idle`。服务只有在 `idle` 时才接受新请求。
2. `/v1/local/video/generate/get_all_tasks`用于获取服务接收到的且已完成的所有任务。
2. `/v1/tasks/` 用于获取服务接收完成的所有任务。
3. `/v1/local/video/generate/task_status`用于获取指定`task_id`的状态,可以返回得到该任务是`processing`还是`completed`
3. `/v1/tasks/{task_id}/status` 用于获取指定 `task_id` 的任务状态。返回任务是 `processing` 还是 `completed`
### 客户端随时终止服务端当前的任务
### 客户端随时停止服务器上的当前任务
```shell
python scripts/stop_running_task.py
```
服务的接口是`/v1/local/video/generate/stop_running_task`
服务端点`/v1/tasks/running`
终止任务后,服务端并不会退出服务,而是回到等待接收新请求的状态。
终止任务后,服务不会退出,而是返回等待新请求的状态。
### 单节点同时起多个服务
### 在单个节点上启动多个服务
在单节点上,可以多次使用`scripts/start_server.sh`同时起多个服务(注意同一个ip下的端口号,不同服务之间要保持不同),也可以直接通过`scripts/start_multi_servers.sh`一次性起多个服务:
在单节点上,可以使用 `scripts/start_server.sh` 启动多个服务(注意同一 IP 下的端口号必须不同),或者可以使用 `scripts/start_multi_servers.sh` 同时启动多个服务:
```shell
num_gpus=8 bash scripts/start_multi_servers.sh
```
其中`num_gpus`表示启动的服务数;服务将在`--start_port`开始的连续`num_gpus`个端口上运行。
其中 `num_gpus` 表示要启动的服务数量;服务将从 `--start_port` 开始在连续端口上运行。
### 多个服务之间的调度
......@@ -85,4 +72,17 @@ num_gpus=8 bash scripts/start_multi_servers.sh
python scripts/post_multi_servers.py
```
`post_multi_servers.py`会根据服务的空闲状态,调度客户端发起的多个请求。
`post_multi_servers.py` 将根据服务的空闲状态调度多个客户端请求。
### API 端点总结
| 端点 | 方法 | 描述 |
|------|------|------|
| `/v1/tasks/` | POST | 创建视频生成任务 |
| `/v1/tasks/form` | POST | 通过表单创建视频生成任务 |
| `/v1/tasks/` | GET | 获取所有任务列表 |
| `/v1/tasks/{task_id}/status` | GET | 获取指定任务状态 |
| `/v1/tasks/{task_id}/result` | GET | 获取指定任务的结果视频文件 |
| `/v1/tasks/running` | DELETE | 停止当前运行的任务 |
| `/v1/files/download/{file_path}` | GET | 下载文件 |
| `/v1/service/status` | GET | 获取服务状态 |
import asyncio
import argparse
from fastapi import FastAPI
from pydantic import BaseModel
import sys
import signal
import atexit
from pathlib import Path
from loguru import logger
import uvicorn
import json
from typing import Optional
from datetime import datetime
import threading
import ctypes
import gc
import torch
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager
from lightx2v.server.api import ApiServer
from lightx2v.server.service import DistributedInferenceService
from lightx2v.server.utils import ProcessManager
# =========================
# FastAPI Related Code
# =========================
def create_signal_handler(inference_service: DistributedInferenceService):
"""Create unified signal handler function"""
runner = None
thread = None
app = FastAPI()
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
prompt: str
use_prompt_enhancer: bool = False
negative_prompt: str = ""
image_path: str = ""
num_fragments: int = 1
save_video_path: str
def get(self, key, default=None):
return getattr(self, key, default)
class ApiServerServiceStatus(BaseServiceStatus):
pass
def local_video_generate(message: Message):
try:
global runner
runner.set_inputs(message)
logger.info(f"message: {message}")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, gracefully shutting down...")
try:
loop.run_until_complete(runner.run_pipeline())
if inference_service.is_running:
inference_service.stop_distributed_inference()
except Exception as e:
logger.error(f"Error occurred while shutting down distributed inference service: {str(e)}")
finally:
loop.close()
ApiServerServiceStatus.complete_task(message)
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
ApiServerServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/video/generate")
async def v1_local_video_generate(message: Message):
try:
task_id = ApiServerServiceStatus.start_task(message)
# Use background threads to perform long-running tasks
global thread
thread = threading.Thread(target=local_video_generate, args=(message,), daemon=True)
thread.start()
return {"task_id": task_id, "task_status": "processing", "save_video_path": message.save_video_path}
except RuntimeError as e:
return {"error": str(e)}
sys.exit(0)
return signal_handler
@app.get("/v1/local/video/generate/service_status")
async def get_service_status():
return ApiServerServiceStatus.get_status_service()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_cls",
type=str,
required=True,
choices=[
"wan2.1",
"hunyuan",
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"wan2.1_audio",
],
default="hunyuan",
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
@app.get("/v1/local/video/generate/get_all_tasks")
async def get_all_tasks():
return ApiServerServiceStatus.get_all_tasks()
parser.add_argument("--split", action="store_true")
parser.add_argument("--lora_path", type=str, required=False, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node for distributed inference")
args = parser.parse_args()
logger.info(f"args: {args}")
@app.post("/v1/local/video/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return ApiServerServiceStatus.get_status_task_id(message.task_id)
cache_dir = Path(__file__).parent.parent / ".cache"
inference_service = DistributedInferenceService()
api_server = ApiServer()
api_server.initialize_services(cache_dir, inference_service)
def _async_raise(tid, exctype):
"""Force thread tid to raise exception exctype"""
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(exctype))
if res == 0:
raise ValueError("Invalid thread ID")
elif res > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
signal_handler = create_signal_handler(inference_service)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
logger.info("Starting distributed inference service...")
success = inference_service.start_distributed_inference(args)
if not success:
logger.error("Failed to start distributed inference service, exiting program")
sys.exit(1)
@app.get("/v1/local/video/generate/stop_running_task")
async def stop_running_task():
global thread
if thread and thread.is_alive():
try:
_async_raise(thread.ident, SystemExit)
thread.join()
# Clean up the thread reference
thread = None
ApiServerServiceStatus.clean_stopped_task()
gc.collect()
torch.cuda.empty_cache()
return {"stop_status": "success", "reason": "Task stopped successfully."}
except Exception as e:
return {"stop_status": "error", "reason": str(e)}
else:
return {"stop_status": "do_nothing", "reason": "No running task found."}
atexit.register(inference_service.stop_distributed_inference)
try:
logger.info(f"Starting FastAPI server on port: {args.port}")
uvicorn.run(
api_server.get_app(),
host="0.0.0.0",
port=args.port,
reload=False,
workers=1,
)
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down service...")
except Exception as e:
logger.error(f"Error occurred while running FastAPI server: {str(e)}")
finally:
inference_service.stop_distributed_inference()
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--split", action="store_true")
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
config = set_config(args)
config["mode"] = "split_server" if args.split else "server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
main()
This diff is collapsed.
......@@ -41,7 +41,6 @@ def radial_attn(
head_dim=hidden_dim,
q_data_type=query.dtype,
kv_data_type=key.dtype,
o_data_type=query.dtype,
use_fp16_qk_reduction=True,
)
......
......@@ -58,6 +58,19 @@ async def main():
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args()
if args.prompt_path:
try:
with open(args.prompt_path, "r", encoding="utf-8") as f:
args.prompt = f.read().strip()
logger.info(f"从文件 {args.prompt_path} 读取到prompt: {args.prompt}")
except FileNotFoundError:
logger.error(f"找不到prompt文件: {args.prompt_path}")
raise
except Exception as e:
logger.error(f"读取prompt文件时出错: {e}")
raise
logger.info(f"args: {args}")
with ProfilingContext("Total Cost"):
......
......@@ -349,7 +349,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(2) // 2, grid_sizes, freqs)
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
freqs_i = self.zero_temporal_component_in_3DRoPE(seq_lens, freqs_i)
......
......@@ -101,6 +101,14 @@ class DefaultRunner:
self.config["negative_prompt"] = inputs.get("negative_prompt", "")
self.config["image_path"] = inputs.get("image_path", "")
self.config["save_video_path"] = inputs.get("save_video_path", "")
self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
self.config["audio_path"] = inputs.get("audio_path", "") # for wan-audio
self.config["video_duration"] = inputs.get("video_duration", 5) # for wan-audio
# self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
# self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
def run(self):
for step_index in range(self.model.scheduler.infer_steps):
......
......@@ -219,8 +219,9 @@ def save_to_video(gen_lvideo, out_path, target_fps):
def save_audio(
audio_array,
audio_name: str,
video_name: str = None,
video_name: str,
sr: int = 16000,
output_path: Optional[str] = None,
):
logger.info(f"Saving audio to {audio_name} type: {type(audio_array)}")
......@@ -230,18 +231,21 @@ def save_audio(
sample_rate=sr,
)
out_video = f"{video_name[:-4]}_with_audio.mp4"
# 确保父目录存在
if output_path is None:
out_video = f"{video_name[:-4]}_with_audio.mp4"
else:
out_video = output_path
parent_dir = os.path.dirname(out_video)
if parent_dir and not os.path.exists(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
# 如果输出视频已存在,先删除
if os.path.exists(out_video):
os.remove(out_video)
cmd = f"/usr/bin/ffmpeg -i {video_name} -i {audio_name} {out_video}"
subprocess.call(cmd, shell=True)
subprocess.call(["/usr/bin/ffmpeg", "-y", "-i", video_name, "-i", audio_name, out_video])
return out_video
@RUNNER_REGISTER("wan2.1_audio")
......@@ -323,17 +327,16 @@ class WanAudioRunner(WanRunner):
"vae_encode_out": vae_encode_out,
}
logger.info(f"clip_encoder_out:{clip_encoder_out.shape} vae_encode_out:{vae_encode_out.shape}")
with ProfilingContext("Run Text Encoder"):
with open(self.config["prompt_path"], "r", encoding="utf-8") as f:
prompt = f.readline().strip()
logger.info(f"Prompt: {prompt}")
img = Image.open(self.config["image_path"]).convert("RGB")
text_encoder_output = self.run_text_encoder(prompt, img)
logger.info(f"Prompt: {self.config['prompt']}")
img = Image.open(self.config["image_path"]).convert("RGB")
text_encoder_output = self.run_text_encoder(self.config["prompt"], img)
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
del self.image_encoder # 删除ref的clip模型,只使用一次
# del self.image_encoder # 删除ref的clip模型,只使用一次
gc.collect()
torch.cuda.empty_cache()
......@@ -488,10 +491,10 @@ class WanAudioRunner(WanRunner):
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
out_path = self.config.save_video_path
out_path = os.path.join("./", "video_merge.mp4")
audio_file = os.path.join("./", "audio_merge.wav")
save_to_video(gen_lvideo, out_path, target_fps)
save_audio(merge_audio, audio_file, out_path)
save_audio(merge_audio, audio_file, out_path, output_path=self.config.get("save_video_path", None))
os.remove(out_path)
os.remove(audio_file)
......
import asyncio
from fastapi import FastAPI, UploadFile, HTTPException, Form, File, APIRouter
from fastapi.responses import StreamingResponse
from loguru import logger
import threading
import gc
import torch
from pathlib import Path
import uuid
from typing import Optional
from .schema import (
TaskRequest,
TaskResponse,
ServiceStatusResponse,
StopTaskResponse,
)
from .service import FileService, DistributedInferenceService, VideoGenerationService
from .utils import ServiceStatus
class ApiServer:
def __init__(self):
self.app = FastAPI(title="LightX2V API", version="1.0.0")
self.file_service = None
self.inference_service = None
self.video_service = None
self.thread = None
self.stop_generation_event = threading.Event()
# Create routers
self.tasks_router = APIRouter(prefix="/v1/tasks", tags=["tasks"])
self.files_router = APIRouter(prefix="/v1/files", tags=["files"])
self.service_router = APIRouter(prefix="/v1/service", tags=["service"])
self._setup_routes()
def _setup_routes(self):
"""Setup routes"""
self._setup_task_routes()
self._setup_file_routes()
self._setup_service_routes()
# Register routers
self.app.include_router(self.tasks_router)
self.app.include_router(self.files_router)
self.app.include_router(self.service_router)
def _stream_file_response(self, file_path: Path, filename: str | None = None) -> StreamingResponse:
"""Common file streaming response method"""
assert self.file_service is not None, "File service is not initialized"
try:
resolved_path = file_path.resolve()
# Security check: ensure file is within allowed directory
if not str(resolved_path).startswith(str(self.file_service.output_video_dir.resolve())):
raise HTTPException(status_code=403, detail="Access to this file is not allowed")
if not resolved_path.exists() or not resolved_path.is_file():
raise HTTPException(status_code=404, detail=f"File not found: {file_path}")
file_size = resolved_path.stat().st_size
actual_filename = filename or resolved_path.name
# Set appropriate MIME type
mime_type = "application/octet-stream"
if actual_filename.lower().endswith((".mp4", ".avi", ".mov", ".mkv")):
mime_type = "video/mp4"
elif actual_filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
mime_type = "image/jpeg"
headers = {
"Content-Disposition": f'attachment; filename="{actual_filename}"',
"Content-Length": str(file_size),
"Accept-Ranges": "bytes",
}
def file_stream_generator(file_path: str, chunk_size: int = 1024 * 1024):
with open(file_path, "rb") as file:
while chunk := file.read(chunk_size):
yield chunk
return StreamingResponse(
file_stream_generator(str(resolved_path)),
media_type=mime_type,
headers=headers,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while processing file stream response: {e}")
raise HTTPException(status_code=500, detail="File transfer failed")
def _setup_task_routes(self):
@self.tasks_router.post("/", response_model=TaskResponse)
async def create_task(message: TaskRequest):
"""Create video generation task"""
try:
task_id = ServiceStatus.start_task(message)
# Use background thread to handle long-running tasks
self.stop_generation_event.clear()
self.thread = threading.Thread(
target=self._process_video_generation,
args=(message, self.stop_generation_event),
daemon=True,
)
self.thread.start()
return TaskResponse(
task_id=task_id,
task_status="processing",
save_video_path=message.save_video_path,
)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@self.tasks_router.post("/form", response_model=TaskResponse)
async def create_task_form(
image_file: UploadFile = File(...),
prompt: str = Form(default=""),
save_video_path: str = Form(default=""),
use_prompt_enhancer: bool = Form(default=False),
negative_prompt: str = Form(default=""),
num_fragments: int = Form(default=1),
infer_steps: int = Form(default=5),
target_video_length: int = Form(default=81),
seed: int = Form(default=42),
audio_file: Optional[UploadFile] = File(default=None),
video_duration: int = Form(default=5),
):
"""Create video generation task via form"""
# Process uploaded image file
image_path = ""
assert self.file_service is not None, "File service is not initialized"
if image_file and image_file.filename:
file_extension = Path(image_file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
image_path = self.file_service.input_image_dir / unique_filename
with open(image_path, "wb") as buffer:
content = await image_file.read()
buffer.write(content)
image_path = str(image_path)
audio_path = ""
if audio_file and audio_file.filename:
file_extension = Path(audio_file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
audio_path = self.file_service.input_audio_dir / unique_filename
with open(audio_path, "wb") as buffer:
content = await audio_file.read()
buffer.write(content)
audio_path = str(audio_path)
message = TaskRequest(
prompt=prompt,
use_prompt_enhancer=use_prompt_enhancer,
negative_prompt=negative_prompt,
image_path=image_path,
num_fragments=num_fragments,
save_video_path=save_video_path,
infer_steps=infer_steps,
target_video_length=target_video_length,
seed=seed,
audio_path=audio_path,
video_duration=video_duration,
)
try:
task_id = ServiceStatus.start_task(message)
self.stop_generation_event.clear()
self.thread = threading.Thread(
target=self._process_video_generation,
args=(message, self.stop_generation_event),
daemon=True,
)
self.thread.start()
return TaskResponse(
task_id=task_id,
task_status="processing",
save_video_path=message.save_video_path,
)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@self.tasks_router.get("/", response_model=dict)
async def list_tasks():
"""Get all task list"""
return ServiceStatus.get_all_tasks()
@self.tasks_router.get("/{task_id}/status")
async def get_task_status(task_id: str):
"""Get status of specified task"""
return ServiceStatus.get_status_task_id(task_id)
@self.tasks_router.get("/{task_id}/result")
async def get_task_result(task_id: str):
"""Get result video file of specified task"""
assert self.video_service is not None, "Video service is not initialized"
assert self.file_service is not None, "File service is not initialized"
try:
task_status = ServiceStatus.get_status_task_id(task_id)
if not task_status or task_status.get("status") != "completed":
raise HTTPException(status_code=404, detail="Task not completed or does not exist")
save_video_path = task_status.get("save_video_path")
if not save_video_path:
raise HTTPException(status_code=404, detail="Task result file does not exist")
full_path = Path(save_video_path)
if not full_path.is_absolute():
full_path = self.file_service.output_video_dir / save_video_path
return self._stream_file_response(full_path)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while getting task result: {e}")
raise HTTPException(status_code=500, detail="Failed to get task result")
@self.tasks_router.delete("/running", response_model=StopTaskResponse)
async def stop_running_task():
"""Stop currently running task"""
if self.thread and self.thread.is_alive():
try:
logger.info("Sending stop signal to running task thread...")
self.stop_generation_event.set()
self.thread.join(timeout=5)
if self.thread.is_alive():
logger.warning("Task thread did not stop within the specified time, manual intervention may be required.")
return StopTaskResponse(
stop_status="warning",
reason="Task thread did not stop within the specified time, manual intervention may be required.",
)
else:
self.thread = None
ServiceStatus.clean_stopped_task()
gc.collect()
torch.cuda.empty_cache()
logger.info("Task stopped successfully.")
return StopTaskResponse(stop_status="success", reason="Task stopped successfully.")
except Exception as e:
logger.error(f"Error occurred while stopping task: {str(e)}")
return StopTaskResponse(stop_status="error", reason=str(e))
else:
return StopTaskResponse(stop_status="do_nothing", reason="No running task found.")
def _setup_file_routes(self):
@self.files_router.get("/download/{file_path:path}")
async def download_file(file_path: str):
"""Download file"""
assert self.file_service is not None, "File service is not initialized"
try:
full_path = self.file_service.output_video_dir / file_path
return self._stream_file_response(full_path)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while processing file download request: {e}")
raise HTTPException(status_code=500, detail="File download failed")
def _setup_service_routes(self):
@self.service_router.get("/status", response_model=ServiceStatusResponse)
async def get_service_status():
"""Get service status"""
return ServiceStatus.get_status_service()
def _process_video_generation(self, message: TaskRequest, stop_event: threading.Event):
assert self.video_service is not None, "Video service is not initialized"
try:
if stop_event.is_set():
logger.info(f"Task {message.task_id} received stop signal, terminating")
ServiceStatus.record_failed_task(message, error="Task stopped")
return
# Use video generation service to process task
result = asyncio.run(self.video_service.generate_video(message))
except Exception as e:
logger.error(f"Task {message.task_id} processing failed: {str(e)}")
ServiceStatus.record_failed_task(message, error=str(e))
def initialize_services(self, cache_dir: Path, inference_service: DistributedInferenceService):
self.file_service = FileService(cache_dir)
self.inference_service = inference_service
self.video_service = VideoGenerationService(self.file_service, inference_service)
def get_app(self) -> FastAPI:
return self.app
import os
import torch
import torch.distributed as dist
from loguru import logger
class DistributedManager:
def __init__(self):
self.is_initialized = False
self.rank = 0
self.world_size = 1
def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
try:
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
dist.init_process_group(backend="nccl", init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)
if torch.cuda.is_available(): # type: ignore
torch.cuda.set_device(rank)
self.is_initialized = True
self.rank = rank
self.world_size = world_size
logger.info(f"Rank {rank}/{world_size - 1} distributed environment initialized successfully")
return True
except Exception as e:
logger.error(f"Rank {rank} distributed environment initialization failed: {str(e)}")
return False
def cleanup(self):
try:
if dist.is_initialized():
dist.destroy_process_group()
logger.info(f"Rank {self.rank} distributed environment cleaned up")
except Exception as e:
logger.error(f"Rank {self.rank} error occurred while cleaning up distributed environment: {str(e)}")
finally:
self.is_initialized = False
def barrier(self):
if self.is_initialized:
dist.barrier()
def is_rank_zero(self) -> bool:
return self.rank == 0
def broadcast_task_data(self, task_data=None): # type: ignore
if not self.is_initialized:
return None
if self.is_rank_zero():
if task_data is None:
stop_signal = torch.tensor([1], dtype=torch.int32, device=f"cuda:{self.rank}")
else:
stop_signal = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")
dist.broadcast(stop_signal, src=0)
if task_data is not None:
import pickle
task_bytes = pickle.dumps(task_data)
task_length = torch.tensor([len(task_bytes)], dtype=torch.int32, device=f"cuda:{self.rank}")
dist.broadcast(task_length, src=0)
task_tensor = torch.tensor(list(task_bytes), dtype=torch.uint8, device=f"cuda:{self.rank}")
dist.broadcast(task_tensor, src=0)
return task_data
else:
return None
else:
stop_signal = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")
dist.broadcast(stop_signal, src=0)
if stop_signal.item() == 1:
return None
else:
task_length = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")
dist.broadcast(task_length, src=0)
task_tensor = torch.empty(int(task_length.item()), dtype=torch.uint8, device=f"cuda:{self.rank}")
dist.broadcast(task_tensor, src=0)
import pickle
task_bytes = bytes(task_tensor.cpu().numpy())
task_data = pickle.loads(task_bytes)
return task_data
class DistributedWorker:
def __init__(self, rank: int, world_size: int, master_addr: str, master_port: str):
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port
self.dist_manager = DistributedManager()
def init(self) -> bool:
return self.dist_manager.init_process_group(self.rank, self.world_size, self.master_addr, self.master_port)
def cleanup(self):
self.dist_manager.cleanup()
def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs):
# Synchronize all processes
self.dist_manager.barrier()
if self.dist_manager.is_rank_zero():
result = {"task_id": task_id, "status": status, **kwargs}
result_queue.put(result)
logger.info(f"Task {task_id} {status}")
def create_distributed_worker(rank: int, world_size: int, master_addr: str, master_port: str) -> DistributedWorker:
return DistributedWorker(rank, world_size, master_addr, master_port)
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
from ..utils.generate_task_id import generate_task_id
class TaskRequest(BaseModel):
task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)")
prompt: str = Field("", description="Generation prompt")
use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
negative_prompt: str = Field("", description="Negative prompt")
image_path: str = Field("", description="Input image path")
num_fragments: int = Field(1, description="Number of fragments")
save_video_path: str = Field("", description="Save video path (optional, defaults to task_id.mp4)")
infer_steps: int = Field(5, description="Inference steps")
target_video_length: int = Field(81, description="Target video length")
seed: int = Field(42, description="Random seed")
audio_path: str = Field("", description="Input audio path (Wan-Audio)")
video_duration: int = Field(5, description="Video duration (Wan-Audio)")
def __init__(self, **data):
super().__init__(**data)
# If save_video_path is empty, use task_id.mp4
if not self.save_video_path:
self.save_video_path = f"{self.task_id}.mp4"
def get(self, key, default=None):
return getattr(self, key, default)
class TaskStatusMessage(BaseModel):
task_id: str = Field(..., description="Task ID")
class TaskResponse(BaseModel):
task_id: str
task_status: str
save_video_path: str
class TaskResultResponse(BaseModel):
status: str
task_status: str
filename: Optional[str] = None
file_size: Optional[int] = None
download_url: Optional[str] = None
message: str
class ServiceStatusResponse(BaseModel):
service_status: str
task_id: Optional[str] = None
start_time: Optional[datetime] = None
class StopTaskResponse(BaseModel):
stop_status: str
reason: str
import asyncio
import queue
import time
import uuid
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import httpx
import torch.multiprocessing as mp
from loguru import logger
from ..utils.set_config import set_config
from ..infer import init_runner
from .utils import ServiceStatus
from .schema import TaskRequest, TaskResponse
from .distributed_utils import create_distributed_worker
mp.set_start_method("spawn", force=True)
class FileService:
def __init__(self, cache_dir: Path):
self.cache_dir = cache_dir
self.input_image_dir = cache_dir / "inputs" / "imgs"
self.input_audio_dir = cache_dir / "inputs" / "audios"
self.output_video_dir = cache_dir / "outputs"
# Create directories
for directory in [
self.input_image_dir,
self.output_video_dir,
self.input_audio_dir,
]:
directory.mkdir(parents=True, exist_ok=True)
async def download_image(self, image_url: str) -> Path:
try:
async with httpx.AsyncClient(verify=False) as client:
response = await client.get(image_url)
if response.status_code != 200:
raise ValueError(f"Failed to download image from {image_url}")
image_name = Path(urlparse(image_url).path).name
if not image_name:
raise ValueError(f"Invalid image URL: {image_url}")
image_path = self.input_image_dir / image_name
image_path.parent.mkdir(parents=True, exist_ok=True)
with open(image_path, "wb") as f:
f.write(response.content)
return image_path
except Exception as e:
logger.error(f"Failed to download image: {e}")
raise
def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
file_extension = Path(filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
file_path = self.input_image_dir / unique_filename
with open(file_path, "wb") as f:
f.write(file_content)
return file_path
def get_output_path(self, save_video_path: str) -> Path:
video_path = Path(save_video_path)
if not video_path.is_absolute():
return self.output_video_dir / save_video_path
return video_path
def _distributed_inference_worker(rank, world_size, master_addr, master_port, args, task_queue, result_queue):
task_data = None
loop = None
worker = None
try:
logger.info(f"Process {rank}/{world_size - 1} initializing distributed inference service...")
# Create and initialize distributed worker process
worker = create_distributed_worker(rank, world_size, master_addr, master_port)
if not worker.init():
raise RuntimeError(f"Rank {rank} distributed environment initialization failed")
# Initialize configuration and model
config = set_config(args)
config["mode"] = "server"
logger.info(f"Rank {rank} config: {config}")
runner = init_runner(config)
logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")
# Create event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while True:
# Only rank=0 reads tasks from queue
if rank == 0:
try:
task_data = task_queue.get(timeout=1.0)
if task_data is None: # Stop signal
logger.info(f"Process {rank} received stop signal, exiting inference service")
# Broadcast stop signal to other processes
worker.dist_manager.broadcast_task_data(None)
break
# Broadcast task data to other processes
worker.dist_manager.broadcast_task_data(task_data)
except queue.Empty:
# Queue is empty, continue waiting
continue
else:
# Non-rank=0 processes receive task data from rank=0
task_data = worker.dist_manager.broadcast_task_data()
if task_data is None: # Stop signal
logger.info(f"Process {rank} received stop signal, exiting inference service")
break
# All processes handle the task
if task_data is not None:
logger.info(f"Process {rank} received inference task: {task_data['task_id']}")
try:
# Set inputs and run inference
runner.set_inputs(task_data) # type: ignore
loop.run_until_complete(runner.run_pipeline())
# Synchronize and report results
worker.sync_and_report(
task_data["task_id"],
"success",
result_queue,
save_video_path=task_data["save_video_path"],
message="Inference completed",
)
except Exception as e:
logger.error(f"Process {rank} error occurred while processing task: {str(e)}")
# Synchronize and report error
worker.sync_and_report(
task_data.get("task_id", "unknown"),
"failed",
result_queue,
error=str(e),
message=f"Inference failed: {str(e)}",
)
except KeyboardInterrupt:
logger.info(f"Process {rank} received KeyboardInterrupt, gracefully exiting")
except Exception as e:
logger.error(f"Distributed inference service process {rank} startup failed: {str(e)}")
if rank == 0:
error_result = {
"task_id": "startup",
"status": "startup_failed",
"error": str(e),
"message": f"Inference service startup failed: {str(e)}",
}
result_queue.put(error_result)
finally:
# Clean up resources
try:
if loop and not loop.is_closed():
loop.close()
except: # noqa: E722
pass
try:
if worker:
worker.cleanup()
except: # noqa: E722
pass
class DistributedInferenceService:
def __init__(self):
self.task_queue = None
self.result_queue = None
self.processes = []
self.is_running = False
def start_distributed_inference(self, args) -> bool:
if self.is_running:
logger.warning("Distributed inference service is already running")
return True
nproc_per_node = args.nproc_per_node
if nproc_per_node <= 0:
logger.error("nproc_per_node must be greater than 0")
return False
try:
import random
master_addr = "127.0.0.1"
master_port = str(random.randint(20000, 29999))
logger.info(f"Distributed inference service Master Addr: {master_addr}, Master Port: {master_port}")
# Create shared queues
self.task_queue = mp.Queue()
self.result_queue = mp.Queue()
# Start processes
for rank in range(nproc_per_node):
p = mp.Process(
target=_distributed_inference_worker,
args=(
rank,
nproc_per_node,
master_addr,
master_port,
args,
self.task_queue,
self.result_queue,
),
daemon=True,
)
p.start()
self.processes.append(p)
self.is_running = True
logger.info(f"Distributed inference service started successfully with {nproc_per_node} processes")
return True
except Exception as e:
logger.exception(f"Error occurred while starting distributed inference service: {str(e)}")
self.stop_distributed_inference()
return False
def stop_distributed_inference(self):
if not self.is_running:
return
try:
logger.info(f"Stopping {len(self.processes)} distributed inference service processes...")
# Send stop signal
if self.task_queue:
for _ in self.processes:
self.task_queue.put(None)
# Wait for processes to end
for p in self.processes:
try:
p.join(timeout=10)
if p.is_alive():
logger.warning(f"Process {p.pid} did not end within the specified time, forcing termination...")
p.terminate()
p.join(timeout=5)
except: # noqa: E722
pass
logger.info("All distributed inference service processes have stopped")
except Exception as e:
logger.error(f"Error occurred while stopping distributed inference service: {str(e)}")
finally:
# Clean up resources
self._clean_queues()
self.processes = []
self.task_queue = None
self.result_queue = None
self.is_running = False
def _clean_queues(self):
for queue_obj in [self.task_queue, self.result_queue]:
if queue_obj:
try:
while not queue_obj.empty():
queue_obj.get_nowait()
except: # noqa: E722
pass
def submit_task(self, task_data: dict) -> bool:
if not self.is_running or not self.task_queue:
logger.error("Distributed inference service is not started")
return False
try:
self.task_queue.put(task_data)
return True
except Exception as e:
logger.error(f"Failed to submit task: {str(e)}")
return False
def wait_for_result(self, task_id: str, timeout: int = 300) -> Optional[dict]:
if not self.is_running or not self.result_queue:
return None
start_time = time.time()
while time.time() - start_time < timeout:
try:
result = self.result_queue.get(timeout=1.0)
if result.get("task_id") == task_id:
return result
else:
# Not the result for current task, put back in queue
self.result_queue.put(result)
time.sleep(0.1)
except queue.Empty:
continue
return None
class VideoGenerationService:
def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
self.file_service = file_service
self.inference_service = inference_service
async def generate_video(self, message: TaskRequest) -> TaskResponse:
try:
# Process image path
task_data = {
"task_id": message.task_id,
"prompt": message.prompt,
"use_prompt_enhancer": message.use_prompt_enhancer,
"negative_prompt": message.negative_prompt,
"image_path": message.image_path,
"num_fragments": message.num_fragments,
"save_video_path": message.save_video_path,
"infer_steps": message.infer_steps,
"target_video_length": message.target_video_length,
"seed": message.seed,
"audio_path": message.audio_path,
"video_duration": message.video_duration,
}
# Process network image
if message.image_path.startswith("http"):
image_path = await self.file_service.download_image(message.image_path)
task_data["image_path"] = str(image_path)
# Process output path
save_video_path = self.file_service.get_output_path(message.save_video_path)
task_data["save_video_path"] = str(save_video_path)
# Submit task to distributed inference service
if not self.inference_service.submit_task(task_data):
raise RuntimeError("Distributed inference service is not started")
# Wait for result
result = self.inference_service.wait_for_result(message.task_id)
if result is None:
raise RuntimeError("Task processing timeout")
if result.get("status") == "success":
ServiceStatus.complete_task(message)
return TaskResponse(
task_id=message.task_id,
task_status="completed",
save_video_path=str(save_video_path),
)
else:
error_msg = result.get("error", "Inference failed")
ServiceStatus.record_failed_task(message, error=error_msg)
raise RuntimeError(error_msg)
except Exception as e:
logger.error(f"Task {message.task_id} processing failed: {str(e)}")
ServiceStatus.record_failed_task(message, error=str(e))
raise
import sys
import psutil
import signal
import base64
from PIL import Image
from loguru import logger
from typing import Optional
from datetime import datetime
from pydantic import BaseModel
import threading
import torch
import io
class ProcessManager:
@staticmethod
def kill_all_related_processes():
current_process = psutil.Process()
children = current_process.children(recursive=True)
for child in children:
try:
child.kill()
except Exception as e:
logger.info(f"Failed to kill child process {child.pid}: {e}")
try:
current_process.kill()
except Exception as e:
logger.info(f"Failed to kill main process: {e}")
@staticmethod
def signal_handler(sig, frame):
logger.info("\nReceived Ctrl+C, shutting down all related processes...")
ProcessManager.kill_all_related_processes()
sys.exit(0)
@staticmethod
def register_signal_handler():
signal.signal(signal.SIGINT, ProcessManager.signal_handler)
class TaskStatusMessage(BaseModel):
task_id: str
class ServiceStatus:
_lock = threading.Lock()
_current_task = None
_result_store = {}
@classmethod
def start_task(cls, message):
with cls._lock:
if cls._current_task is not None:
raise RuntimeError("Service busy")
if message.task_id in cls._result_store:
raise RuntimeError(f"Task ID {message.task_id} already exists")
cls._current_task = {"message": message, "start_time": datetime.now()}
return message.task_id
@classmethod
def complete_task(cls, message):
with cls._lock:
if cls._current_task:
cls._result_store[message.task_id] = {
"success": True,
"message": message,
"start_time": cls._current_task["start_time"],
"completion_time": datetime.now(),
"save_video_path": message.save_video_path,
}
cls._current_task = None
@classmethod
def record_failed_task(cls, message, error: Optional[str] = None):
with cls._lock:
if cls._current_task:
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error, "save_video_path": message.save_video_path}
cls._current_task = None
@classmethod
def clean_stopped_task(cls):
with cls._lock:
if cls._current_task:
message = cls._current_task["message"]
error = "Task stopped by user"
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error, "save_video_path": message.save_video_path}
cls._current_task = None
@classmethod
def get_status_task_id(cls, task_id: str):
with cls._lock:
if cls._current_task and cls._current_task["message"].task_id == task_id:
return {"status": "processing", "task_id": task_id}
if task_id in cls._result_store:
result = cls._result_store[task_id]
return {
"status": "completed" if result["success"] else "failed",
"task_id": task_id,
"success": result["success"],
"start_time": result["start_time"],
"completion_time": result.get("completion_time"),
"error": result.get("error"),
"save_video_path": result.get("save_video_path"),
}
return {"status": "not_found", "task_id": task_id}
@classmethod
def get_status_service(cls):
with cls._lock:
if cls._current_task:
return {"service_status": "busy", "task_id": cls._current_task["message"].task_id, "start_time": cls._current_task["start_time"]}
return {"service_status": "idle"}
@classmethod
def get_all_tasks(cls):
with cls._lock:
return cls._result_store
class TensorTransporter:
def __init__(self):
self.buffer = io.BytesIO()
def to_device(self, data, device):
if isinstance(data, dict):
return {key: self.to_device(value, device) for key, value in data.items()}
elif isinstance(data, list):
return [self.to_device(item, device) for item in data]
elif isinstance(data, torch.Tensor):
return data.to(device)
else:
return data
def prepare_tensor(self, data) -> str:
self.buffer.seek(0)
self.buffer.truncate()
torch.save(self.to_device(data, "cpu"), self.buffer)
return base64.b64encode(self.buffer.getvalue()).decode("utf-8")
def load_tensor(self, tensor_base64: str, device="cuda"):
tensor_bytes = base64.b64decode(tensor_base64)
with io.BytesIO(tensor_bytes) as buffer:
return self.to_device(torch.load(buffer), device)
class ImageTransporter:
def __init__(self):
self.buffer = io.BytesIO()
def prepare_image(self, image: Image.Image):
self.buffer.seek(0)
self.buffer.truncate()
image.save(self.buffer, format="PNG")
return base64.b64encode(self.buffer.getvalue()).decode("utf-8")
def load_image(self, image_base64: bytes) -> Image.Image:
image_bytes = base64.b64decode(image_base64)
with io.BytesIO(image_bytes) as buffer:
return Image.open(buffer).convert("RGB")
import time
import torch
from contextlib import ContextDecorator
import asyncio
from functools import wraps
from lightx2v.utils.envs import *
from loguru import logger
class _ProfilingContext(ContextDecorator):
class _ProfilingContext:
def __init__(self, name):
self.name = name
self.rank_info = ""
......@@ -31,8 +32,44 @@ class _ProfilingContext(ContextDecorator):
logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
return False
async def __aenter__(self):
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.start_time = time.perf_counter()
return self
class _NullContext(ContextDecorator):
async def __aexit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
if torch.cuda.is_available():
peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # 转换为GB
logger.info(f"{self.rank_info}Function '{self.name}' Peak Memory: {peak_memory:.2f} GB")
else:
logger.info(f"{self.rank_info}Function '{self.name}' executed without GPU.")
elapsed = time.perf_counter() - self.start_time
logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
return False
def __call__(self, func):
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
async with self:
return await func(*args, **kwargs)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)
return sync_wrapper
class _NullContext:
# Context manager without decision branch logic overhead
def __init__(self, *args, **kwargs):
pass
......@@ -43,6 +80,15 @@ class _NullContext(ContextDecorator):
def __exit__(self, *args):
return False
async def __aenter__(self):
return self
async def __aexit__(self, *args):
return False
def __call__(self, func):
return func
ProfilingContext = _ProfilingContext
ProfilingContext4Debug = _ProfilingContext if CHECK_ENABLE_PROFILING_DEBUG() else _NullContext
......@@ -2,13 +2,13 @@ import requests
from loguru import logger
response = requests.get("http://localhost:8000/v1/local/video/generate/service_status")
response = requests.get("http://localhost:8000/v1/service/status")
logger.info(response.json())
response = requests.get("http://localhost:8000/v1/local/video/generate/get_all_tasks")
response = requests.get("http://localhost:8000/v1/tasks/")
logger.info(response.json())
response = requests.post("http://localhost:8000/v1/local/video/generate/task_status", json={"task_id": "test_task_001"})
response = requests.get("http://localhost:8000/v1/tasks/test_task_001/status")
logger.info(response.json())
import requests
from loguru import logger
import random
import string
import time
from datetime import datetime
# same as lightx2v/utils/generate_task_id.py
# from lightx2v.utils.generate_task_id import generate_task_id
def generate_task_id():
"""
Generate a random task ID in the format XXXX-XXXX-XXXX-XXXX-XXXX.
Features:
1. Does not modify the global random state.
2. Each X is an uppercase letter or digit (0-9).
3. Combines time factors to ensure high randomness.
For example: N1PQ-PRM5-N1BN-Z3S1-BGBJ
"""
# Save the current random state (does not affect external randomness)
original_state = random.getstate()
try:
# Define character set (uppercase letters + digits)
characters = string.ascii_uppercase + string.digits
# Create an independent random instance
local_random = random.Random(time.perf_counter_ns())
# Generate 5 groups of 4-character random strings
groups = []
for _ in range(5):
# Mix new time factor for each group
time_mix = int(datetime.now().timestamp())
local_random.seed(time_mix + local_random.getstate()[1][0] + time.perf_counter_ns())
groups.append("".join(local_random.choices(characters, k=4)))
return "-".join(groups)
finally:
# Restore the original random state
random.setstate(original_state)
if __name__ == "__main__":
url = "http://localhost:8000/v1/local/video/generate"
url = "http://localhost:8000/v1/tasks/"
message = {
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4", # It is best to set it to an absolute path.
}
logger.info(f"message: {message}")
......
......@@ -2,14 +2,12 @@ import requests
from loguru import logger
url = "http://localhost:8000/v1/local/video/generate"
url = "http://localhost:8000/v1/tasks/"
message = {
"task_id": "test_task_001",
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_enhanced.mp4", # It is best to set it to an absolute path.
"use_prompt_enhancer": True,
}
......
import requests
from loguru import logger
import random
import string
import time
from datetime import datetime
# same as lightx2v/utils/generate_task_id.py
# from lightx2v.utils.generate_task_id import generate_task_id
def generate_task_id():
"""
Generate a random task ID in the format XXXX-XXXX-XXXX-XXXX-XXXX.
Features:
1. Does not modify the global random state.
2. Each X is an uppercase letter or digit (0-9).
3. Combines time factors to ensure high randomness.
For example: N1PQ-PRM5-N1BN-Z3S1-BGBJ
"""
# Save the current random state (does not affect external randomness)
original_state = random.getstate()
try:
# Define character set (uppercase letters + digits)
characters = string.ascii_uppercase + string.digits
# Create an independent random instance
local_random = random.Random(time.perf_counter_ns())
# Generate 5 groups of 4-character random strings
groups = []
for _ in range(5):
# Mix new time factor for each group
time_mix = int(datetime.now().timestamp())
local_random.seed(time_mix + local_random.getstate()[1][0] + time.perf_counter_ns())
groups.append("".join(local_random.choices(characters, k=4)))
return "-".join(groups)
finally:
# Restore the original random state
random.setstate(original_state)
if __name__ == "__main__":
url = "http://localhost:8000/v1/local/video/generate"
url = "http://localhost:8000/v1/tasks/"
message = {
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "./assets/inputs/imgs/img_0.jpg",
"save_video_path": "./output_lightx2v_wan_i2v_t02.mp4", # It is best to set it to an absolute path.
"image_path": "http://xxx.xxx.xxx.xxx/img_0.jpg", # 图片地址
}
logger.info(f"message: {message}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment