Commit 99a6f046 authored by wangshankun's avatar wangshankun
Browse files

Merge branch 'main' of https://github.com/ModelTC/LightX2V into main

parents 8bdefedf 068a47db
import base64
import os
import re
import uuid
from pathlib import Path
from typing import Optional, Tuple
from loguru import logger
def is_base64_image(data: str) -> bool:
"""Check if a string is a base64-encoded image"""
if data.startswith("data:image/"):
return True
try:
if len(data) % 4 == 0:
base64.b64decode(data, validate=True)
decoded = base64.b64decode(data[:100])
if decoded.startswith(b"\x89PNG\r\n\x1a\n"):
return True
if decoded.startswith(b"\xff\xd8\xff"):
return True
if decoded.startswith(b"GIF87a") or decoded.startswith(b"GIF89a"):
return True
if decoded[8:12] == b"WEBP":
return True
except Exception as e:
logger.warning(f"Error checking base64 image: {e}")
return False
return False
def extract_base64_data(data: str) -> Tuple[str, Optional[str]]:
"""
Extract base64 data and format from a data URL or plain base64 string
Returns: (base64_data, format)
"""
if data.startswith("data:"):
match = re.match(r"data:image/(\w+);base64,(.+)", data)
if match:
format_type = match.group(1)
base64_data = match.group(2)
return base64_data, format_type
return data, None
def save_base64_image(base64_data: str, output_dir: str) -> str:
"""
Save a base64-encoded image to disk and return the file path
"""
Path(output_dir).mkdir(parents=True, exist_ok=True)
data, format_type = extract_base64_data(base64_data)
file_id = str(uuid.uuid4())
try:
image_data = base64.b64decode(data)
except Exception as e:
raise ValueError(f"Invalid base64 data: {e}")
if format_type:
ext = format_type
else:
if image_data.startswith(b"\x89PNG\r\n\x1a\n"):
ext = "png"
elif image_data.startswith(b"\xff\xd8\xff"):
ext = "jpg"
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
ext = "gif"
elif len(image_data) > 12 and image_data[8:12] == b"WEBP":
ext = "webp"
else:
ext = "png"
file_path = os.path.join(output_dir, f"{file_id}.{ext}")
with open(file_path, "wb") as f:
f.write(image_data)
return file_path
import sys
from pathlib import Path
import uvicorn
from loguru import logger
from .api import ApiServer
from .config import server_config
from .service import DistributedInferenceService
def run_server(args):
inference_service = None
try:
logger.info("Starting LightX2V server...")
if hasattr(args, "host") and args.host:
server_config.host = args.host
if hasattr(args, "port") and args.port:
server_config.port = args.port
if not server_config.validate():
raise RuntimeError("Invalid server configuration")
inference_service = DistributedInferenceService()
if not inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service")
logger.info("Inference service started successfully")
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
api_server = ApiServer(max_queue_size=server_config.max_queue_size)
api_server.initialize_services(cache_dir, inference_service)
app = api_server.get_app()
logger.info(f"Starting server on {server_config.host}:{server_config.port}")
uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
except KeyboardInterrupt:
logger.info("Server interrupted by user")
if inference_service:
inference_service.stop_distributed_inference()
except Exception as e:
logger.error(f"Server failed: {e}")
if inference_service:
inference_service.stop_distributed_inference()
sys.exit(1)
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
from ..utils.generate_task_id import generate_task_id
......@@ -11,7 +8,7 @@ class TaskRequest(BaseModel):
prompt: str = Field("", description="Generation prompt")
use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
negative_prompt: str = Field("", description="Negative prompt")
image_path: str = Field("", description="Input image path")
image_path: str = Field("", description="Base64 encoded image or URL")
num_fragments: int = Field(1, description="Number of fragments")
save_video_path: str = Field("", description="Save video path (optional, defaults to task_id.mp4)")
infer_steps: int = Field(5, description="Inference steps")
......@@ -22,7 +19,6 @@ class TaskRequest(BaseModel):
def __init__(self, **data):
super().__init__(**data)
# If save_video_path is empty, use task_id.mp4
if not self.save_video_path:
self.save_video_path = f"{self.task_id}.mp4"
......@@ -40,21 +36,6 @@ class TaskResponse(BaseModel):
save_video_path: str
class TaskResultResponse(BaseModel):
status: str
task_status: str
filename: Optional[str] = None
file_size: Optional[int] = None
download_url: Optional[str] = None
message: str
class ServiceStatusResponse(BaseModel):
service_status: str
task_id: Optional[str] = None
start_time: Optional[datetime] = None
class StopTaskResponse(BaseModel):
stop_status: str
reason: str
This diff is collapsed.
import threading
import uuid
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional
from loguru import logger
class TaskStatus(Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class TaskInfo:
task_id: str
status: TaskStatus
message: Any
start_time: datetime = field(default_factory=datetime.now)
end_time: Optional[datetime] = None
error: Optional[str] = None
save_video_path: Optional[str] = None
stop_event: threading.Event = field(default_factory=threading.Event)
thread: Optional[threading.Thread] = None
class TaskManager:
def __init__(self, max_queue_size: int = 100):
self.max_queue_size = max_queue_size
self._tasks: OrderedDict[str, TaskInfo] = OrderedDict()
self._lock = threading.RLock()
self._processing_lock = threading.Lock()
self._current_processing_task: Optional[str] = None
self.total_tasks = 0
self.completed_tasks = 0
self.failed_tasks = 0
def create_task(self, message: Any) -> str:
with self._lock:
if hasattr(message, "task_id") and message.task_id in self._tasks:
raise RuntimeError(f"Task ID {message.task_id} already exists")
active_tasks = sum(1 for t in self._tasks.values() if t.status in [TaskStatus.PENDING, TaskStatus.PROCESSING])
if active_tasks >= self.max_queue_size:
raise RuntimeError(f"Task queue is full (max {self.max_queue_size} tasks)")
task_id = getattr(message, "task_id", str(uuid.uuid4()))
task_info = TaskInfo(task_id=task_id, status=TaskStatus.PENDING, message=message, save_video_path=getattr(message, "save_video_path", None))
self._tasks[task_id] = task_info
self.total_tasks += 1
self._cleanup_old_tasks()
return task_id
def start_task(self, task_id: str) -> TaskInfo:
with self._lock:
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
task.status = TaskStatus.PROCESSING
task.start_time = datetime.now()
self._tasks.move_to_end(task_id)
return task
def complete_task(self, task_id: str, save_video_path: Optional[str] = None):
with self._lock:
if task_id not in self._tasks:
logger.warning(f"Task {task_id} not found for completion")
return
task = self._tasks[task_id]
task.status = TaskStatus.COMPLETED
task.end_time = datetime.now()
if save_video_path:
task.save_video_path = save_video_path
self.completed_tasks += 1
def fail_task(self, task_id: str, error: str):
with self._lock:
if task_id not in self._tasks:
logger.warning(f"Task {task_id} not found for failure")
return
task = self._tasks[task_id]
task.status = TaskStatus.FAILED
task.end_time = datetime.now()
task.error = error
self.failed_tasks += 1
def cancel_task(self, task_id: str) -> bool:
with self._lock:
if task_id not in self._tasks:
return False
task = self._tasks[task_id]
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
return False
task.stop_event.set()
task.status = TaskStatus.CANCELLED
task.end_time = datetime.now()
task.error = "Task cancelled by user"
if task.thread and task.thread.is_alive():
task.thread.join(timeout=5)
return True
def cancel_all_tasks(self):
with self._lock:
for task_id, task in list(self._tasks.items()):
if task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
self.cancel_task(task_id)
def get_task(self, task_id: str) -> Optional[TaskInfo]:
with self._lock:
return self._tasks.get(task_id)
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
task = self.get_task(task_id)
if not task:
return None
return {"task_id": task.task_id, "status": task.status.value, "start_time": task.start_time, "end_time": task.end_time, "error": task.error, "save_video_path": task.save_video_path}
def get_all_tasks(self):
with self._lock:
return {task_id: self.get_task_status(task_id) for task_id in self._tasks}
def get_active_task_count(self) -> int:
with self._lock:
return sum(1 for t in self._tasks.values() if t.status in [TaskStatus.PENDING, TaskStatus.PROCESSING])
def get_pending_task_count(self) -> int:
with self._lock:
return sum(1 for t in self._tasks.values() if t.status == TaskStatus.PENDING)
def is_processing(self) -> bool:
with self._lock:
return self._current_processing_task is not None
def acquire_processing_lock(self, task_id: str, timeout: Optional[float] = None) -> bool:
acquired = self._processing_lock.acquire(timeout=timeout if timeout else False)
if acquired:
with self._lock:
self._current_processing_task = task_id
logger.info(f"Task {task_id} acquired processing lock")
return acquired
def release_processing_lock(self, task_id: str):
with self._lock:
if self._current_processing_task == task_id:
self._current_processing_task = None
try:
self._processing_lock.release()
logger.info(f"Task {task_id} released processing lock")
except RuntimeError as e:
logger.warning(f"Task {task_id} tried to release lock but failed: {e}")
def get_next_pending_task(self) -> Optional[str]:
with self._lock:
for task_id, task in self._tasks.items():
if task.status == TaskStatus.PENDING:
return task_id
return None
def get_service_status(self) -> Dict[str, Any]:
with self._lock:
active_tasks = [task_id for task_id, task in self._tasks.items() if task.status == TaskStatus.PROCESSING]
pending_count = sum(1 for t in self._tasks.values() if t.status == TaskStatus.PENDING)
return {
"service_status": "busy" if self._current_processing_task else "idle",
"current_task": self._current_processing_task,
"active_tasks": active_tasks,
"pending_tasks": pending_count,
"queue_size": self.max_queue_size,
"total_tasks": self.total_tasks,
"completed_tasks": self.completed_tasks,
"failed_tasks": self.failed_tasks,
}
def _cleanup_old_tasks(self, keep_count: int = 1000):
if len(self._tasks) <= keep_count:
return
completed_tasks = [(task_id, task) for task_id, task in self._tasks.items() if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]]
completed_tasks.sort(key=lambda x: x[1].end_time or x[1].start_time)
remove_count = len(self._tasks) - keep_count
for task_id, _ in completed_tasks[:remove_count]:
del self._tasks[task_id]
logger.debug(f"Cleaned up old task: {task_id}")
task_manager = TaskManager()
......@@ -2,9 +2,6 @@ import base64
import io
import signal
import sys
import threading
from datetime import datetime
from typing import Optional
import psutil
import torch
......@@ -43,81 +40,6 @@ class TaskStatusMessage(BaseModel):
task_id: str
class ServiceStatus:
_lock = threading.Lock()
_current_task = None
_result_store = {}
@classmethod
def start_task(cls, message):
with cls._lock:
if cls._current_task is not None:
raise RuntimeError("Service busy")
if message.task_id in cls._result_store:
raise RuntimeError(f"Task ID {message.task_id} already exists")
cls._current_task = {"message": message, "start_time": datetime.now()}
return message.task_id
@classmethod
def complete_task(cls, message):
with cls._lock:
if cls._current_task:
cls._result_store[message.task_id] = {
"success": True,
"message": message,
"start_time": cls._current_task["start_time"],
"completion_time": datetime.now(),
"save_video_path": message.save_video_path,
}
cls._current_task = None
@classmethod
def record_failed_task(cls, message, error: Optional[str] = None):
with cls._lock:
if cls._current_task:
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error, "save_video_path": message.save_video_path}
cls._current_task = None
@classmethod
def clean_stopped_task(cls):
with cls._lock:
if cls._current_task:
message = cls._current_task["message"]
error = "Task stopped by user"
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error, "save_video_path": message.save_video_path}
cls._current_task = None
@classmethod
def get_status_task_id(cls, task_id: str):
with cls._lock:
if cls._current_task and cls._current_task["message"].task_id == task_id:
return {"status": "processing", "task_id": task_id}
if task_id in cls._result_store:
result = cls._result_store[task_id]
return {
"status": "completed" if result["success"] else "failed",
"task_id": task_id,
"success": result["success"],
"start_time": result["start_time"],
"completion_time": result.get("completion_time"),
"error": result.get("error"),
"save_video_path": result.get("save_video_path"),
}
return {"status": "not_found", "task_id": task_id}
@classmethod
def get_status_service(cls):
with cls._lock:
if cls._current_task:
return {"service_status": "busy", "task_id": cls._current_task["message"].task_id, "start_time": cls._current_task["start_time"]}
return {"service_status": "idle"}
@classmethod
def get_all_tasks(cls):
with cls._lock:
return cls._result_store
class TensorTransporter:
def __init__(self):
self.buffer = io.BytesIO()
......
import base64
import requests
from loguru import logger
def image_to_base64(image_path):
"""Convert an image file to base64 string"""
with open(image_path, "rb") as f:
image_data = f.read()
return base64.b64encode(image_data).decode("utf-8")
if __name__ == "__main__":
url = "http://localhost:8000/v1/tasks/"
message = {
"prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "assets/inputs/imgs/img_0.jpg", # 图片地址
"image_path": image_to_base64("assets/inputs/imgs/img_0.jpg"), # 图片地址
}
logger.info(f"message: {message}")
......
import base64
import os
import threading
import time
from typing import Any
import requests
from loguru import logger
from tqdm import tqdm
def image_to_base64(image_path):
"""Convert an image file to base64 string"""
with open(image_path, "rb") as f:
image_data = f.read()
return base64.b64encode(image_data).decode("utf-8")
def process_image_path(image_path) -> Any | str:
"""Process image_path: convert to base64 if local path, keep unchanged if HTTP link"""
if not image_path:
return image_path
if image_path.startswith(("http://", "https://")):
return image_path
if os.path.exists(image_path):
return image_to_base64(image_path)
else:
logger.warning(f"Image path not found: {image_path}")
return image_path
def send_and_monitor_task(url, message, task_index, complete_bar, complete_lock):
"""Send task to server and monitor until completion"""
try:
# Step 1: Send task and get task_id
if "image_path" in message and message["image_path"]:
message["image_path"] = process_image_path(message["image_path"])
response = requests.post(f"{url}/v1/tasks/", json=message)
response_data = response.json()
task_id = response_data.get("task_id")
......@@ -38,7 +65,6 @@ def send_and_monitor_task(url, message, task_index, complete_bar, complete_lock)
complete_bar.update(1) # Still update progress even if failed
return False
else:
# Task still running, wait and check again
time.sleep(0.5)
except Exception as e:
......@@ -91,7 +117,8 @@ def process_tasks_async(messages, available_urls, show_progress=True):
logger.info(f"Sending {len(messages)} tasks to available servers...")
# Create completion progress bar
complete_bar = None
complete_lock = None
if show_progress:
complete_bar = tqdm(total=len(messages), desc="Completing tasks")
complete_lock = threading.Lock() # Thread-safe updates to completion bar
......@@ -101,7 +128,7 @@ def process_tasks_async(messages, available_urls, show_progress=True):
server_url = find_idle_server(available_urls)
# Create and start thread for sending and monitoring task
thread = threading.Thread(target=send_and_monitor_task, args=(server_url, message, idx, complete_bar if show_progress else None, complete_lock if show_progress else None))
thread = threading.Thread(target=send_and_monitor_task, args=(server_url, message, idx, complete_bar, complete_lock))
thread.daemon = False
thread.start()
active_threads.append(thread)
......@@ -114,7 +141,7 @@ def process_tasks_async(messages, available_urls, show_progress=True):
thread.join()
# Close completion bar
if show_progress:
if complete_bar:
complete_bar.close()
logger.info("All tasks processing completed!")
......
import base64
from loguru import logger
from post_multi_servers import get_available_urls, process_tasks_async
def image_to_base64(image_path):
"""Convert an image file to base64 string"""
with open(image_path, "rb") as f:
image_data = f.read()
return base64.b64encode(image_data).decode("utf-8")
if __name__ == "__main__":
urls = [f"http://localhost:{port}" for port in range(8000, 8008)]
img_prompts = {
......@@ -11,7 +21,7 @@ if __name__ == "__main__":
messages = []
for i, (image_path, prompt) in enumerate(img_prompts.items()):
messages.append({"prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_path, "save_video_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"})
messages.append({"prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_to_base64(image_path), "save_video_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"})
logger.info(f"urls: {urls}")
......
......@@ -9,6 +9,9 @@ export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export ENABLE_GRAPH_MODE=false
export TORCH_CUDA_ARCH_LIST="9.0"
# Start API server with distributed inference service
python -m lightx2v.api_server \
--model_cls wan2.1_distill \
......
......@@ -12,6 +12,9 @@ source ${lightx2v_path}/scripts/base/base.sh
export TORCH_CUDA_ARCH_LIST="9.0"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export ENABLE_GRAPH_MODE=false
export SENSITIVE_LAYER_DTYPE=None
python -m lightx2v.infer \
--model_cls wan2.1_audio \
--task i2v \
......
......@@ -13,6 +13,7 @@ export TORCH_CUDA_ARCH_LIST="9.0"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export ENABLE_GRAPH_MODE=false
export SENSITIVE_LAYER_DTYPE=None
#for debugging
#export TORCH_NCCL_BLOCKING_WAIT=1 #启用 NCCL 阻塞等待模式(否则 watchdog 会杀死卡顿的进程)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment