Commit 99a6f046 authored by wangshankun's avatar wangshankun
Browse files

Merge branch 'main' of https://github.com/ModelTC/LightX2V into main

parents 8bdefedf 068a47db
import base64
import os
import re
import uuid
from pathlib import Path
from typing import Optional, Tuple
from loguru import logger
def is_base64_image(data: str) -> bool:
"""Check if a string is a base64-encoded image"""
if data.startswith("data:image/"):
return True
try:
if len(data) % 4 == 0:
base64.b64decode(data, validate=True)
decoded = base64.b64decode(data[:100])
if decoded.startswith(b"\x89PNG\r\n\x1a\n"):
return True
if decoded.startswith(b"\xff\xd8\xff"):
return True
if decoded.startswith(b"GIF87a") or decoded.startswith(b"GIF89a"):
return True
if decoded[8:12] == b"WEBP":
return True
except Exception as e:
logger.warning(f"Error checking base64 image: {e}")
return False
return False
def extract_base64_data(data: str) -> Tuple[str, Optional[str]]:
"""
Extract base64 data and format from a data URL or plain base64 string
Returns: (base64_data, format)
"""
if data.startswith("data:"):
match = re.match(r"data:image/(\w+);base64,(.+)", data)
if match:
format_type = match.group(1)
base64_data = match.group(2)
return base64_data, format_type
return data, None
def save_base64_image(base64_data: str, output_dir: str) -> str:
"""
Save a base64-encoded image to disk and return the file path
"""
Path(output_dir).mkdir(parents=True, exist_ok=True)
data, format_type = extract_base64_data(base64_data)
file_id = str(uuid.uuid4())
try:
image_data = base64.b64decode(data)
except Exception as e:
raise ValueError(f"Invalid base64 data: {e}")
if format_type:
ext = format_type
else:
if image_data.startswith(b"\x89PNG\r\n\x1a\n"):
ext = "png"
elif image_data.startswith(b"\xff\xd8\xff"):
ext = "jpg"
elif image_data.startswith(b"GIF87a") or image_data.startswith(b"GIF89a"):
ext = "gif"
elif len(image_data) > 12 and image_data[8:12] == b"WEBP":
ext = "webp"
else:
ext = "png"
file_path = os.path.join(output_dir, f"{file_id}.{ext}")
with open(file_path, "wb") as f:
f.write(image_data)
return file_path
import sys
from pathlib import Path
import uvicorn
from loguru import logger
from .api import ApiServer
from .config import server_config
from .service import DistributedInferenceService
def run_server(args):
inference_service = None
try:
logger.info("Starting LightX2V server...")
if hasattr(args, "host") and args.host:
server_config.host = args.host
if hasattr(args, "port") and args.port:
server_config.port = args.port
if not server_config.validate():
raise RuntimeError("Invalid server configuration")
inference_service = DistributedInferenceService()
if not inference_service.start_distributed_inference(args):
raise RuntimeError("Failed to start distributed inference service")
logger.info("Inference service started successfully")
cache_dir = Path(server_config.cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
api_server = ApiServer(max_queue_size=server_config.max_queue_size)
api_server.initialize_services(cache_dir, inference_service)
app = api_server.get_app()
logger.info(f"Starting server on {server_config.host}:{server_config.port}")
uvicorn.run(app, host=server_config.host, port=server_config.port, log_level="info")
except KeyboardInterrupt:
logger.info("Server interrupted by user")
if inference_service:
inference_service.stop_distributed_inference()
except Exception as e:
logger.error(f"Server failed: {e}")
if inference_service:
inference_service.stop_distributed_inference()
sys.exit(1)
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
from ..utils.generate_task_id import generate_task_id
......@@ -11,7 +8,7 @@ class TaskRequest(BaseModel):
prompt: str = Field("", description="Generation prompt")
use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
negative_prompt: str = Field("", description="Negative prompt")
image_path: str = Field("", description="Input image path")
image_path: str = Field("", description="Base64 encoded image or URL")
num_fragments: int = Field(1, description="Number of fragments")
save_video_path: str = Field("", description="Save video path (optional, defaults to task_id.mp4)")
infer_steps: int = Field(5, description="Inference steps")
......@@ -22,7 +19,6 @@ class TaskRequest(BaseModel):
def __init__(self, **data):
super().__init__(**data)
# If save_video_path is empty, use task_id.mp4
if not self.save_video_path:
self.save_video_path = f"{self.task_id}.mp4"
......@@ -40,21 +36,6 @@ class TaskResponse(BaseModel):
save_video_path: str
class TaskResultResponse(BaseModel):
status: str
task_status: str
filename: Optional[str] = None
file_size: Optional[int] = None
download_url: Optional[str] = None
message: str
class ServiceStatusResponse(BaseModel):
service_status: str
task_id: Optional[str] = None
start_time: Optional[datetime] = None
class StopTaskResponse(BaseModel):
stop_status: str
reason: str
import queue
import asyncio
import threading
import time
import uuid
from pathlib import Path
......@@ -11,9 +12,11 @@ from loguru import logger
from ..infer import init_runner
from ..utils.set_config import set_config
from .audio_utils import is_base64_audio, save_base64_audio
from .config import server_config
from .distributed_utils import create_distributed_worker
from .image_utils import is_base64_image, save_base64_image
from .schema import TaskRequest, TaskResponse
from .utils import ServiceStatus
mp.set_start_method("spawn", force=True)
......@@ -25,7 +28,13 @@ class FileService:
self.input_audio_dir = cache_dir / "inputs" / "audios"
self.output_video_dir = cache_dir / "outputs"
# Create directories
self._http_client = None
self._client_lock = asyncio.Lock()
self.max_retries = 3
self.retry_delay = 1.0
self.max_retry_delay = 10.0
for directory in [
self.input_image_dir,
self.output_video_dir,
......@@ -33,17 +42,74 @@ class FileService:
]:
directory.mkdir(parents=True, exist_ok=True)
async def _get_http_client(self) -> httpx.AsyncClient:
"""Get or create a persistent HTTP client with connection pooling."""
async with self._client_lock:
if self._http_client is None or self._http_client.is_closed:
timeout = httpx.Timeout(
connect=10.0,
read=30.0,
write=10.0,
pool=5.0,
)
limits = httpx.Limits(max_keepalive_connections=5, max_connections=10, keepalive_expiry=30.0)
self._http_client = httpx.AsyncClient(verify=False, timeout=timeout, limits=limits, follow_redirects=True)
return self._http_client
async def _download_with_retry(self, url: str, max_retries: Optional[int] = None) -> httpx.Response:
"""Download with exponential backoff retry logic."""
if max_retries is None:
max_retries = self.max_retries
last_exception = None
retry_delay = self.retry_delay
for attempt in range(max_retries):
try:
client = await self._get_http_client()
response = await client.get(url)
if response.status_code == 200:
return response
elif response.status_code >= 500:
logger.warning(f"Server error {response.status_code} for {url}, attempt {attempt + 1}/{max_retries}")
last_exception = httpx.HTTPStatusError(f"Server returned {response.status_code}", request=response.request, response=response)
else:
raise httpx.HTTPStatusError(f"Client error {response.status_code}", request=response.request, response=response)
except (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError) as e:
logger.warning(f"Connection error for {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
last_exception = e
except httpx.HTTPStatusError as e:
if e.response and e.response.status_code < 500:
raise
last_exception = e
except Exception as e:
logger.error(f"Unexpected error downloading {url}: {str(e)}")
last_exception = e
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
retry_delay = min(retry_delay * 2, self.max_retry_delay)
error_msg = f"All {max_retries} connection attempts failed for {url}"
if last_exception:
error_msg += f": {str(last_exception)}"
raise httpx.ConnectError(error_msg)
async def download_image(self, image_url: str) -> Path:
"""Download image with retry logic and proper error handling."""
try:
async with httpx.AsyncClient(verify=False) as client:
response = await client.get(image_url)
parsed_url = urlparse(image_url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL format: {image_url}")
if response.status_code != 200:
raise ValueError(f"Failed to download image from {image_url}")
response = await self._download_with_retry(image_url)
image_name = Path(urlparse(image_url).path).name
image_name = Path(parsed_url.path).name
if not image_name:
raise ValueError(f"Invalid image URL: {image_url}")
image_name = f"{uuid.uuid4()}.jpg"
image_path = self.input_image_dir / image_name
image_path.parent.mkdir(parents=True, exist_ok=True)
......@@ -51,10 +117,60 @@ class FileService:
with open(image_path, "wb") as f:
f.write(response.content)
logger.info(f"Successfully downloaded image from {image_url} to {image_path}")
return image_path
except httpx.ConnectError as e:
logger.error(f"Connection error downloading image from {image_url}: {str(e)}")
raise ValueError(f"Failed to connect to {image_url}: {str(e)}")
except httpx.TimeoutException as e:
logger.error(f"Timeout downloading image from {image_url}: {str(e)}")
raise ValueError(f"Download timeout for {image_url}: {str(e)}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error downloading image from {image_url}: {str(e)}")
raise ValueError(f"HTTP error for {image_url}: {str(e)}")
except ValueError as e:
raise
except Exception as e:
logger.error(f"Failed to download image: {e}")
logger.error(f"Unexpected error downloading image from {image_url}: {str(e)}")
raise ValueError(f"Failed to download image from {image_url}: {str(e)}")
async def download_audio(self, audio_url: str) -> Path:
"""Download audio with retry logic and proper error handling."""
try:
parsed_url = urlparse(audio_url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL format: {audio_url}")
response = await self._download_with_retry(audio_url)
audio_name = Path(parsed_url.path).name
if not audio_name:
audio_name = f"{uuid.uuid4()}.mp3"
audio_path = self.input_audio_dir / audio_name
audio_path.parent.mkdir(parents=True, exist_ok=True)
with open(audio_path, "wb") as f:
f.write(response.content)
logger.info(f"Successfully downloaded audio from {audio_url} to {audio_path}")
return audio_path
except httpx.ConnectError as e:
logger.error(f"Connection error downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"Failed to connect to {audio_url}: {str(e)}")
except httpx.TimeoutException as e:
logger.error(f"Timeout downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"Download timeout for {audio_url}: {str(e)}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"HTTP error for {audio_url}: {str(e)}")
except ValueError as e:
raise
except Exception as e:
logger.error(f"Unexpected error downloading audio from {audio_url}: {str(e)}")
raise ValueError(f"Failed to download audio from {audio_url}: {str(e)}")
def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
file_extension = Path(filename).suffix
......@@ -72,20 +188,25 @@ class FileService:
return self.output_video_dir / save_video_path
return video_path
async def cleanup(self):
"""Cleanup resources including HTTP client."""
async with self._client_lock:
if self._http_client and not self._http_client.is_closed:
await self._http_client.aclose()
self._http_client = None
def _distributed_inference_worker(rank, world_size, master_addr, master_port, args, task_queue, result_queue):
def _distributed_inference_worker(rank, world_size, master_addr, master_port, args, shared_data, task_event, result_event):
task_data = None
worker = None
try:
logger.info(f"Process {rank}/{world_size - 1} initializing distributed inference service...")
# Create and initialize distributed worker process
worker = create_distributed_worker(rank, world_size, master_addr, master_port)
if not worker.init():
raise RuntimeError(f"Rank {rank} distributed environment initialization failed")
# Initialize configuration and model
config = set_config(args)
logger.info(f"Rank {rank} config: {config}")
......@@ -93,80 +214,88 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")
while True:
# Only rank=0 reads tasks from queue
if not task_event.wait(timeout=1.0):
continue
if rank == 0:
try:
task_data = task_queue.get(timeout=1.0)
if task_data is None: # Stop signal
logger.info(f"Process {rank} received stop signal, exiting inference service")
# Broadcast stop signal to other processes
worker.dist_manager.broadcast_task_data(None)
break
# Broadcast task data to other processes
if shared_data.get("stop", False):
logger.info(f"Process {rank} received stop signal, exiting inference service")
worker.dist_manager.broadcast_task_data(None)
break
task_data = shared_data.get("current_task")
if task_data:
worker.dist_manager.broadcast_task_data(task_data)
except queue.Empty:
# Queue is empty, continue waiting
else:
continue
else:
# Non-rank=0 processes receive task data from rank=0
task_data = worker.dist_manager.broadcast_task_data()
if task_data is None: # Stop signal
if task_data is None:
logger.info(f"Process {rank} received stop signal, exiting inference service")
break
# All processes handle the task
if task_data is not None:
logger.info(f"Process {rank} received inference task: {task_data['task_id']}")
try:
# Set inputs and run inference
runner.set_inputs(task_data) # type: ignore
runner.run_pipeline()
# Synchronize and report results
worker.sync_and_report(
task_data["task_id"],
"success",
result_queue,
save_video_path=task_data["save_video_path"],
message="Inference completed",
)
worker.dist_manager.barrier()
if rank == 0:
# Only rank 0 updates the result
shared_data["result"] = {
"task_id": task_data["task_id"],
"status": "success",
"save_video_path": task_data.get("video_path", task_data["save_video_path"]), # Return original path for API
"message": "Inference completed",
}
result_event.set()
logger.info(f"Task {task_data['task_id']} success")
except Exception as e:
logger.error(f"Process {rank} error occurred while processing task: {str(e)}")
# Synchronize and report error
worker.sync_and_report(
task_data.get("task_id", "unknown"),
"failed",
result_queue,
error=str(e),
message=f"Inference failed: {str(e)}",
)
worker.dist_manager.barrier()
if rank == 0:
# Only rank 0 updates the result
shared_data["result"] = {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Inference failed: {str(e)}",
}
result_event.set()
logger.info(f"Task {task_data.get('task_id', 'unknown')} failed")
except KeyboardInterrupt:
logger.info(f"Process {rank} received KeyboardInterrupt, gracefully exiting")
except Exception as e:
logger.error(f"Distributed inference service process {rank} startup failed: {str(e)}")
logger.exception(f"Distributed inference service process {rank} startup failed: {str(e)}")
if rank == 0:
error_result = {
shared_data["result"] = {
"task_id": "startup",
"status": "startup_failed",
"error": str(e),
"message": f"Inference service startup failed: {str(e)}",
}
result_queue.put(error_result)
result_event.set()
finally:
try:
if worker:
worker.cleanup()
except: # noqa: E722
pass
except Exception as e:
logger.debug(f"Error cleaning up worker for rank {rank}: {e}")
class DistributedInferenceService:
def __init__(self):
self.task_queue = None
self.result_queue = None
self.manager = None
self.shared_data = None
self.task_event = None
self.result_event = None
self.processes = []
self.is_running = False
......@@ -188,17 +317,21 @@ class DistributedInferenceService:
return False
try:
import random
master_addr = "127.0.0.1"
master_port = str(random.randint(20000, 29999))
master_addr = server_config.master_addr
master_port = server_config.find_free_master_port()
logger.info(f"Distributed inference service Master Addr: {master_addr}, Master Port: {master_port}")
# Create shared queues
self.task_queue = mp.Queue()
self.result_queue = mp.Queue()
# Create shared data structures
self.manager = mp.Manager()
self.shared_data = self.manager.dict()
self.task_event = self.manager.Event()
self.result_event = self.manager.Event()
# Initialize shared data
self.shared_data["current_task"] = None
self.shared_data["result"] = None
self.shared_data["stop"] = False
# Start processes
for rank in range(nproc_per_node):
p = mp.Process(
target=_distributed_inference_worker,
......@@ -208,10 +341,11 @@ class DistributedInferenceService:
master_addr,
master_port,
args,
self.task_queue,
self.result_queue,
self.shared_data,
self.task_event,
self.result_event,
),
daemon=True,
daemon=False, # Changed to False for proper cleanup
)
p.start()
self.processes.append(p)
......@@ -226,18 +360,19 @@ class DistributedInferenceService:
return False
def stop_distributed_inference(self):
assert self.task_event, "Task event is not initialized"
assert self.result_event, "Result event is not initialized"
if not self.is_running:
return
try:
logger.info(f"Stopping {len(self.processes)} distributed inference service processes...")
# Send stop signal
if self.task_queue:
for _ in self.processes:
self.task_queue.put(None)
if self.shared_data is not None:
self.shared_data["stop"] = True
self.task_event.set()
# Wait for processes to end
for p in self.processes:
try:
p.join(timeout=10)
......@@ -245,8 +380,8 @@ class DistributedInferenceService:
logger.warning(f"Process {p.pid} did not end within the specified time, forcing termination...")
p.terminate()
p.join(timeout=5)
except: # noqa: E722
pass
except Exception as e:
logger.warning(f"Error terminating process {p.pid}: {e}")
logger.info("All distributed inference service processes have stopped")
......@@ -254,52 +389,76 @@ class DistributedInferenceService:
logger.error(f"Error occurred while stopping distributed inference service: {str(e)}")
finally:
# Clean up resources
self._clean_queues()
self.processes = []
self.task_queue = None
self.result_queue = None
self.manager = None
self.shared_data = None
self.task_event = None
self.result_event = None
self.is_running = False
def _clean_queues(self):
for queue_obj in [self.task_queue, self.result_queue]:
if queue_obj:
try:
while not queue_obj.empty():
queue_obj.get_nowait()
except: # noqa: E722
pass
def submit_task(self, task_data: dict) -> bool:
if not self.is_running or not self.task_queue:
assert self.task_event, "Task event is not initialized"
assert self.result_event, "Result event is not initialized"
if not self.is_running or not self.shared_data:
logger.error("Distributed inference service is not started")
return False
try:
self.task_queue.put(task_data)
self.result_event.clear()
self.shared_data["result"] = None
self.shared_data["current_task"] = task_data
self.task_event.set() # Signal workers
return True
except Exception as e:
logger.error(f"Failed to submit task: {str(e)}")
return False
def wait_for_result(self, task_id: str, timeout: int = 300) -> Optional[dict]:
if not self.is_running or not self.result_queue:
def wait_for_result(self, task_id: str, timeout: Optional[int] = None) -> Optional[dict]:
assert self.task_event, "Task event is not initialized"
assert self.result_event, "Result event is not initialized"
if timeout is None:
timeout = server_config.task_timeout
if not self.is_running or not self.shared_data:
return None
if self.result_event.wait(timeout=timeout):
result = self.shared_data.get("result")
if result and result.get("task_id") == task_id:
self.shared_data["current_task"] = None
self.task_event.clear()
return result
return None
def wait_for_result_with_stop(self, task_id: str, stop_event: threading.Event, timeout: Optional[int] = None) -> Optional[dict]:
if timeout is None:
timeout = server_config.task_timeout
if not self.is_running or not self.shared_data:
return None
assert self.task_event, "Task event is not initialized"
assert self.result_event, "Result event is not initialized"
start_time = time.time()
while time.time() - start_time < timeout:
try:
result = self.result_queue.get(timeout=1.0)
if result.get("task_id") == task_id:
if stop_event.is_set():
logger.info(f"Task {task_id} stop event triggered during wait")
self.shared_data["current_task"] = None
self.task_event.clear()
return None
if self.result_event.wait(timeout=0.5):
result = self.shared_data.get("result")
if result and result.get("task_id") == task_id:
self.shared_data["current_task"] = None
self.task_event.clear()
return result
else:
# Not the result for current task, put back in queue
self.result_queue.put(result)
time.sleep(0.1)
except queue.Empty:
continue
return None
......@@ -313,39 +472,60 @@ class VideoGenerationService:
self.file_service = file_service
self.inference_service = inference_service
async def generate_video(self, message: TaskRequest) -> TaskResponse:
async def generate_video_with_stop_event(self, message: TaskRequest, stop_event) -> Optional[TaskResponse]:
try:
task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
task_data["task_id"] = message.task_id
if "image_path" in message.model_fields_set and message.image_path.startswith("http"):
image_path = await self.file_service.download_image(message.image_path)
task_data["image_path"] = str(image_path)
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled before processing")
return None
if "image_path" in message.model_fields_set and message.image_path:
if message.image_path.startswith("http"):
image_path = await self.file_service.download_image(message.image_path)
task_data["image_path"] = str(image_path)
elif is_base64_image(message.image_path):
image_path = save_base64_image(message.image_path, str(self.file_service.input_image_dir))
task_data["image_path"] = str(image_path)
else:
task_data["image_path"] = message.image_path
if "audio_path" in message.model_fields_set and message.audio_path:
if message.audio_path.startswith("http"):
audio_path = await self.file_service.download_audio(message.audio_path)
task_data["audio_path"] = str(audio_path)
elif is_base64_audio(message.audio_path):
audio_path = save_base64_audio(message.audio_path, str(self.file_service.input_audio_dir))
task_data["audio_path"] = str(audio_path)
else:
task_data["audio_path"] = message.audio_path
save_video_path = self.file_service.get_output_path(message.save_video_path)
task_data["save_video_path"] = str(save_video_path)
actual_save_path = self.file_service.get_output_path(message.save_video_path)
task_data["save_video_path"] = str(actual_save_path)
task_data["video_path"] = message.save_video_path
if not self.inference_service.submit_task(task_data):
raise RuntimeError("Distributed inference service is not started")
result = self.inference_service.wait_for_result(message.task_id)
result = self.inference_service.wait_for_result_with_stop(message.task_id, stop_event, timeout=300)
if result is None:
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled during processing")
return None
raise RuntimeError("Task processing timeout")
if result.get("status") == "success":
ServiceStatus.complete_task(message)
return TaskResponse(
task_id=message.task_id,
task_status="completed",
save_video_path=str(save_video_path),
save_video_path=message.save_video_path, # Return original path
)
else:
error_msg = result.get("error", "Inference failed")
ServiceStatus.record_failed_task(message, error=error_msg)
raise RuntimeError(error_msg)
except Exception as e:
logger.error(f"Task {message.task_id} processing failed: {str(e)}")
ServiceStatus.record_failed_task(message, error=str(e))
raise
import threading
import uuid
from collections import OrderedDict
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional
from loguru import logger
class TaskStatus(Enum):
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class TaskInfo:
task_id: str
status: TaskStatus
message: Any
start_time: datetime = field(default_factory=datetime.now)
end_time: Optional[datetime] = None
error: Optional[str] = None
save_video_path: Optional[str] = None
stop_event: threading.Event = field(default_factory=threading.Event)
thread: Optional[threading.Thread] = None
class TaskManager:
def __init__(self, max_queue_size: int = 100):
self.max_queue_size = max_queue_size
self._tasks: OrderedDict[str, TaskInfo] = OrderedDict()
self._lock = threading.RLock()
self._processing_lock = threading.Lock()
self._current_processing_task: Optional[str] = None
self.total_tasks = 0
self.completed_tasks = 0
self.failed_tasks = 0
def create_task(self, message: Any) -> str:
with self._lock:
if hasattr(message, "task_id") and message.task_id in self._tasks:
raise RuntimeError(f"Task ID {message.task_id} already exists")
active_tasks = sum(1 for t in self._tasks.values() if t.status in [TaskStatus.PENDING, TaskStatus.PROCESSING])
if active_tasks >= self.max_queue_size:
raise RuntimeError(f"Task queue is full (max {self.max_queue_size} tasks)")
task_id = getattr(message, "task_id", str(uuid.uuid4()))
task_info = TaskInfo(task_id=task_id, status=TaskStatus.PENDING, message=message, save_video_path=getattr(message, "save_video_path", None))
self._tasks[task_id] = task_info
self.total_tasks += 1
self._cleanup_old_tasks()
return task_id
def start_task(self, task_id: str) -> TaskInfo:
with self._lock:
if task_id not in self._tasks:
raise KeyError(f"Task {task_id} not found")
task = self._tasks[task_id]
task.status = TaskStatus.PROCESSING
task.start_time = datetime.now()
self._tasks.move_to_end(task_id)
return task
def complete_task(self, task_id: str, save_video_path: Optional[str] = None):
with self._lock:
if task_id not in self._tasks:
logger.warning(f"Task {task_id} not found for completion")
return
task = self._tasks[task_id]
task.status = TaskStatus.COMPLETED
task.end_time = datetime.now()
if save_video_path:
task.save_video_path = save_video_path
self.completed_tasks += 1
def fail_task(self, task_id: str, error: str):
with self._lock:
if task_id not in self._tasks:
logger.warning(f"Task {task_id} not found for failure")
return
task = self._tasks[task_id]
task.status = TaskStatus.FAILED
task.end_time = datetime.now()
task.error = error
self.failed_tasks += 1
def cancel_task(self, task_id: str) -> bool:
with self._lock:
if task_id not in self._tasks:
return False
task = self._tasks[task_id]
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]:
return False
task.stop_event.set()
task.status = TaskStatus.CANCELLED
task.end_time = datetime.now()
task.error = "Task cancelled by user"
if task.thread and task.thread.is_alive():
task.thread.join(timeout=5)
return True
def cancel_all_tasks(self):
with self._lock:
for task_id, task in list(self._tasks.items()):
if task.status in [TaskStatus.PENDING, TaskStatus.PROCESSING]:
self.cancel_task(task_id)
def get_task(self, task_id: str) -> Optional[TaskInfo]:
with self._lock:
return self._tasks.get(task_id)
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
task = self.get_task(task_id)
if not task:
return None
return {"task_id": task.task_id, "status": task.status.value, "start_time": task.start_time, "end_time": task.end_time, "error": task.error, "save_video_path": task.save_video_path}
def get_all_tasks(self):
with self._lock:
return {task_id: self.get_task_status(task_id) for task_id in self._tasks}
def get_active_task_count(self) -> int:
with self._lock:
return sum(1 for t in self._tasks.values() if t.status in [TaskStatus.PENDING, TaskStatus.PROCESSING])
def get_pending_task_count(self) -> int:
with self._lock:
return sum(1 for t in self._tasks.values() if t.status == TaskStatus.PENDING)
def is_processing(self) -> bool:
with self._lock:
return self._current_processing_task is not None
def acquire_processing_lock(self, task_id: str, timeout: Optional[float] = None) -> bool:
acquired = self._processing_lock.acquire(timeout=timeout if timeout else False)
if acquired:
with self._lock:
self._current_processing_task = task_id
logger.info(f"Task {task_id} acquired processing lock")
return acquired
def release_processing_lock(self, task_id: str):
with self._lock:
if self._current_processing_task == task_id:
self._current_processing_task = None
try:
self._processing_lock.release()
logger.info(f"Task {task_id} released processing lock")
except RuntimeError as e:
logger.warning(f"Task {task_id} tried to release lock but failed: {e}")
def get_next_pending_task(self) -> Optional[str]:
with self._lock:
for task_id, task in self._tasks.items():
if task.status == TaskStatus.PENDING:
return task_id
return None
def get_service_status(self) -> Dict[str, Any]:
with self._lock:
active_tasks = [task_id for task_id, task in self._tasks.items() if task.status == TaskStatus.PROCESSING]
pending_count = sum(1 for t in self._tasks.values() if t.status == TaskStatus.PENDING)
return {
"service_status": "busy" if self._current_processing_task else "idle",
"current_task": self._current_processing_task,
"active_tasks": active_tasks,
"pending_tasks": pending_count,
"queue_size": self.max_queue_size,
"total_tasks": self.total_tasks,
"completed_tasks": self.completed_tasks,
"failed_tasks": self.failed_tasks,
}
def _cleanup_old_tasks(self, keep_count: int = 1000):
if len(self._tasks) <= keep_count:
return
completed_tasks = [(task_id, task) for task_id, task in self._tasks.items() if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]]
completed_tasks.sort(key=lambda x: x[1].end_time or x[1].start_time)
remove_count = len(self._tasks) - keep_count
for task_id, _ in completed_tasks[:remove_count]:
del self._tasks[task_id]
logger.debug(f"Cleaned up old task: {task_id}")
task_manager = TaskManager()
......@@ -2,9 +2,6 @@ import base64
import io
import signal
import sys
import threading
from datetime import datetime
from typing import Optional
import psutil
import torch
......@@ -43,81 +40,6 @@ class TaskStatusMessage(BaseModel):
task_id: str
class ServiceStatus:
_lock = threading.Lock()
_current_task = None
_result_store = {}
@classmethod
def start_task(cls, message):
with cls._lock:
if cls._current_task is not None:
raise RuntimeError("Service busy")
if message.task_id in cls._result_store:
raise RuntimeError(f"Task ID {message.task_id} already exists")
cls._current_task = {"message": message, "start_time": datetime.now()}
return message.task_id
@classmethod
def complete_task(cls, message):
with cls._lock:
if cls._current_task:
cls._result_store[message.task_id] = {
"success": True,
"message": message,
"start_time": cls._current_task["start_time"],
"completion_time": datetime.now(),
"save_video_path": message.save_video_path,
}
cls._current_task = None
@classmethod
def record_failed_task(cls, message, error: Optional[str] = None):
with cls._lock:
if cls._current_task:
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error, "save_video_path": message.save_video_path}
cls._current_task = None
@classmethod
def clean_stopped_task(cls):
with cls._lock:
if cls._current_task:
message = cls._current_task["message"]
error = "Task stopped by user"
cls._result_store[message.task_id] = {"success": False, "message": message, "start_time": cls._current_task["start_time"], "error": error, "save_video_path": message.save_video_path}
cls._current_task = None
@classmethod
def get_status_task_id(cls, task_id: str):
with cls._lock:
if cls._current_task and cls._current_task["message"].task_id == task_id:
return {"status": "processing", "task_id": task_id}
if task_id in cls._result_store:
result = cls._result_store[task_id]
return {
"status": "completed" if result["success"] else "failed",
"task_id": task_id,
"success": result["success"],
"start_time": result["start_time"],
"completion_time": result.get("completion_time"),
"error": result.get("error"),
"save_video_path": result.get("save_video_path"),
}
return {"status": "not_found", "task_id": task_id}
@classmethod
def get_status_service(cls):
with cls._lock:
if cls._current_task:
return {"service_status": "busy", "task_id": cls._current_task["message"].task_id, "start_time": cls._current_task["start_time"]}
return {"service_status": "idle"}
@classmethod
def get_all_tasks(cls):
with cls._lock:
return cls._result_store
class TensorTransporter:
def __init__(self):
self.buffer = io.BytesIO()
......
import base64
import requests
from loguru import logger
def image_to_base64(image_path):
"""Convert an image file to base64 string"""
with open(image_path, "rb") as f:
image_data = f.read()
return base64.b64encode(image_data).decode("utf-8")
if __name__ == "__main__":
url = "http://localhost:8000/v1/tasks/"
message = {
"prompt": "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside.",
"negative_prompt": "镜头晃动,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走",
"image_path": "assets/inputs/imgs/img_0.jpg", # 图片地址
"image_path": image_to_base64("assets/inputs/imgs/img_0.jpg"), # 图片地址
}
logger.info(f"message: {message}")
......
import base64
import os
import threading
import time
from typing import Any
import requests
from loguru import logger
from tqdm import tqdm
def image_to_base64(image_path):
"""Convert an image file to base64 string"""
with open(image_path, "rb") as f:
image_data = f.read()
return base64.b64encode(image_data).decode("utf-8")
def process_image_path(image_path) -> Any | str:
"""Process image_path: convert to base64 if local path, keep unchanged if HTTP link"""
if not image_path:
return image_path
if image_path.startswith(("http://", "https://")):
return image_path
if os.path.exists(image_path):
return image_to_base64(image_path)
else:
logger.warning(f"Image path not found: {image_path}")
return image_path
def send_and_monitor_task(url, message, task_index, complete_bar, complete_lock):
"""Send task to server and monitor until completion"""
try:
# Step 1: Send task and get task_id
if "image_path" in message and message["image_path"]:
message["image_path"] = process_image_path(message["image_path"])
response = requests.post(f"{url}/v1/tasks/", json=message)
response_data = response.json()
task_id = response_data.get("task_id")
......@@ -38,7 +65,6 @@ def send_and_monitor_task(url, message, task_index, complete_bar, complete_lock)
complete_bar.update(1) # Still update progress even if failed
return False
else:
# Task still running, wait and check again
time.sleep(0.5)
except Exception as e:
......@@ -91,7 +117,8 @@ def process_tasks_async(messages, available_urls, show_progress=True):
logger.info(f"Sending {len(messages)} tasks to available servers...")
# Create completion progress bar
complete_bar = None
complete_lock = None
if show_progress:
complete_bar = tqdm(total=len(messages), desc="Completing tasks")
complete_lock = threading.Lock() # Thread-safe updates to completion bar
......@@ -101,7 +128,7 @@ def process_tasks_async(messages, available_urls, show_progress=True):
server_url = find_idle_server(available_urls)
# Create and start thread for sending and monitoring task
thread = threading.Thread(target=send_and_monitor_task, args=(server_url, message, idx, complete_bar if show_progress else None, complete_lock if show_progress else None))
thread = threading.Thread(target=send_and_monitor_task, args=(server_url, message, idx, complete_bar, complete_lock))
thread.daemon = False
thread.start()
active_threads.append(thread)
......@@ -114,7 +141,7 @@ def process_tasks_async(messages, available_urls, show_progress=True):
thread.join()
# Close completion bar
if show_progress:
if complete_bar:
complete_bar.close()
logger.info("All tasks processing completed!")
......
import base64
from loguru import logger
from post_multi_servers import get_available_urls, process_tasks_async
def image_to_base64(image_path):
"""Convert an image file to base64 string"""
with open(image_path, "rb") as f:
image_data = f.read()
return base64.b64encode(image_data).decode("utf-8")
if __name__ == "__main__":
urls = [f"http://localhost:{port}" for port in range(8000, 8008)]
img_prompts = {
......@@ -11,7 +21,7 @@ if __name__ == "__main__":
messages = []
for i, (image_path, prompt) in enumerate(img_prompts.items()):
messages.append({"prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_path, "save_video_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"})
messages.append({"prompt": prompt, "negative_prompt": negative_prompt, "image_path": image_to_base64(image_path), "save_video_path": f"./output_lightx2v_wan_i2v_{i + 1}.mp4"})
logger.info(f"urls: {urls}")
......
......@@ -9,6 +9,9 @@ export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
export ENABLE_GRAPH_MODE=false
export TORCH_CUDA_ARCH_LIST="9.0"
# Start API server with distributed inference service
python -m lightx2v.api_server \
--model_cls wan2.1_distill \
......
......@@ -12,6 +12,9 @@ source ${lightx2v_path}/scripts/base/base.sh
export TORCH_CUDA_ARCH_LIST="9.0"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export ENABLE_GRAPH_MODE=false
export SENSITIVE_LAYER_DTYPE=None
python -m lightx2v.infer \
--model_cls wan2.1_audio \
--task i2v \
......
......@@ -13,6 +13,7 @@ export TORCH_CUDA_ARCH_LIST="9.0"
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
export ENABLE_GRAPH_MODE=false
export SENSITIVE_LAYER_DTYPE=None
#for debugging
#export TORCH_NCCL_BLOCKING_WAIT=1 #启用 NCCL 阻塞等待模式(否则 watchdog 会杀死卡顿的进程)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment