Commit 398b598a authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Refactor/api (#94)

* fix: correct frequency computation in WanTransformerInfer

* refactor: restructure API server and distributed inference services

- Removed the old api_server_dist.py file and integrated its functionality into a new modular structure.
- Created a new ApiServer class to handle API routes and services.
- Introduced DistributedInferenceService and FileService for better separation of concerns.
- Updated the main entry point to initialize and run the new API server with distributed inference capabilities.
- Added schema definitions for task requests and responses to improve data handling.
- Enhanced error handling and logging throughout the services.

* refactor: enhance API structure and file handling in server

- Introduced APIRouter for modular route management in the ApiServer class.
- Updated task creation and file download endpoints to improve clarity and functionality.
- Implemented a new method for streaming file responses with proper MIME type handling.
- Refactored task request schema to auto-generate task IDs and handle optional video save paths.
- Improved error handling and logging for better debugging and user feedback.

* feat: add configurable parameters for video generation

- Introduced new parameters: infer_steps, target_video_length, and seed to the API and task request schema.
- Updated DefaultRunner and VideoGenerationService to handle the new parameters for enhanced video generation control.
- Improved default values for parameters to ensure consistent behavior.

* refactor: enhance profiling context for async support

* refactor: improve signal handling in API server

* feat: enhance video generation capabilities with audio support

* refactor: improve subprocess call for audio-video merging in wan_audio_runner.py

* refactor: enhance API server argument parsing and improve code readability

* refactor: enhance logging and improve code comments for clarity

* refactor: update response model for task listing endpoint to return a dictionary

* docs: update API endpoints and improve documentation clarity

* refactor: update API endpoints in scripts for task management and remove unused code

* fix: pre-commit
parent 1e422663
......@@ -2,7 +2,6 @@
lightx2v provides asynchronous service functionality. The code entry point is [here](https://github.com/ModelTC/lightx2v/blob/main/lightx2v/api_server.py)
### Start the Service
```shell
......@@ -12,36 +11,26 @@ bash scripts/start_server.sh
The `--port 8000` option means the service will bind to port `8000` on the local machine. You can change this as needed.
### Client Sends Request
```shell
python scripts/post.py
```
The service endpoint is: `/v1/local/video/generate`
The service endpoint is: `/v1/tasks/`
The `message` parameter in `scripts/post.py` is as follows:
```python
message = {
"task_id": generate_task_id(),
"task_id_must_unique": True,
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4",
}
```
1. `prompt`, `negative_prompt`, and `image_path` are basic inputs for video generation. `image_path` can be an empty string, indicating no image input is needed.
2. `save_video_path` specifies the path where the generated video will be saved on the server. The relative path is relative to the server's startup directory. It is recommended to set an absolute path according to your environment.
3. `task_id` is the ID of the task, which is a string. You can customize a string or use the `generate_task_id()` function to generate a random string. The task ID is used to distinguish between different video generation tasks.
4. `task_id_must_unique` indicates whether each `task_id` must be unique. If set to `False`, there is no such restriction. In this case, if duplicate `task_id`s are sent, the server's `task` record will be overwritten by the newer task with the same `task_id`. If you do not need to keep a record of all tasks for querying, you can set this to `False`.
### Client Checks Server Status
......@@ -51,12 +40,11 @@ python scripts/check_status.py
The service endpoints include:
1. `/v1/local/video/generate/service_status` is used to check the status of the service. It returns whether the service is `busy` or `idle`. The service only accepts new requests when it is `idle`.
2. `/v1/local/video/generate/get_all_tasks` is used to get all tasks received and completed by the server.
1. `/v1/service/status` is used to check the status of the service. It returns whether the service is `busy` or `idle`. The service only accepts new requests when it is `idle`.
3. `/v1/local/video/generate/task_status` is used to get the status of a specified `task_id`. It returns whether the task is `processing` or `completed`.
2. `/v1/tasks/` is used to get all tasks received and completed by the server.
3. `/v1/tasks/{task_id}/status` is used to get the status of a specified `task_id`. It returns whether the task is `processing` or `completed`.
### Client Stops the Current Task on the Server at Any Time
......@@ -64,7 +52,7 @@ The service endpoints include:
python scripts/stop_running_task.py
```
The service endpoint is: `/v1/local/video/generate/stop_running_task`
The service endpoint is: `/v1/tasks/running`
After terminating the task, the server will not exit but will return to waiting for new requests.
......@@ -78,7 +66,6 @@ num_gpus=8 bash scripts/start_multi_servers.sh
Where `num_gpus` indicates the number of services to start; the services will run on consecutive ports starting from `--start_port`.
### Scheduling Between Multiple Services
```shell
......@@ -86,3 +73,16 @@ python scripts/post_multi_servers.py
```
`post_multi_servers.py` will schedule multiple client requests based on the idle status of the services.
### API Endpoints Summary
| Endpoint | Method | Description |
|----------|--------|-------------|
| `/v1/tasks/` | POST | Create video generation task |
| `/v1/tasks/form` | POST | Create video generation task via form |
| `/v1/tasks/` | GET | Get all task list |
| `/v1/tasks/{task_id}/status` | GET | Get status of specified task |
| `/v1/tasks/{task_id}/result` | GET | Get result video file of specified task |
| `/v1/tasks/running` | DELETE | Stop currently running task |
| `/v1/files/download/{file_path}` | GET | Download file |
| `/v1/service/status` | GET | Get service status |
# 如何启动服务
lightx2v提供了异步服务功能,代码入口处在[这里](https://github.com/ModelTC/lightx2v/blob/main/lightx2v/api_server.py)
lightx2v 提供异步服务功能。代码入口点在 [这里](https://github.com/ModelTC/lightx2v/blob/main/lightx2v/api_server.py)
### 启动服务
......@@ -10,8 +9,7 @@ lightx2v提供了异步服务功能,代码入口处在[这里](https://github.
bash scripts/start_server.sh
```
其中的`--port 8000`表示服务绑定在本机的`8000`端口上,可以自行修改
`--port 8000` 选项表示服务将绑定到本地机器的 `8000` 端口。您可以根据需要更改此端口。
### 客户端发送请求
......@@ -19,65 +17,54 @@ bash scripts/start_server.sh
python scripts/post.py
```
服务的接口是`/v1/local/video/generate`
服务端点`/v1/tasks/`
`scripts/post.py`中的`message`参数如下:
`scripts/post.py` 中的 `message` 参数如下:
```python
message = {
"task_id": generate_task_id(),
"task_id_must_unique": True,
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4",
"image_path": ""
}
```
1. `prompt`, `negative_prompt`, `image_path`是一些基础的视频生成的输入,`image_path`可以为空字符,表示不需要图片输入
2. `save_video_path`表示服务端生成的视频的路径,相对路径是相对服务端的启动路径,建议根据你自己的环境,设置一个绝对路径。
3. `task_id`表示该任务的id,格式是一个字符串。可以自定义个字符串,也可以调用`generate_task_id()`函数生成一个随机的字符串。任务的id用来区分不同的视频生成任务。
4. `task_id_must_unique`表示是否要求每个`task_id`是独一无二的,即不能发有重复的`task_id`。如果是`False`,就没有这个强制要求,此时如果发送了重复的`task_id`,服务端的`task`记录将会被相同`task_id`的较新的`task`覆盖掉。如果不需要记录所有的`task`以用于查询,那这里就可以设置成`False`
1. `prompt``negative_prompt``image_path` 是视频生成的基本输入。`image_path` 可以是空字符串,表示不需要图像输入。
### 客户端获取服务端的状态
### 客户端检查服务器状态
```shell
python scripts/check_status.py
```
其中服务的接口有
服务端点包括
1. `/v1/local/video/generate/service_status`用于检查服务状态,可以返回得到服务是`busy`还是`idle`只有在`idle`状态,该服务才会接收新的请求。
1. `/v1/service/status` 用于检查服务状态。返回服务是 `busy` 还是 `idle`。服务只有在 `idle` 时才接受新请求。
2. `/v1/local/video/generate/get_all_tasks`用于获取服务接收到的且已完成的所有任务。
2. `/v1/tasks/` 用于获取服务接收完成的所有任务。
3. `/v1/local/video/generate/task_status`用于获取指定`task_id`的状态,可以返回得到该任务是`processing`还是`completed`
3. `/v1/tasks/{task_id}/status` 用于获取指定 `task_id` 的任务状态。返回任务是 `processing` 还是 `completed`
### 客户端随时终止服务端当前的任务
### 客户端随时停止服务器上的当前任务
```shell
python scripts/stop_running_task.py
```
服务的接口是`/v1/local/video/generate/stop_running_task`
服务端点`/v1/tasks/running`
终止任务后,服务端并不会退出服务,而是回到等待接收新请求的状态。
终止任务后,服务不会退出,而是返回等待新请求的状态。
### 单节点同时起多个服务
### 在单个节点上启动多个服务
在单节点上,可以多次使用`scripts/start_server.sh`同时起多个服务(注意同一个ip下的端口号,不同服务之间要保持不同),也可以直接通过`scripts/start_multi_servers.sh`一次性起多个服务:
在单节点上,可以使用 `scripts/start_server.sh` 启动多个服务(注意同一 IP 下的端口号必须不同),或者可以使用 `scripts/start_multi_servers.sh` 同时启动多个服务:
```shell
num_gpus=8 bash scripts/start_multi_servers.sh
```
其中`num_gpus`表示启动的服务数;服务将在`--start_port`开始的连续`num_gpus`个端口上运行。
其中 `num_gpus` 表示要启动的服务数量;服务将从 `--start_port` 开始在连续端口上运行。
### 多个服务之间的调度
......@@ -85,4 +72,17 @@ num_gpus=8 bash scripts/start_multi_servers.sh
python scripts/post_multi_servers.py
```
`post_multi_servers.py`会根据服务的空闲状态,调度客户端发起的多个请求。
`post_multi_servers.py` 将根据服务的空闲状态调度多个客户端请求。
### API 端点总结
| 端点 | 方法 | 描述 |
|------|------|------|
| `/v1/tasks/` | POST | 创建视频生成任务 |
| `/v1/tasks/form` | POST | 通过表单创建视频生成任务 |
| `/v1/tasks/` | GET | 获取所有任务列表 |
| `/v1/tasks/{task_id}/status` | GET | 获取指定任务状态 |
| `/v1/tasks/{task_id}/result` | GET | 获取指定任务的结果视频文件 |
| `/v1/tasks/running` | DELETE | 停止当前运行的任务 |
| `/v1/files/download/{file_path}` | GET | 下载文件 |
| `/v1/service/status` | GET | 获取服务状态 |
import asyncio
import argparse
from fastapi import FastAPI
from pydantic import BaseModel
import sys
import signal
import atexit
from pathlib import Path
from loguru import logger
import uvicorn
import json
from typing import Optional
from datetime import datetime
import threading
import ctypes
import gc
import torch
from lightx2v.utils.profiler import ProfilingContext
from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager
from lightx2v.server.api import ApiServer
from lightx2v.server.service import DistributedInferenceService
from lightx2v.server.utils import ProcessManager
# =========================
# FastAPI Related Code
# =========================
def create_signal_handler(inference_service: DistributedInferenceService):
"""Create unified signal handler function"""
runner = None
thread = None
app = FastAPI()
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
prompt: str
use_prompt_enhancer: bool = False
negative_prompt: str = ""
image_path: str = ""
num_fragments: int = 1
save_video_path: str
def get(self, key, default=None):
return getattr(self, key, default)
class ApiServerServiceStatus(BaseServiceStatus):
pass
def local_video_generate(message: Message):
def signal_handler(signum, frame):
logger.info(f"Received signal {signum}, gracefully shutting down...")
try:
global runner
runner.set_inputs(message)
logger.info(f"message: {message}")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(runner.run_pipeline())
finally:
loop.close()
ApiServerServiceStatus.complete_task(message)
if inference_service.is_running:
inference_service.stop_distributed_inference()
except Exception as e:
logger.error(f"task_id {message.task_id} failed: {str(e)}")
ApiServerServiceStatus.record_failed_task(message, error=str(e))
logger.error(f"Error occurred while shutting down distributed inference service: {str(e)}")
finally:
sys.exit(0)
@app.post("/v1/local/video/generate")
async def v1_local_video_generate(message: Message):
try:
task_id = ApiServerServiceStatus.start_task(message)
# Use background threads to perform long-running tasks
global thread
thread = threading.Thread(target=local_video_generate, args=(message,), daemon=True)
thread.start()
return {"task_id": task_id, "task_status": "processing", "save_video_path": message.save_video_path}
except RuntimeError as e:
return {"error": str(e)}
return signal_handler
@app.get("/v1/local/video/generate/service_status")
async def get_service_status():
return ApiServerServiceStatus.get_status_service()
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_cls",
type=str,
required=True,
choices=[
"wan2.1",
"hunyuan",
"wan2.1_causvid",
"wan2.1_skyreels_v2_df",
"wan2.1_audio",
],
default="hunyuan",
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--split", action="store_true")
parser.add_argument("--lora_path", type=str, required=False, default=None)
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node for distributed inference")
@app.get("/v1/local/video/generate/get_all_tasks")
async def get_all_tasks():
return ApiServerServiceStatus.get_all_tasks()
args = parser.parse_args()
logger.info(f"args: {args}")
cache_dir = Path(__file__).parent.parent / ".cache"
inference_service = DistributedInferenceService()
@app.post("/v1/local/video/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return ApiServerServiceStatus.get_status_task_id(message.task_id)
api_server = ApiServer()
api_server.initialize_services(cache_dir, inference_service)
signal_handler = create_signal_handler(inference_service)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
def _async_raise(tid, exctype):
"""Force thread tid to raise exception exctype"""
res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(exctype))
if res == 0:
raise ValueError("Invalid thread ID")
elif res > 1:
ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), 0)
raise SystemError("PyThreadState_SetAsyncExc failed")
logger.info("Starting distributed inference service...")
success = inference_service.start_distributed_inference(args)
if not success:
logger.error("Failed to start distributed inference service, exiting program")
sys.exit(1)
atexit.register(inference_service.stop_distributed_inference)
@app.get("/v1/local/video/generate/stop_running_task")
async def stop_running_task():
global thread
if thread and thread.is_alive():
try:
_async_raise(thread.ident, SystemExit)
thread.join()
# Clean up the thread reference
thread = None
ApiServerServiceStatus.clean_stopped_task()
gc.collect()
torch.cuda.empty_cache()
return {"stop_status": "success", "reason": "Task stopped successfully."}
logger.info(f"Starting FastAPI server on port: {args.port}")
uvicorn.run(
api_server.get_app(),
host="0.0.0.0",
port=args.port,
reload=False,
workers=1,
)
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down service...")
except Exception as e:
return {"stop_status": "error", "reason": str(e)}
else:
return {"stop_status": "do_nothing", "reason": "No running task found."}
logger.error(f"Error occurred while running FastAPI server: {str(e)}")
finally:
inference_service.stop_distributed_inference()
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_distill", "wan2.1_causvid", "wan2.1_skyreels_v2_df", "cogvideox"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--split", action="store_true")
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
logger.info(f"args: {args}")
with ProfilingContext("Init Server Cost"):
config = set_config(args)
config["mode"] = "split_server" if args.split else "server"
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config)
uvicorn.run(app, host="0.0.0.0", port=config.port, reload=False, workers=1)
main()
import asyncio
import argparse
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from loguru import logger
import uvicorn
import threading
import ctypes
import gc
import torch
import os
import sys
import time
import torch.multiprocessing as mp
import queue
import torch.distributed as dist
import random
import uuid
from lightx2v.utils.set_config import set_config
from lightx2v.infer import init_runner
from lightx2v.utils.service_utils import TaskStatusMessage, BaseServiceStatus, ProcessManager
import httpx
from pathlib import Path
from urllib.parse import urlparse
# =========================
# FastAPI Related Code
# =========================
runner = None
thread = None
app = FastAPI()
CACHE_DIR = Path(__file__).parent.parent / "cache"
INPUT_IMAGE_DIR = CACHE_DIR / "inputs" / "imgs"
OUTPUT_VIDEO_DIR = CACHE_DIR / "outputs"
for directory in [INPUT_IMAGE_DIR, OUTPUT_VIDEO_DIR]:
directory.mkdir(parents=True, exist_ok=True)
class Message(BaseModel):
task_id: str
task_id_must_unique: bool = False
prompt: str
use_prompt_enhancer: bool = False
negative_prompt: str = ""
image_path: str = ""
num_fragments: int = 1
save_video_path: str
def get(self, key, default=None):
return getattr(self, key, default)
class ApiServerServiceStatus(BaseServiceStatus):
pass
def download_image(image_url: str):
with httpx.Client(verify=False) as client:
response = client.get(image_url)
image_name = Path(urlparse(image_url).path).name
if not image_name:
raise ValueError(f"Invalid image URL: {image_url}")
image_path = INPUT_IMAGE_DIR / image_name
image_path.parent.mkdir(parents=True, exist_ok=True)
if response.status_code == 200:
with open(image_path, "wb") as f:
f.write(response.content)
return image_path
else:
raise ValueError(f"Failed to download image from {image_url}")
stop_generation_event = threading.Event()
def local_video_generate(message: Message, stop_event: threading.Event):
try:
global input_queues, output_queues
if input_queues is None or output_queues is None:
logger.error("分布式推理服务未启动")
ApiServerServiceStatus.record_failed_task(message, error="分布式推理服务未启动")
return
logger.info(f"提交任务到分布式推理服务: {message.task_id}")
# 将任务数据转换为字典
task_data = {
"task_id": message.task_id,
"prompt": message.prompt,
"use_prompt_enhancer": message.use_prompt_enhancer,
"negative_prompt": message.negative_prompt,
"image_path": message.image_path,
"num_fragments": message.num_fragments,
"save_video_path": message.save_video_path,
}
if message.image_path.startswith("http"):
image_path = download_image(message.image_path)
task_data["image_path"] = str(image_path)
save_video_path = Path(message.save_video_path)
if not save_video_path.is_absolute():
task_data["save_video_path"] = str(OUTPUT_VIDEO_DIR / message.save_video_path)
# 将任务放入输入队列
for input_queue in input_queues:
input_queue.put(task_data)
# 等待结果
timeout = 300 # 5分钟超时
start_time = time.time()
while time.time() - start_time < timeout:
if stop_event.is_set():
logger.info(f"任务 {message.task_id} 收到停止信号,正在终止")
ApiServerServiceStatus.record_failed_task(message, error="任务被停止")
return
try:
result = output_queues[0].get(timeout=1.0)
# 检查是否是当前任务的结果
if result.get("task_id") == message.task_id:
if result.get("status") == "success":
logger.info(f"任务 {message.task_id} 推理成功")
ApiServerServiceStatus.complete_task(message)
else:
error_msg = result.get("error", "推理失败")
logger.error(f"任务 {message.task_id} 推理失败: {error_msg}")
ApiServerServiceStatus.record_failed_task(message, error=error_msg)
return
else:
# 不是当前任务的结果,放回队列
# 注意:如果并发任务很多,这种做法可能导致当前任务的结果被延迟。
# 更健壮的并发结果处理需要更复杂的设计,例如每个任务有独立的输出队列。
output_queues[0].put(result)
time.sleep(0.1)
except queue.Empty:
# 队列为空,继续等待
continue
# 超时
logger.error(f"任务 {message.task_id} 处理超时")
ApiServerServiceStatus.record_failed_task(message, error="处理超时")
except Exception as e:
logger.error(f"任务 {message.task_id} 处理失败: {str(e)}")
ApiServerServiceStatus.record_failed_task(message, error=str(e))
@app.post("/v1/local/video/generate")
async def v1_local_video_generate(message: Message):
try:
task_id = ApiServerServiceStatus.start_task(message)
# Use background threads to perform long-running tasks
global thread, stop_generation_event
stop_generation_event.clear()
thread = threading.Thread(
target=local_video_generate,
args=(
message,
stop_generation_event,
),
daemon=True,
)
thread.start()
return {"task_id": task_id, "task_status": "processing", "save_video_path": message.save_video_path}
except RuntimeError as e:
return {"error": str(e)}
@app.post("/v1/local/video/generate_form")
async def v1_local_video_generate_form(
task_id: str,
prompt: str,
save_video_path: str,
task_id_must_unique: bool = False,
use_prompt_enhancer: bool = False,
negative_prompt: str = "",
num_fragments: int = 1,
image_file: UploadFile = File(None),
):
# 处理上传的图片文件
image_path = ""
if image_file and image_file.filename:
# 生成唯一的文件名
file_extension = Path(image_file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
image_path = INPUT_IMAGE_DIR / unique_filename
# 保存上传的文件
with open(image_path, "wb") as buffer:
content = await image_file.read()
buffer.write(content)
image_path = str(image_path)
message = Message(
task_id=task_id,
task_id_must_unique=task_id_must_unique,
prompt=prompt,
use_prompt_enhancer=use_prompt_enhancer,
negative_prompt=negative_prompt,
image_path=image_path,
num_fragments=num_fragments,
save_video_path=save_video_path,
)
try:
task_id = ApiServerServiceStatus.start_task(message)
global thread, stop_generation_event
stop_generation_event.clear()
thread = threading.Thread(
target=local_video_generate,
args=(
message,
stop_generation_event,
),
daemon=True,
)
thread.start()
return {"task_id": task_id, "task_status": "processing", "save_video_path": message.save_video_path}
except RuntimeError as e:
return {"error": str(e)}
@app.get("/v1/local/video/generate/service_status")
async def get_service_status():
return ApiServerServiceStatus.get_status_service()
@app.get("/v1/local/video/generate/get_all_tasks")
async def get_all_tasks():
return ApiServerServiceStatus.get_all_tasks()
@app.post("/v1/local/video/generate/task_status")
async def get_task_status(message: TaskStatusMessage):
return ApiServerServiceStatus.get_status_task_id(message.task_id)
@app.get("/v1/local/video/generate/get_task_result")
async def get_task_result(message: TaskStatusMessage):
result = ApiServerServiceStatus.get_status_task_id(message.task_id)
# 传输save_video_path内容到外部
save_video_path = result.get("save_video_path")
if save_video_path and Path(save_video_path).is_absolute() and Path(save_video_path).exists():
file_path = Path(save_video_path)
relative_path = file_path.relative_to(OUTPUT_VIDEO_DIR.resolve()) if str(file_path).startswith(str(OUTPUT_VIDEO_DIR.resolve())) else file_path.name
return {
"status": "success",
"task_status": result.get("status", "unknown"),
"filename": file_path.name,
"file_size": file_path.stat().st_size,
"download_url": f"/v1/file/download/{relative_path}",
"message": "任务结果已准备就绪",
}
elif save_video_path and not Path(save_video_path).is_absolute():
video_path = OUTPUT_VIDEO_DIR / save_video_path
if video_path.exists():
return {
"status": "success",
"task_status": result.get("status", "unknown"),
"filename": video_path.name,
"file_size": video_path.stat().st_size,
"download_url": f"/v1/file/download/{save_video_path}",
"message": "任务结果已准备就绪",
}
return {"status": "not_found", "message": "Task result not found", "task_status": result.get("status", "unknown")}
def file_stream_generator(file_path: str, chunk_size: int = 1024 * 1024):
"""文件流生成器,逐块读取文件"""
with open(file_path, "rb") as file:
while chunk := file.read(chunk_size):
yield chunk
@app.get(
"/v1/file/download/{file_path:path}",
response_class=StreamingResponse,
summary="下载文件",
description="流式下载指定的文件",
responses={200: {"description": "文件下载成功", "content": {"application/octet-stream": {}}}, 404: {"description": "文件未找到"}, 500: {"description": "服务器错误"}},
)
async def download_file(file_path: str):
try:
full_path = OUTPUT_VIDEO_DIR / file_path
resolved_path = full_path.resolve()
# 安全检查:确保文件在允许的目录内
if not str(resolved_path).startswith(str(OUTPUT_VIDEO_DIR.resolve())):
return {"status": "forbidden", "message": "不允许访问该文件"}
if resolved_path.exists() and resolved_path.is_file():
file_size = resolved_path.stat().st_size
filename = resolved_path.name
# 设置适当的 MIME 类型
mime_type = "application/octet-stream"
if filename.lower().endswith((".mp4", ".avi", ".mov", ".mkv")):
mime_type = "video/mp4"
elif filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
mime_type = "image/jpeg"
headers = {
"Content-Disposition": f'attachment; filename="{filename}"',
"Content-Length": str(file_size),
"Accept-Ranges": "bytes",
}
return StreamingResponse(file_stream_generator(str(resolved_path)), media_type=mime_type, headers=headers)
else:
return {"status": "not_found", "message": f"文件未找到: {file_path}"}
except Exception as e:
logger.error(f"处理文件下载请求时发生错误: {e}")
return {"status": "error", "message": "文件下载失败"}
@app.get("/v1/local/video/generate/stop_running_task")
async def stop_running_task():
global thread, stop_generation_event
if thread and thread.is_alive():
try:
logger.info("正在发送停止信号给运行中的任务线程...")
stop_generation_event.set() # 设置事件,通知线程停止
thread.join(timeout=5) # 等待线程结束,设置超时时间
if thread.is_alive():
logger.warning("任务线程未在规定时间内停止,可能需要手动干预。")
return {"stop_status": "warning", "reason": "任务线程未在规定时间内停止,可能需要手动干预。"}
else:
# 清理线程引用
thread = None
ApiServerServiceStatus.clean_stopped_task()
gc.collect()
torch.cuda.empty_cache()
logger.info("任务已成功停止。")
return {"stop_status": "success", "reason": "Task stopped successfully."}
except Exception as e:
logger.error(f"停止任务时发生错误: {str(e)}")
return {"stop_status": "error", "reason": str(e)}
else:
return {"stop_status": "do_nothing", "reason": "No running task found."}
# 使用多进程队列进行通信
input_queues = []
output_queues = []
distributed_runners = []
def distributed_inference_worker(rank, world_size, master_addr, master_port, args, input_queue, output_queue):
"""分布式推理服务工作进程"""
try:
# 设置环境变量
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["ENABLE_PROFILING_DEBUG"] = "true"
os.environ["ENABLE_GRAPH_MODE"] = "false"
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["CUDA_VISIBLE_DEVICES"] = str(rank)
logger.info(f"进程 {rank}/{world_size - 1} 正在初始化分布式推理服务...")
dist.init_process_group(backend="nccl", init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)
config = set_config(args)
config["mode"] = "server"
logger.info(f"config: {config}")
runner = init_runner(config)
logger.info(f"进程 {rank}/{world_size - 1} 分布式推理服务初始化完成,等待任务...")
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while True:
try:
task_data = input_queue.get(timeout=1.0) # 1秒超时
if task_data is None: # 停止信号
logger.info(f"进程 {rank}/{world_size - 1} 收到停止信号,退出推理服务")
break
logger.info(f"进程 {rank}/{world_size - 1} 收到推理任务: {task_data['task_id']}")
runner.set_inputs(task_data)
# 运行推理,复用已创建的事件循环
try:
loop.run_until_complete(runner.run_pipeline())
# 只有 Rank 0 负责将结果放入输出队列,避免重复
if rank == 0:
result = {"task_id": task_data["task_id"], "status": "success", "save_video_path": task_data["save_video_path"], "message": "推理完成"}
output_queue.put(result)
logger.info(f"任务 {task_data['task_id']} 处理完成 (由 Rank 0 报告)")
if dist.is_initialized():
dist.barrier()
except Exception as e:
# 只有 Rank 0 负责报告错误
if rank == 0:
result = {"task_id": task_data["task_id"], "status": "failed", "error": str(e), "message": f"推理失败: {str(e)}"}
output_queue.put(result)
logger.error(f"任务 {task_data['task_id']} 推理失败: {str(e)} (由 Rank 0 报告)")
if dist.is_initialized():
dist.barrier()
except queue.Empty:
# 队列为空,继续等待
continue
except KeyboardInterrupt:
logger.info(f"进程 {rank}/{world_size - 1} 收到 KeyboardInterrupt,优雅退出")
break
except Exception as e:
logger.error(f"进程 {rank}/{world_size - 1} 处理任务时发生错误: {str(e)}")
# 只有 Rank 0 负责发送错误结果
task_data = task_data if "task_data" in locals() else {}
if rank == 0:
error_result = {
"task_id": task_data.get("task_id", "unknown"),
"status": "error",
"error": str(e),
"message": f"处理任务时发生错误: {str(e)}",
}
try:
output_queue.put(error_result)
except: # noqa: E722
pass
if dist.is_initialized():
try:
dist.barrier()
except: # noqa: E722
pass
except KeyboardInterrupt:
logger.info(f"进程 {rank}/{world_size - 1} 主循环收到 KeyboardInterrupt,正在退出")
except Exception as e:
logger.error(f"分布式推理服务进程 {rank}/{world_size - 1} 启动失败: {str(e)}")
# 只有 Rank 0 负责报告启动失败
if rank == 0:
try:
error_result = {"task_id": "startup", "status": "startup_failed", "error": str(e), "message": f"推理服务启动失败: {str(e)}"}
output_queue.put(error_result)
except: # noqa: E722
pass
# 在进程最终退出时关闭事件循环和销毁分布式组
finally:
try:
if "loop" in locals() and loop and not loop.is_closed():
loop.close()
except: # noqa: E722
pass
try:
if dist.is_initialized():
dist.destroy_process_group()
except: # noqa: E722
pass
def start_distributed_inference_with_queue(args):
"""使用队列启动分布式推理服务,并模拟torchrun的多进程模式"""
global input_queues, output_queues, distributed_runners
nproc_per_node = args.nproc_per_node
if nproc_per_node <= 0:
logger.error("nproc_per_node 必须大于0")
return False
try:
master_addr = "127.0.0.1"
master_port = str(random.randint(20000, 29999))
logger.info(f"分布式推理服务 Master Addr: {master_addr}, Master Port: {master_port}")
processes = []
ctx = mp.get_context("spawn")
for rank in range(nproc_per_node):
input_queue = ctx.Queue()
output_queue = ctx.Queue()
p = ctx.Process(target=distributed_inference_worker, args=(rank, nproc_per_node, master_addr, master_port, args, input_queue, output_queue), daemon=True)
p.start()
processes.append(p)
input_queues.append(input_queue)
output_queues.append(output_queue)
distributed_runners = processes
return True
except Exception as e:
logger.exception(f"启动分布式推理服务时发生错误: {str(e)}")
stop_distributed_inference_with_queue()
return False
def stop_distributed_inference_with_queue():
"""停止分布式推理服务"""
global input_queues, output_queues, distributed_runners
try:
if distributed_runners:
logger.info(f"正在停止 {len(distributed_runners)} 个分布式推理服务进程...")
# 向所有工作进程发送停止信号
if input_queues:
for input_queue in input_queues:
try:
input_queue.put(None)
except: # noqa: E722
pass
# 等待所有进程结束
for p in distributed_runners:
try:
p.join(timeout=10)
except: # noqa: E722
pass
# 强制终止任何未结束的进程
for p in distributed_runners:
try:
if p.is_alive():
logger.warning(f"推理服务进程 {p.pid} 未在规定时间内结束,强制终止...")
p.terminate()
p.join(timeout=5)
except: # noqa: E722
pass
logger.info("所有分布式推理服务进程已停止")
# 清理队列
if input_queues:
try:
for input_queue in input_queues:
try:
while not input_queue.empty():
input_queue.get_nowait()
except: # noqa: E722
pass
except: # noqa: E722
pass
if output_queues:
try:
for output_queue in output_queues:
try:
while not output_queue.empty():
output_queue.get_nowait()
except: # noqa: E722
pass
except: # noqa: E722
pass
distributed_runners = []
input_queues = []
output_queues = []
except Exception as e:
logger.error(f"停止分布式推理服务时发生错误: {str(e)}")
except KeyboardInterrupt:
logger.info("停止分布式推理服务时收到 KeyboardInterrupt,强制清理")
# =========================
# Main Entry
# =========================
if __name__ == "__main__":
global startup_args
ProcessManager.register_signal_handler()
parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causvid", "wan2.1_skyreels_v2_df"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--split", action="store_true")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--start_inference", action="store_true", help="是否在启动API服务器前启动分布式推理服务")
parser.add_argument("--nproc_per_node", type=int, default=4, help="分布式推理时每个节点的进程数")
args = parser.parse_args()
logger.info(f"args: {args}")
# 保存启动参数供重启功能使用
startup_args = args
if args.start_inference:
logger.info("正在启动分布式推理服务...")
success = start_distributed_inference_with_queue(args)
if not success:
logger.error("分布式推理服务启动失败,退出程序")
sys.exit(1)
# 注册程序退出时的清理函数
import atexit
atexit.register(stop_distributed_inference_with_queue)
# 注册信号处理器,用于优雅关闭
import signal
def signal_handler(signum, frame):
logger.info(f"接收到信号 {signum},正在优雅关闭...")
try:
stop_distributed_inference_with_queue()
except: # noqa: E722
logger.error("关闭分布式推理服务时发生错误")
finally:
sys.exit(0)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
try:
logger.info(f"正在启动FastAPI服务器,端口: {args.port}")
uvicorn.run(app, host="0.0.0.0", port=args.port, reload=False, workers=1)
except KeyboardInterrupt:
logger.info("接收到KeyboardInterrupt,正在关闭服务...")
except Exception as e:
logger.error(f"FastAPI服务器运行时发生错误: {str(e)}")
finally:
# 确保在程序结束时停止推理服务
if args.start_inference:
stop_distributed_inference_with_queue()
"""
curl -X 'POST' \
'http://localhost:8000/v1/local/video/generate_form?task_id=abc&prompt=%E8%B7%B3%E8%88%9E&save_video_path=a.mp4&task_id_must_unique=false&use_prompt_enhancer=false&num_fragments=1' \
-H 'accept: application/json' \
-H 'Content-Type: multipart/form-data' \
-F 'image_file=@图片1.png;type=image/png'
curl -X 'POST' \
'http://localhost:8000/v1/local/video/generate' \
-H 'accept: application/json' \
-H 'Content-Type: application/json' \
-d '{
"task_id": "abcde",
"task_id_must_unique": false,
"prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline'\''s intricate details and the refreshing atmosphere of the seaside.",
"use_prompt_enhancer": false,
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "/mnt/aigc/users/gaopeng1/ComfyUI-Lightx2vWrapper/lightx2v/assets/inputs/imgs/img_0.jpg",
"num_fragments": 1,
"save_video_path": "/mnt/aigc/users/lijiaqi2/ComfyUI/custom_nodes/ComfyUI-Lightx2vWrapper/lightx2v/save_results/img_0.mp4"
}'
"""
......@@ -41,7 +41,6 @@ def radial_attn(
head_dim=hidden_dim,
q_data_type=query.dtype,
kv_data_type=key.dtype,
o_data_type=query.dtype,
use_fp16_qk_reduction=True,
)
......
......@@ -58,6 +58,19 @@ async def main():
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args()
if args.prompt_path:
try:
with open(args.prompt_path, "r", encoding="utf-8") as f:
args.prompt = f.read().strip()
logger.info(f"从文件 {args.prompt_path} 读取到prompt: {args.prompt}")
except FileNotFoundError:
logger.error(f"找不到prompt文件: {args.prompt_path}")
raise
except Exception as e:
logger.error(f"读取prompt文件时出错: {e}")
raise
logger.info(f"args: {args}")
with ProfilingContext("Total Cost"):
......
......@@ -349,7 +349,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.config.get("audio_sr", False):
freqs_i = compute_freqs_audio_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
else:
freqs_i = compute_freqs_dist(q.size(2) // 2, grid_sizes, freqs)
freqs_i = compute_freqs_dist(q.size(0), q.size(2) // 2, grid_sizes, freqs)
freqs_i = self.zero_temporal_component_in_3DRoPE(seq_lens, freqs_i)
......
......@@ -101,6 +101,14 @@ class DefaultRunner:
self.config["negative_prompt"] = inputs.get("negative_prompt", "")
self.config["image_path"] = inputs.get("image_path", "")
self.config["save_video_path"] = inputs.get("save_video_path", "")
self.config["infer_steps"] = inputs.get("infer_steps", self.config.get("infer_steps", 5))
self.config["target_video_length"] = inputs.get("target_video_length", self.config.get("target_video_length", 81))
self.config["seed"] = inputs.get("seed", self.config.get("seed", 42))
self.config["audio_path"] = inputs.get("audio_path", "") # for wan-audio
self.config["video_duration"] = inputs.get("video_duration", 5) # for wan-audio
# self.config["sample_shift"] = inputs.get("sample_shift", self.config.get("sample_shift", 5))
# self.config["sample_guide_scale"] = inputs.get("sample_guide_scale", self.config.get("sample_guide_scale", 5))
def run(self):
for step_index in range(self.model.scheduler.infer_steps):
......
......@@ -219,8 +219,9 @@ def save_to_video(gen_lvideo, out_path, target_fps):
def save_audio(
audio_array,
audio_name: str,
video_name: str = None,
video_name: str,
sr: int = 16000,
output_path: Optional[str] = None,
):
logger.info(f"Saving audio to {audio_name} type: {type(audio_array)}")
......@@ -230,18 +231,21 @@ def save_audio(
sample_rate=sr,
)
if output_path is None:
out_video = f"{video_name[:-4]}_with_audio.mp4"
# 确保父目录存在
else:
out_video = output_path
parent_dir = os.path.dirname(out_video)
if parent_dir and not os.path.exists(parent_dir):
os.makedirs(parent_dir, exist_ok=True)
# 如果输出视频已存在,先删除
if os.path.exists(out_video):
os.remove(out_video)
cmd = f"/usr/bin/ffmpeg -i {video_name} -i {audio_name} {out_video}"
subprocess.call(cmd, shell=True)
subprocess.call(["/usr/bin/ffmpeg", "-y", "-i", video_name, "-i", audio_name, out_video])
return out_video
@RUNNER_REGISTER("wan2.1_audio")
......@@ -323,17 +327,16 @@ class WanAudioRunner(WanRunner):
"vae_encode_out": vae_encode_out,
}
logger.info(f"clip_encoder_out:{clip_encoder_out.shape} vae_encode_out:{vae_encode_out.shape}")
with ProfilingContext("Run Text Encoder"):
with open(self.config["prompt_path"], "r", encoding="utf-8") as f:
prompt = f.readline().strip()
logger.info(f"Prompt: {prompt}")
logger.info(f"Prompt: {self.config['prompt']}")
img = Image.open(self.config["image_path"]).convert("RGB")
text_encoder_output = self.run_text_encoder(prompt, img)
text_encoder_output = self.run_text_encoder(self.config["prompt"], img)
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
del self.image_encoder # 删除ref的clip模型,只使用一次
# del self.image_encoder # 删除ref的clip模型,只使用一次
gc.collect()
torch.cuda.empty_cache()
......@@ -488,10 +491,10 @@ class WanAudioRunner(WanRunner):
gen_lvideo = torch.cat(gen_video_list, dim=2).float()
merge_audio = np.concatenate(cut_audio_list, axis=0).astype(np.float32)
out_path = self.config.save_video_path
out_path = os.path.join("./", "video_merge.mp4")
audio_file = os.path.join("./", "audio_merge.wav")
save_to_video(gen_lvideo, out_path, target_fps)
save_audio(merge_audio, audio_file, out_path)
save_audio(merge_audio, audio_file, out_path, output_path=self.config.get("save_video_path", None))
os.remove(out_path)
os.remove(audio_file)
......
import asyncio
from fastapi import FastAPI, UploadFile, HTTPException, Form, File, APIRouter
from fastapi.responses import StreamingResponse
from loguru import logger
import threading
import gc
import torch
from pathlib import Path
import uuid
from typing import Optional
from .schema import (
TaskRequest,
TaskResponse,
ServiceStatusResponse,
StopTaskResponse,
)
from .service import FileService, DistributedInferenceService, VideoGenerationService
from .utils import ServiceStatus
class ApiServer:
def __init__(self):
self.app = FastAPI(title="LightX2V API", version="1.0.0")
self.file_service = None
self.inference_service = None
self.video_service = None
self.thread = None
self.stop_generation_event = threading.Event()
# Create routers
self.tasks_router = APIRouter(prefix="/v1/tasks", tags=["tasks"])
self.files_router = APIRouter(prefix="/v1/files", tags=["files"])
self.service_router = APIRouter(prefix="/v1/service", tags=["service"])
self._setup_routes()
def _setup_routes(self):
"""Setup routes"""
self._setup_task_routes()
self._setup_file_routes()
self._setup_service_routes()
# Register routers
self.app.include_router(self.tasks_router)
self.app.include_router(self.files_router)
self.app.include_router(self.service_router)
def _stream_file_response(self, file_path: Path, filename: str | None = None) -> StreamingResponse:
"""Common file streaming response method"""
assert self.file_service is not None, "File service is not initialized"
try:
resolved_path = file_path.resolve()
# Security check: ensure file is within allowed directory
if not str(resolved_path).startswith(str(self.file_service.output_video_dir.resolve())):
raise HTTPException(status_code=403, detail="Access to this file is not allowed")
if not resolved_path.exists() or not resolved_path.is_file():
raise HTTPException(status_code=404, detail=f"File not found: {file_path}")
file_size = resolved_path.stat().st_size
actual_filename = filename or resolved_path.name
# Set appropriate MIME type
mime_type = "application/octet-stream"
if actual_filename.lower().endswith((".mp4", ".avi", ".mov", ".mkv")):
mime_type = "video/mp4"
elif actual_filename.lower().endswith((".jpg", ".jpeg", ".png", ".gif")):
mime_type = "image/jpeg"
headers = {
"Content-Disposition": f'attachment; filename="{actual_filename}"',
"Content-Length": str(file_size),
"Accept-Ranges": "bytes",
}
def file_stream_generator(file_path: str, chunk_size: int = 1024 * 1024):
with open(file_path, "rb") as file:
while chunk := file.read(chunk_size):
yield chunk
return StreamingResponse(
file_stream_generator(str(resolved_path)),
media_type=mime_type,
headers=headers,
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while processing file stream response: {e}")
raise HTTPException(status_code=500, detail="File transfer failed")
def _setup_task_routes(self):
@self.tasks_router.post("/", response_model=TaskResponse)
async def create_task(message: TaskRequest):
"""Create video generation task"""
try:
task_id = ServiceStatus.start_task(message)
# Use background thread to handle long-running tasks
self.stop_generation_event.clear()
self.thread = threading.Thread(
target=self._process_video_generation,
args=(message, self.stop_generation_event),
daemon=True,
)
self.thread.start()
return TaskResponse(
task_id=task_id,
task_status="processing",
save_video_path=message.save_video_path,
)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@self.tasks_router.post("/form", response_model=TaskResponse)
async def create_task_form(
image_file: UploadFile = File(...),
prompt: str = Form(default=""),
save_video_path: str = Form(default=""),
use_prompt_enhancer: bool = Form(default=False),
negative_prompt: str = Form(default=""),
num_fragments: int = Form(default=1),
infer_steps: int = Form(default=5),
target_video_length: int = Form(default=81),
seed: int = Form(default=42),
audio_file: Optional[UploadFile] = File(default=None),
video_duration: int = Form(default=5),
):
"""Create video generation task via form"""
# Process uploaded image file
image_path = ""
assert self.file_service is not None, "File service is not initialized"
if image_file and image_file.filename:
file_extension = Path(image_file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
image_path = self.file_service.input_image_dir / unique_filename
with open(image_path, "wb") as buffer:
content = await image_file.read()
buffer.write(content)
image_path = str(image_path)
audio_path = ""
if audio_file and audio_file.filename:
file_extension = Path(audio_file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
audio_path = self.file_service.input_audio_dir / unique_filename
with open(audio_path, "wb") as buffer:
content = await audio_file.read()
buffer.write(content)
audio_path = str(audio_path)
message = TaskRequest(
prompt=prompt,
use_prompt_enhancer=use_prompt_enhancer,
negative_prompt=negative_prompt,
image_path=image_path,
num_fragments=num_fragments,
save_video_path=save_video_path,
infer_steps=infer_steps,
target_video_length=target_video_length,
seed=seed,
audio_path=audio_path,
video_duration=video_duration,
)
try:
task_id = ServiceStatus.start_task(message)
self.stop_generation_event.clear()
self.thread = threading.Thread(
target=self._process_video_generation,
args=(message, self.stop_generation_event),
daemon=True,
)
self.thread.start()
return TaskResponse(
task_id=task_id,
task_status="processing",
save_video_path=message.save_video_path,
)
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
@self.tasks_router.get("/", response_model=dict)
async def list_tasks():
"""Get all task list"""
return ServiceStatus.get_all_tasks()
@self.tasks_router.get("/{task_id}/status")
async def get_task_status(task_id: str):
"""Get status of specified task"""
return ServiceStatus.get_status_task_id(task_id)
@self.tasks_router.get("/{task_id}/result")
async def get_task_result(task_id: str):
"""Get result video file of specified task"""
assert self.video_service is not None, "Video service is not initialized"
assert self.file_service is not None, "File service is not initialized"
try:
task_status = ServiceStatus.get_status_task_id(task_id)
if not task_status or task_status.get("status") != "completed":
raise HTTPException(status_code=404, detail="Task not completed or does not exist")
save_video_path = task_status.get("save_video_path")
if not save_video_path:
raise HTTPException(status_code=404, detail="Task result file does not exist")
full_path = Path(save_video_path)
if not full_path.is_absolute():
full_path = self.file_service.output_video_dir / save_video_path
return self._stream_file_response(full_path)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while getting task result: {e}")
raise HTTPException(status_code=500, detail="Failed to get task result")
@self.tasks_router.delete("/running", response_model=StopTaskResponse)
async def stop_running_task():
"""Stop currently running task"""
if self.thread and self.thread.is_alive():
try:
logger.info("Sending stop signal to running task thread...")
self.stop_generation_event.set()
self.thread.join(timeout=5)
if self.thread.is_alive():
logger.warning("Task thread did not stop within the specified time, manual intervention may be required.")
return StopTaskResponse(
stop_status="warning",
reason="Task thread did not stop within the specified time, manual intervention may be required.",
)
else:
self.thread = None
ServiceStatus.clean_stopped_task()
gc.collect()
torch.cuda.empty_cache()
logger.info("Task stopped successfully.")
return StopTaskResponse(stop_status="success", reason="Task stopped successfully.")
except Exception as e:
logger.error(f"Error occurred while stopping task: {str(e)}")
return StopTaskResponse(stop_status="error", reason=str(e))
else:
return StopTaskResponse(stop_status="do_nothing", reason="No running task found.")
def _setup_file_routes(self):
@self.files_router.get("/download/{file_path:path}")
async def download_file(file_path: str):
"""Download file"""
assert self.file_service is not None, "File service is not initialized"
try:
full_path = self.file_service.output_video_dir / file_path
return self._stream_file_response(full_path)
except HTTPException:
raise
except Exception as e:
logger.error(f"Error occurred while processing file download request: {e}")
raise HTTPException(status_code=500, detail="File download failed")
def _setup_service_routes(self):
@self.service_router.get("/status", response_model=ServiceStatusResponse)
async def get_service_status():
"""Get service status"""
return ServiceStatus.get_status_service()
def _process_video_generation(self, message: TaskRequest, stop_event: threading.Event):
assert self.video_service is not None, "Video service is not initialized"
try:
if stop_event.is_set():
logger.info(f"Task {message.task_id} received stop signal, terminating")
ServiceStatus.record_failed_task(message, error="Task stopped")
return
# Use video generation service to process task
result = asyncio.run(self.video_service.generate_video(message))
except Exception as e:
logger.error(f"Task {message.task_id} processing failed: {str(e)}")
ServiceStatus.record_failed_task(message, error=str(e))
def initialize_services(self, cache_dir: Path, inference_service: DistributedInferenceService):
self.file_service = FileService(cache_dir)
self.inference_service = inference_service
self.video_service = VideoGenerationService(self.file_service, inference_service)
def get_app(self) -> FastAPI:
return self.app
import os
import torch
import torch.distributed as dist
from loguru import logger
class DistributedManager:
def __init__(self):
self.is_initialized = False
self.rank = 0
self.world_size = 1
def init_process_group(self, rank: int, world_size: int, master_addr: str, master_port: str) -> bool:
try:
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
dist.init_process_group(backend="nccl", init_method=f"tcp://{master_addr}:{master_port}", rank=rank, world_size=world_size)
if torch.cuda.is_available(): # type: ignore
torch.cuda.set_device(rank)
self.is_initialized = True
self.rank = rank
self.world_size = world_size
logger.info(f"Rank {rank}/{world_size - 1} distributed environment initialized successfully")
return True
except Exception as e:
logger.error(f"Rank {rank} distributed environment initialization failed: {str(e)}")
return False
def cleanup(self):
try:
if dist.is_initialized():
dist.destroy_process_group()
logger.info(f"Rank {self.rank} distributed environment cleaned up")
except Exception as e:
logger.error(f"Rank {self.rank} error occurred while cleaning up distributed environment: {str(e)}")
finally:
self.is_initialized = False
def barrier(self):
if self.is_initialized:
dist.barrier()
def is_rank_zero(self) -> bool:
return self.rank == 0
def broadcast_task_data(self, task_data=None): # type: ignore
if not self.is_initialized:
return None
if self.is_rank_zero():
if task_data is None:
stop_signal = torch.tensor([1], dtype=torch.int32, device=f"cuda:{self.rank}")
else:
stop_signal = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")
dist.broadcast(stop_signal, src=0)
if task_data is not None:
import pickle
task_bytes = pickle.dumps(task_data)
task_length = torch.tensor([len(task_bytes)], dtype=torch.int32, device=f"cuda:{self.rank}")
dist.broadcast(task_length, src=0)
task_tensor = torch.tensor(list(task_bytes), dtype=torch.uint8, device=f"cuda:{self.rank}")
dist.broadcast(task_tensor, src=0)
return task_data
else:
return None
else:
stop_signal = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")
dist.broadcast(stop_signal, src=0)
if stop_signal.item() == 1:
return None
else:
task_length = torch.tensor([0], dtype=torch.int32, device=f"cuda:{self.rank}")
dist.broadcast(task_length, src=0)
task_tensor = torch.empty(int(task_length.item()), dtype=torch.uint8, device=f"cuda:{self.rank}")
dist.broadcast(task_tensor, src=0)
import pickle
task_bytes = bytes(task_tensor.cpu().numpy())
task_data = pickle.loads(task_bytes)
return task_data
class DistributedWorker:
def __init__(self, rank: int, world_size: int, master_addr: str, master_port: str):
self.rank = rank
self.world_size = world_size
self.master_addr = master_addr
self.master_port = master_port
self.dist_manager = DistributedManager()
def init(self) -> bool:
return self.dist_manager.init_process_group(self.rank, self.world_size, self.master_addr, self.master_port)
def cleanup(self):
self.dist_manager.cleanup()
def sync_and_report(self, task_id: str, status: str, result_queue, **kwargs):
# Synchronize all processes
self.dist_manager.barrier()
if self.dist_manager.is_rank_zero():
result = {"task_id": task_id, "status": status, **kwargs}
result_queue.put(result)
logger.info(f"Task {task_id} {status}")
def create_distributed_worker(rank: int, world_size: int, master_addr: str, master_port: str) -> DistributedWorker:
return DistributedWorker(rank, world_size, master_addr, master_port)
from pydantic import BaseModel, Field
from typing import Optional
from datetime import datetime
from ..utils.generate_task_id import generate_task_id
class TaskRequest(BaseModel):
task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)")
prompt: str = Field("", description="Generation prompt")
use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
negative_prompt: str = Field("", description="Negative prompt")
image_path: str = Field("", description="Input image path")
num_fragments: int = Field(1, description="Number of fragments")
save_video_path: str = Field("", description="Save video path (optional, defaults to task_id.mp4)")
infer_steps: int = Field(5, description="Inference steps")
target_video_length: int = Field(81, description="Target video length")
seed: int = Field(42, description="Random seed")
audio_path: str = Field("", description="Input audio path (Wan-Audio)")
video_duration: int = Field(5, description="Video duration (Wan-Audio)")
def __init__(self, **data):
super().__init__(**data)
# If save_video_path is empty, use task_id.mp4
if not self.save_video_path:
self.save_video_path = f"{self.task_id}.mp4"
def get(self, key, default=None):
return getattr(self, key, default)
class TaskStatusMessage(BaseModel):
task_id: str = Field(..., description="Task ID")
class TaskResponse(BaseModel):
task_id: str
task_status: str
save_video_path: str
class TaskResultResponse(BaseModel):
status: str
task_status: str
filename: Optional[str] = None
file_size: Optional[int] = None
download_url: Optional[str] = None
message: str
class ServiceStatusResponse(BaseModel):
service_status: str
task_id: Optional[str] = None
start_time: Optional[datetime] = None
class StopTaskResponse(BaseModel):
stop_status: str
reason: str
import asyncio
import queue
import time
import uuid
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import httpx
import torch.multiprocessing as mp
from loguru import logger
from ..utils.set_config import set_config
from ..infer import init_runner
from .utils import ServiceStatus
from .schema import TaskRequest, TaskResponse
from .distributed_utils import create_distributed_worker
mp.set_start_method("spawn", force=True)
class FileService:
def __init__(self, cache_dir: Path):
self.cache_dir = cache_dir
self.input_image_dir = cache_dir / "inputs" / "imgs"
self.input_audio_dir = cache_dir / "inputs" / "audios"
self.output_video_dir = cache_dir / "outputs"
# Create directories
for directory in [
self.input_image_dir,
self.output_video_dir,
self.input_audio_dir,
]:
directory.mkdir(parents=True, exist_ok=True)
async def download_image(self, image_url: str) -> Path:
try:
async with httpx.AsyncClient(verify=False) as client:
response = await client.get(image_url)
if response.status_code != 200:
raise ValueError(f"Failed to download image from {image_url}")
image_name = Path(urlparse(image_url).path).name
if not image_name:
raise ValueError(f"Invalid image URL: {image_url}")
image_path = self.input_image_dir / image_name
image_path.parent.mkdir(parents=True, exist_ok=True)
with open(image_path, "wb") as f:
f.write(response.content)
return image_path
except Exception as e:
logger.error(f"Failed to download image: {e}")
raise
def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
file_extension = Path(filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
file_path = self.input_image_dir / unique_filename
with open(file_path, "wb") as f:
f.write(file_content)
return file_path
def get_output_path(self, save_video_path: str) -> Path:
video_path = Path(save_video_path)
if not video_path.is_absolute():
return self.output_video_dir / save_video_path
return video_path
def _distributed_inference_worker(rank, world_size, master_addr, master_port, args, task_queue, result_queue):
task_data = None
loop = None
worker = None
try:
logger.info(f"Process {rank}/{world_size - 1} initializing distributed inference service...")
# Create and initialize distributed worker process
worker = create_distributed_worker(rank, world_size, master_addr, master_port)
if not worker.init():
raise RuntimeError(f"Rank {rank} distributed environment initialization failed")
# Initialize configuration and model
config = set_config(args)
config["mode"] = "server"
logger.info(f"Rank {rank} config: {config}")
runner = init_runner(config)
logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")
# Create event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while True:
# Only rank=0 reads tasks from queue
if rank == 0:
try:
task_data = task_queue.get(timeout=1.0)
if task_data is None: # Stop signal
logger.info(f"Process {rank} received stop signal, exiting inference service")
# Broadcast stop signal to other processes
worker.dist_manager.broadcast_task_data(None)
break
# Broadcast task data to other processes
worker.dist_manager.broadcast_task_data(task_data)
except queue.Empty:
# Queue is empty, continue waiting
continue
else:
# Non-rank=0 processes receive task data from rank=0
task_data = worker.dist_manager.broadcast_task_data()
if task_data is None: # Stop signal
logger.info(f"Process {rank} received stop signal, exiting inference service")
break
# All processes handle the task
if task_data is not None:
logger.info(f"Process {rank} received inference task: {task_data['task_id']}")
try:
# Set inputs and run inference
runner.set_inputs(task_data) # type: ignore
loop.run_until_complete(runner.run_pipeline())
# Synchronize and report results
worker.sync_and_report(
task_data["task_id"],
"success",
result_queue,
save_video_path=task_data["save_video_path"],
message="Inference completed",
)
except Exception as e:
logger.error(f"Process {rank} error occurred while processing task: {str(e)}")
# Synchronize and report error
worker.sync_and_report(
task_data.get("task_id", "unknown"),
"failed",
result_queue,
error=str(e),
message=f"Inference failed: {str(e)}",
)
except KeyboardInterrupt:
logger.info(f"Process {rank} received KeyboardInterrupt, gracefully exiting")
except Exception as e:
logger.error(f"Distributed inference service process {rank} startup failed: {str(e)}")
if rank == 0:
error_result = {
"task_id": "startup",
"status": "startup_failed",
"error": str(e),
"message": f"Inference service startup failed: {str(e)}",
}
result_queue.put(error_result)
finally:
# Clean up resources
try:
if loop and not loop.is_closed():
loop.close()
except: # noqa: E722
pass
try:
if worker:
worker.cleanup()
except: # noqa: E722
pass
class DistributedInferenceService:
def __init__(self):
self.task_queue = None
self.result_queue = None
self.processes = []
self.is_running = False
def start_distributed_inference(self, args) -> bool:
if self.is_running:
logger.warning("Distributed inference service is already running")
return True
nproc_per_node = args.nproc_per_node
if nproc_per_node <= 0:
logger.error("nproc_per_node must be greater than 0")
return False
try:
import random
master_addr = "127.0.0.1"
master_port = str(random.randint(20000, 29999))
logger.info(f"Distributed inference service Master Addr: {master_addr}, Master Port: {master_port}")
# Create shared queues
self.task_queue = mp.Queue()
self.result_queue = mp.Queue()
# Start processes
for rank in range(nproc_per_node):
p = mp.Process(
target=_distributed_inference_worker,
args=(
rank,
nproc_per_node,
master_addr,
master_port,
args,
self.task_queue,
self.result_queue,
),
daemon=True,
)
p.start()
self.processes.append(p)
self.is_running = True
logger.info(f"Distributed inference service started successfully with {nproc_per_node} processes")
return True
except Exception as e:
logger.exception(f"Error occurred while starting distributed inference service: {str(e)}")
self.stop_distributed_inference()
return False
def stop_distributed_inference(self):
if not self.is_running:
return
try:
logger.info(f"Stopping {len(self.processes)} distributed inference service processes...")
# Send stop signal
if self.task_queue:
for _ in self.processes:
self.task_queue.put(None)
# Wait for processes to end
for p in self.processes:
try:
p.join(timeout=10)
if p.is_alive():
logger.warning(f"Process {p.pid} did not end within the specified time, forcing termination...")
p.terminate()
p.join(timeout=5)
except: # noqa: E722
pass
logger.info("All distributed inference service processes have stopped")
except Exception as e:
logger.error(f"Error occurred while stopping distributed inference service: {str(e)}")
finally:
# Clean up resources
self._clean_queues()
self.processes = []
self.task_queue = None
self.result_queue = None
self.is_running = False
def _clean_queues(self):
for queue_obj in [self.task_queue, self.result_queue]:
if queue_obj:
try:
while not queue_obj.empty():
queue_obj.get_nowait()
except: # noqa: E722
pass
def submit_task(self, task_data: dict) -> bool:
if not self.is_running or not self.task_queue:
logger.error("Distributed inference service is not started")
return False
try:
self.task_queue.put(task_data)
return True
except Exception as e:
logger.error(f"Failed to submit task: {str(e)}")
return False
def wait_for_result(self, task_id: str, timeout: int = 300) -> Optional[dict]:
if not self.is_running or not self.result_queue:
return None
start_time = time.time()
while time.time() - start_time < timeout:
try:
result = self.result_queue.get(timeout=1.0)
if result.get("task_id") == task_id:
return result
else:
# Not the result for current task, put back in queue
self.result_queue.put(result)
time.sleep(0.1)
except queue.Empty:
continue
return None
class VideoGenerationService:
def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
self.file_service = file_service
self.inference_service = inference_service
async def generate_video(self, message: TaskRequest) -> TaskResponse:
try:
# Process image path
task_data = {
"task_id": message.task_id,
"prompt": message.prompt,
"use_prompt_enhancer": message.use_prompt_enhancer,
"negative_prompt": message.negative_prompt,
"image_path": message.image_path,
"num_fragments": message.num_fragments,
"save_video_path": message.save_video_path,
"infer_steps": message.infer_steps,
"target_video_length": message.target_video_length,
"seed": message.seed,
"audio_path": message.audio_path,
"video_duration": message.video_duration,
}
# Process network image
if message.image_path.startswith("http"):
image_path = await self.file_service.download_image(message.image_path)
task_data["image_path"] = str(image_path)
# Process output path
save_video_path = self.file_service.get_output_path(message.save_video_path)
task_data["save_video_path"] = str(save_video_path)
# Submit task to distributed inference service
if not self.inference_service.submit_task(task_data):
raise RuntimeError("Distributed inference service is not started")
# Wait for result
result = self.inference_service.wait_for_result(message.task_id)
if result is None:
raise RuntimeError("Task processing timeout")
if result.get("status") == "success":
ServiceStatus.complete_task(message)
return TaskResponse(
task_id=message.task_id,
task_status="completed",
save_video_path=str(save_video_path),
)
else:
error_msg = result.get("error", "Inference failed")
ServiceStatus.record_failed_task(message, error=error_msg)
raise RuntimeError(error_msg)
except Exception as e:
logger.error(f"Task {message.task_id} processing failed: {str(e)}")
ServiceStatus.record_failed_task(message, error=str(e))
raise
import sys
import psutil
import signal
import base64
from PIL import Image
from loguru import logger
from typing import Optional
from datetime import datetime
from pydantic import BaseModel
import threading
import torch
import io
class ProcessManager:
@staticmethod
def kill_all_related_processes():
current_process = psutil.Process()
children = current_process.children(recursive=True)
for child in children:
try:
child.kill()
except Exception as e:
logger.info(f"Failed to kill child process {child.pid}: {e}")
try:
current_process.kill()
except Exception as e:
logger.info(f"Failed to kill main process: {e}")
@staticmethod
def signal_handler(sig, frame):
logger.info("\nReceived Ctrl+C, shutting down all related processes...")
ProcessManager.kill_all_related_processes()
sys.exit(0)
@staticmethod
def register_signal_handler():
signal.signal(signal.SIGINT, ProcessManager.signal_handler)
class TaskStatusMessage(BaseModel):
task_id: str
class ServiceStatus:
_lock = threading.Lock()
_current_task = None
_result_store = {}
@classmethod
def start_task(cls, message):
with cls._lock:
if cls._current_task is not None:
raise RuntimeError("Service busy")
if message.task_id in cls._result_store:
raise RuntimeError(f"Task ID {message.task_id} already exists")
cls._current_task = {"message": message, "start_time": datetime.now()}
return message.task_id
@classmethod
def complete_task(cls, message):
with cls._lock:
if cls._current_task:
cls._result_store[message.task_id] = {
"success": True,
"message": message,
"start_time": cls._current_task["start_time"],
"completion_time": datetime.now(),
"save_video_path": message.save_video_path,
}
cls._current_task = None
@classmethod
def record_failed_task(cls, message, error: Optional[str] = None):
with cls._lock:
if cls._current_task:
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error, "save_video_path": message.save_video_path}
cls._current_task = None
@classmethod
def clean_stopped_task(cls):
with cls._lock:
if cls._current_task:
message = cls._current_task["message"]
error = "Task stopped by user"
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error, "save_video_path": message.save_video_path}
cls._current_task = None
@classmethod
def get_status_task_id(cls, task_id: str):
with cls._lock:
if cls._current_task and cls._current_task["message"].task_id == task_id:
return {"status": "processing", "task_id": task_id}
if task_id in cls._result_store:
result = cls._result_store[task_id]
return {
"status": "completed" if result["success"] else "failed",
"task_id": task_id,
"success": result["success"],
"start_time": result["start_time"],
"completion_time": result.get("completion_time"),
"error": result.get("error"),
"save_video_path": result.get("save_video_path"),
}
return {"status": "not_found", "task_id": task_id}
@classmethod
def get_status_service(cls):
with cls._lock:
if cls._current_task:
return {"service_status": "busy", "task_id": cls._current_task["message"].task_id, "start_time": cls._current_task["start_time"]}
return {"service_status": "idle"}
@classmethod
def get_all_tasks(cls):
with cls._lock:
return cls._result_store
class TensorTransporter:
def __init__(self):
self.buffer = io.BytesIO()
def to_device(self, data, device):
if isinstance(data, dict):
return {key: self.to_device(value, device) for key, value in data.items()}
elif isinstance(data, list):
return [self.to_device(item, device) for item in data]
elif isinstance(data, torch.Tensor):
return data.to(device)
else:
return data
def prepare_tensor(self, data) -> str:
self.buffer.seek(0)
self.buffer.truncate()
torch.save(self.to_device(data, "cpu"), self.buffer)
return base64.b64encode(self.buffer.getvalue()).decode("utf-8")
def load_tensor(self, tensor_base64: str, device="cuda"):
tensor_bytes = base64.b64decode(tensor_base64)
with io.BytesIO(tensor_bytes) as buffer:
return self.to_device(torch.load(buffer), device)
class ImageTransporter:
def __init__(self):
self.buffer = io.BytesIO()
def prepare_image(self, image: Image.Image):
self.buffer.seek(0)
self.buffer.truncate()
image.save(self.buffer, format="PNG")
return base64.b64encode(self.buffer.getvalue()).decode("utf-8")
def load_image(self, image_base64: bytes) -> Image.Image:
image_bytes = base64.b64decode(image_base64)
with io.BytesIO(image_bytes) as buffer:
return Image.open(buffer).convert("RGB")
import time
import torch
from contextlib import ContextDecorator
import asyncio
from functools import wraps
from lightx2v.utils.envs import *
from loguru import logger
class _ProfilingContext(ContextDecorator):
class _ProfilingContext:
def __init__(self, name):
self.name = name
self.rank_info = ""
......@@ -31,8 +32,44 @@ class _ProfilingContext(ContextDecorator):
logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
return False
async def __aenter__(self):
torch.cuda.synchronize()
if torch.cuda.is_available():
torch.cuda.reset_peak_memory_stats()
self.start_time = time.perf_counter()
return self
class _NullContext(ContextDecorator):
async def __aexit__(self, exc_type, exc_val, exc_tb):
torch.cuda.synchronize()
if torch.cuda.is_available():
peak_memory = torch.cuda.max_memory_allocated() / (1024**3) # 转换为GB
logger.info(f"{self.rank_info}Function '{self.name}' Peak Memory: {peak_memory:.2f} GB")
else:
logger.info(f"{self.rank_info}Function '{self.name}' executed without GPU.")
elapsed = time.perf_counter() - self.start_time
logger.info(f"[Profile] {self.name} cost {elapsed:.6f} seconds")
return False
def __call__(self, func):
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
async with self:
return await func(*args, **kwargs)
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
with self:
return func(*args, **kwargs)
return sync_wrapper
class _NullContext:
# Context manager without decision branch logic overhead
def __init__(self, *args, **kwargs):
pass
......@@ -43,6 +80,15 @@ class _NullContext(ContextDecorator):
def __exit__(self, *args):
return False
async def __aenter__(self):
return self
async def __aexit__(self, *args):
return False
def __call__(self, func):
return func
ProfilingContext = _ProfilingContext
ProfilingContext4Debug = _ProfilingContext if CHECK_ENABLE_PROFILING_DEBUG() else _NullContext
......@@ -2,13 +2,13 @@ import requests
from loguru import logger
response = requests.get("http://localhost:8000/v1/local/video/generate/service_status")
response = requests.get("http://localhost:8000/v1/service/status")
logger.info(response.json())
response = requests.get("http://localhost:8000/v1/local/video/generate/get_all_tasks")
response = requests.get("http://localhost:8000/v1/tasks/")
logger.info(response.json())
response = requests.post("http://localhost:8000/v1/local/video/generate/task_status", json={"task_id": "test_task_001"})
response = requests.get("http://localhost:8000/v1/tasks/test_task_001/status")
logger.info(response.json())
import requests
from loguru import logger
import random
import string
import time
from datetime import datetime
# same as lightx2v/utils/generate_task_id.py
# from lightx2v.utils.generate_task_id import generate_task_id
def generate_task_id():
"""
Generate a random task ID in the format XXXX-XXXX-XXXX-XXXX-XXXX.
Features:
1. Does not modify the global random state.
2. Each X is an uppercase letter or digit (0-9).
3. Combines time factors to ensure high randomness.
For example: N1PQ-PRM5-N1BN-Z3S1-BGBJ
"""
# Save the current random state (does not affect external randomness)
original_state = random.getstate()
try:
# Define character set (uppercase letters + digits)
characters = string.ascii_uppercase + string.digits
# Create an independent random instance
local_random = random.Random(time.perf_counter_ns())
# Generate 5 groups of 4-character random strings
groups = []
for _ in range(5):
# Mix new time factor for each group
time_mix = int(datetime.now().timestamp())
local_random.seed(time_mix + local_random.getstate()[1][0] + time.perf_counter_ns())
groups.append("".join(local_random.choices(characters, k=4)))
return "-".join(groups)
finally:
# Restore the original random state
random.setstate(original_state)
if __name__ == "__main__":
url = "http://localhost:8000/v1/local/video/generate"
url = "http://localhost:8000/v1/tasks/"
message = {
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_t02.mp4", # It is best to set it to an absolute path.
}
logger.info(f"message: {message}")
......
......@@ -2,14 +2,12 @@ import requests
from loguru import logger
url = "http://localhost:8000/v1/local/video/generate"
url = "http://localhost:8000/v1/tasks/"
message = {
"task_id": "test_task_001",
"prompt": "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "",
"save_video_path": "./output_lightx2v_wan_t2v_enhanced.mp4", # It is best to set it to an absolute path.
"use_prompt_enhancer": True,
}
......
import requests
from loguru import logger
import random
import string
import time
from datetime import datetime
# same as lightx2v/utils/generate_task_id.py
# from lightx2v.utils.generate_task_id import generate_task_id
def generate_task_id():
"""
Generate a random task ID in the format XXXX-XXXX-XXXX-XXXX-XXXX.
Features:
1. Does not modify the global random state.
2. Each X is an uppercase letter or digit (0-9).
3. Combines time factors to ensure high randomness.
For example: N1PQ-PRM5-N1BN-Z3S1-BGBJ
"""
# Save the current random state (does not affect external randomness)
original_state = random.getstate()
try:
# Define character set (uppercase letters + digits)
characters = string.ascii_uppercase + string.digits
# Create an independent random instance
local_random = random.Random(time.perf_counter_ns())
# Generate 5 groups of 4-character random strings
groups = []
for _ in range(5):
# Mix new time factor for each group
time_mix = int(datetime.now().timestamp())
local_random.seed(time_mix + local_random.getstate()[1][0] + time.perf_counter_ns())
groups.append("".join(local_random.choices(characters, k=4)))
return "-".join(groups)
finally:
# Restore the original random state
random.setstate(original_state)
if __name__ == "__main__":
url = "http://localhost:8000/v1/local/video/generate"
url = "http://localhost:8000/v1/tasks/"
message = {
"task_id": generate_task_id(), # task_id also can be string you like, such as "test_task_001"
"task_id_must_unique": True, # If True, the task_id must be unique, otherwise, it will raise an error. Default is False.
"prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "./assets/inputs/imgs/img_0.jpg",
"save_video_path": "./output_lightx2v_wan_i2v_t02.mp4", # It is best to set it to an absolute path.
"image_path": "http://xxx.xxx.xxx.xxx/img_0.jpg", # 图片地址
}
logger.info(f"message: {message}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment