Unverified Commit dea872a2 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

Api image (#515)

parent 1892a3db
from typing import Dict
from .base import MediaHandler
class ImageHandler(MediaHandler):
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def get_media_signatures(self) -> Dict[bytes, str]:
return {
b"\x89PNG\r\n\x1a\n": "png",
b"\xff\xd8\xff": "jpg",
b"GIF87a": "gif",
b"GIF89a": "gif",
}
def get_data_url_prefix(self) -> str:
return "data:image/"
def get_data_url_pattern(self) -> str:
return r"data:image/(\w+);base64,(.+)"
def get_default_extension(self) -> str:
return "png"
def is_base64(self, data: str) -> bool:
if data.startswith(self.get_data_url_prefix()):
return True
try:
import base64
if len(data) % 4 == 0:
base64.b64decode(data, validate=True)
decoded = base64.b64decode(data[:100])
for signature in self.get_media_signatures().keys():
if decoded.startswith(signature):
return True
if len(decoded) > 12 and decoded[8:12] == b"WEBP":
return True
except Exception:
return False
return False
def detect_extension(self, data: bytes) -> str:
for signature, ext in self.get_media_signatures().items():
if data.startswith(signature):
return ext
if len(data) > 12 and data[8:12] == b"WEBP":
return "webp"
return self.get_default_extension()
_handler = ImageHandler()
def is_base64_image(data: str) -> bool:
return _handler.is_base64(data)
def save_base64_image(base64_data: str, output_dir: str) -> str:
return _handler.save_base64(base64_data, output_dir)
#!/usr/bin/env python
"""Example script to run the LightX2V server."""
import argparse
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from lightx2v.server.main import run_server
def main():
parser = argparse.ArgumentParser(description="Run LightX2V inference server")
parser.add_argument("--model_path", type=str, required=True, help="Path to model")
parser.add_argument("--model_cls", type=str, required=True, help="Model class name")
parser.add_argument("--config_json", type=str, help="Path to model config JSON file")
parser.add_argument("--task", type=str, default="i2v", help="Task type (i2v, etc.)")
parser.add_argument("--nproc_per_node", type=int, default=1, help="Number of processes per node (GPUs to use)")
parser.add_argument("--port", type=int, default=8000, help="Server port")
parser.add_argument("--host", type=str, default="127.0.0.1", help="Server host")
args = parser.parse_args()
run_server(args)
if __name__ == "__main__":
main()
import random
from typing import Optional from typing import Optional
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
...@@ -5,35 +6,55 @@ from pydantic import BaseModel, Field ...@@ -5,35 +6,55 @@ from pydantic import BaseModel, Field
from ..utils.generate_task_id import generate_task_id from ..utils.generate_task_id import generate_task_id
def generate_random_seed() -> int:
return random.randint(0, 2**32 - 1)
class TalkObject(BaseModel): class TalkObject(BaseModel):
audio: str = Field(..., description="Audio path") audio: str = Field(..., description="Audio path")
mask: str = Field(..., description="Mask path") mask: str = Field(..., description="Mask path")
class TaskRequest(BaseModel): class BaseTaskRequest(BaseModel):
task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)") task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)")
prompt: str = Field("", description="Generation prompt") prompt: str = Field("", description="Generation prompt")
use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer") use_prompt_enhancer: bool = Field(False, description="Whether to use prompt enhancer")
negative_prompt: str = Field("", description="Negative prompt") negative_prompt: str = Field("", description="Negative prompt")
image_path: str = Field("", description="Base64 encoded image or URL") image_path: str = Field("", description="Base64 encoded image or URL")
num_fragments: int = Field(1, description="Number of fragments") save_result_path: str = Field("", description="Save result path (optional, defaults to task_id, suffix auto-detected)")
save_result_path: str = Field("", description="Save video path (optional, defaults to task_id.mp4)")
infer_steps: int = Field(5, description="Inference steps") infer_steps: int = Field(5, description="Inference steps")
target_video_length: int = Field(81, description="Target video length") seed: int = Field(default_factory=generate_random_seed, description="Random seed (auto-generated if not set)")
seed: int = Field(42, description="Random seed")
audio_path: str = Field("", description="Input audio path (Wan-Audio)")
video_duration: int = Field(5, description="Video duration (Wan-Audio)")
talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
def __init__(self, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)
if not self.save_result_path: if not self.save_result_path:
self.save_result_path = f"{self.task_id}.mp4" self.save_result_path = f"{self.task_id}"
def get(self, key, default=None): def get(self, key, default=None):
return getattr(self, key, default) return getattr(self, key, default)
class VideoTaskRequest(BaseTaskRequest):
num_fragments: int = Field(1, description="Number of fragments")
target_video_length: int = Field(81, description="Target video length")
audio_path: str = Field("", description="Input audio path (Wan-Audio)")
video_duration: int = Field(5, description="Video duration (Wan-Audio)")
talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
class ImageTaskRequest(BaseTaskRequest):
aspect_ratio: str = Field("16:9", description="Output aspect ratio")
class TaskRequest(BaseTaskRequest):
num_fragments: int = Field(1, description="Number of fragments")
target_video_length: int = Field(81, description="Target video length (video only)")
audio_path: str = Field("", description="Input audio path (Wan-Audio)")
video_duration: int = Field(5, description="Video duration (Wan-Audio)")
talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
aspect_ratio: str = Field("16:9", description="Output aspect ratio (T2I only)")
class TaskStatusMessage(BaseModel): class TaskStatusMessage(BaseModel):
task_id: str = Field(..., description="Task ID") task_id: str = Field(..., description="Task ID")
......
This diff is collapsed.
from .file_service import FileService
from .generation import ImageGenerationService, VideoGenerationService
from .inference import DistributedInferenceService, TorchrunInferenceWorker
__all__ = [
"FileService",
"DistributedInferenceService",
"TorchrunInferenceWorker",
"VideoGenerationService",
"ImageGenerationService",
]
...@@ -17,20 +17,15 @@ class DistributedManager: ...@@ -17,20 +17,15 @@ class DistributedManager:
CHUNK_SIZE = 1024 * 1024 CHUNK_SIZE = 1024 * 1024
def init_process_group(self) -> bool: def init_process_group(self) -> bool:
"""Initialize process group using torchrun environment variables"""
try: try:
# torchrun sets these environment variables automatically
self.rank = int(os.environ.get("LOCAL_RANK", 0)) self.rank = int(os.environ.get("LOCAL_RANK", 0))
self.world_size = int(os.environ.get("WORLD_SIZE", 1)) self.world_size = int(os.environ.get("WORLD_SIZE", 1))
if self.world_size > 1: if self.world_size > 1:
# torchrun handles backend, init_method, rank, and world_size
# We just need to call init_process_group without parameters
backend = "nccl" if torch.cuda.is_available() else "gloo" backend = "nccl" if torch.cuda.is_available() else "gloo"
dist.init_process_group(backend=backend, init_method="env://") dist.init_process_group(backend=backend, init_method="env://")
logger.info(f"Setup backend: {backend}") logger.info(f"Setup backend: {backend}")
# Set CUDA device for this rank
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.set_device(self.rank) torch.cuda.set_device(self.rank)
self.device = f"cuda:{self.rank}" self.device = f"cuda:{self.rank}"
......
import asyncio
import uuid
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
import httpx
from loguru import logger
class FileService:
def __init__(self, cache_dir: Path):
self.cache_dir = cache_dir
self.input_image_dir = cache_dir / "inputs" / "imgs"
self.input_audio_dir = cache_dir / "inputs" / "audios"
self.output_video_dir = cache_dir / "outputs"
self._http_client = None
self._client_lock = asyncio.Lock()
self.max_retries = 3
self.retry_delay = 1.0
self.max_retry_delay = 10.0
for directory in [
self.input_image_dir,
self.output_video_dir,
self.input_audio_dir,
]:
directory.mkdir(parents=True, exist_ok=True)
async def _get_http_client(self) -> httpx.AsyncClient:
async with self._client_lock:
if self._http_client is None or self._http_client.is_closed:
timeout = httpx.Timeout(
connect=10.0,
read=30.0,
write=10.0,
pool=5.0,
)
limits = httpx.Limits(max_keepalive_connections=5, max_connections=10, keepalive_expiry=30.0)
self._http_client = httpx.AsyncClient(verify=False, timeout=timeout, limits=limits, follow_redirects=True)
return self._http_client
async def _download_with_retry(self, url: str, max_retries: Optional[int] = None) -> httpx.Response:
if max_retries is None:
max_retries = self.max_retries
last_exception = None
retry_delay = self.retry_delay
for attempt in range(max_retries):
try:
client = await self._get_http_client()
response = await client.get(url)
if response.status_code == 200:
return response
elif response.status_code >= 500:
logger.warning(f"Server error {response.status_code} for {url}, attempt {attempt + 1}/{max_retries}")
last_exception = httpx.HTTPStatusError(f"Server returned {response.status_code}", request=response.request, response=response)
else:
raise httpx.HTTPStatusError(f"Client error {response.status_code}", request=response.request, response=response)
except (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError) as e:
logger.warning(f"Connection error for {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
last_exception = e
except httpx.HTTPStatusError as e:
if e.response and e.response.status_code < 500:
raise
last_exception = e
except Exception as e:
logger.error(f"Unexpected error downloading {url}: {str(e)}")
last_exception = e
if attempt < max_retries - 1:
await asyncio.sleep(retry_delay)
retry_delay = min(retry_delay * 2, self.max_retry_delay)
error_msg = f"All {max_retries} connection attempts failed for {url}"
if last_exception:
error_msg += f": {str(last_exception)}"
raise httpx.ConnectError(error_msg)
async def download_media(self, url: str, media_type: str = "image") -> Path:
try:
parsed_url = urlparse(url)
if not parsed_url.scheme or not parsed_url.netloc:
raise ValueError(f"Invalid URL format: {url}")
response = await self._download_with_retry(url)
media_name = Path(parsed_url.path).name
if not media_name:
default_ext = "jpg" if media_type == "image" else "mp3"
media_name = f"{uuid.uuid4()}.{default_ext}"
if media_type == "image":
target_dir = self.input_image_dir
else:
target_dir = self.input_audio_dir
media_path = target_dir / media_name
media_path.parent.mkdir(parents=True, exist_ok=True)
with open(media_path, "wb") as f:
f.write(response.content)
logger.info(f"Successfully downloaded {media_type} from {url} to {media_path}")
return media_path
except httpx.ConnectError as e:
logger.error(f"Connection error downloading {media_type} from {url}: {str(e)}")
raise ValueError(f"Failed to connect to {url}: {str(e)}")
except httpx.TimeoutException as e:
logger.error(f"Timeout downloading {media_type} from {url}: {str(e)}")
raise ValueError(f"Download timeout for {url}: {str(e)}")
except httpx.HTTPStatusError as e:
logger.error(f"HTTP error downloading {media_type} from {url}: {str(e)}")
raise ValueError(f"HTTP error for {url}: {str(e)}")
except ValueError:
raise
except Exception as e:
logger.error(f"Unexpected error downloading {media_type} from {url}: {str(e)}")
raise ValueError(f"Failed to download {media_type} from {url}: {str(e)}")
async def download_image(self, image_url: str) -> Path:
return await self.download_media(image_url, "image")
async def download_audio(self, audio_url: str) -> Path:
return await self.download_media(audio_url, "audio")
def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
file_extension = Path(filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
file_path = self.input_image_dir / unique_filename
with open(file_path, "wb") as f:
f.write(file_content)
return file_path
def get_output_path(self, save_result_path: str) -> Path:
video_path = Path(save_result_path)
if not video_path.is_absolute():
return self.output_video_dir / save_result_path
return video_path
async def cleanup(self):
async with self._client_lock:
if self._http_client and not self._http_client.is_closed:
await self._http_client.aclose()
self._http_client = None
from .base import BaseGenerationService
from .image import ImageGenerationService
from .video import VideoGenerationService
__all__ = [
"BaseGenerationService",
"VideoGenerationService",
"ImageGenerationService",
]
import json
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from loguru import logger
from ...media import is_base64_audio, is_base64_image, save_base64_audio, save_base64_image
from ...schema import TaskResponse
from ..file_service import FileService
from ..inference import DistributedInferenceService
class BaseGenerationService(ABC):
def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
self.file_service = file_service
self.inference_service = inference_service
@abstractmethod
def get_output_extension(self) -> str:
pass
@abstractmethod
def get_task_type(self) -> str:
pass
def _is_target_task_type(self) -> bool:
if self.inference_service.worker and self.inference_service.worker.runner:
task_type = self.inference_service.worker.runner.config.get("task", "t2v")
return task_type in self.get_task_type().split(",")
return False
async def _process_image_path(self, image_path: str, task_data: Dict[str, Any]) -> None:
if not image_path:
return
if image_path.startswith("http"):
downloaded_path = await self.file_service.download_image(image_path)
task_data["image_path"] = str(downloaded_path)
elif is_base64_image(image_path):
saved_path = save_base64_image(image_path, str(self.file_service.input_image_dir))
task_data["image_path"] = str(saved_path)
else:
task_data["image_path"] = image_path
async def _process_audio_path(self, audio_path: str, task_data: Dict[str, Any]) -> None:
if not audio_path:
return
if audio_path.startswith("http"):
downloaded_path = await self.file_service.download_audio(audio_path)
task_data["audio_path"] = str(downloaded_path)
elif is_base64_audio(audio_path):
saved_path = save_base64_audio(audio_path, str(self.file_service.input_audio_dir))
task_data["audio_path"] = str(saved_path)
else:
task_data["audio_path"] = audio_path
async def _process_talk_objects(self, talk_objects: list, task_data: Dict[str, Any]) -> None:
if not talk_objects:
return
task_data["talk_objects"] = [{} for _ in range(len(talk_objects))]
for index, talk_object in enumerate(talk_objects):
if talk_object.audio.startswith("http"):
audio_path = await self.file_service.download_audio(talk_object.audio)
task_data["talk_objects"][index]["audio"] = str(audio_path)
elif is_base64_audio(talk_object.audio):
audio_path = save_base64_audio(talk_object.audio, str(self.file_service.input_audio_dir))
task_data["talk_objects"][index]["audio"] = str(audio_path)
else:
task_data["talk_objects"][index]["audio"] = talk_object.audio
if talk_object.mask.startswith("http"):
mask_path = await self.file_service.download_image(talk_object.mask)
task_data["talk_objects"][index]["mask"] = str(mask_path)
elif is_base64_image(talk_object.mask):
mask_path = save_base64_image(talk_object.mask, str(self.file_service.input_image_dir))
task_data["talk_objects"][index]["mask"] = str(mask_path)
else:
task_data["talk_objects"][index]["mask"] = talk_object.mask
temp_path = self.file_service.cache_dir / uuid.uuid4().hex[:8]
temp_path.mkdir(parents=True, exist_ok=True)
task_data["audio_path"] = str(temp_path)
config_path = temp_path / "config.json"
with open(config_path, "w") as f:
json.dump({"talk_objects": task_data["talk_objects"]}, f)
def _prepare_output_path(self, save_result_path: str, task_data: Dict[str, Any]) -> None:
actual_save_path = self.file_service.get_output_path(save_result_path)
if not actual_save_path.suffix:
actual_save_path = actual_save_path.with_suffix(self.get_output_extension())
task_data["save_result_path"] = str(actual_save_path)
task_data["video_path"] = actual_save_path.name
async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
try:
task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
task_data["task_id"] = message.task_id
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled before processing")
return None
if hasattr(message, "image_path") and message.image_path:
await self._process_image_path(message.image_path, task_data)
logger.info(f"Task {message.task_id} image path: {task_data.get('image_path')}")
if hasattr(message, "audio_path") and message.audio_path:
await self._process_audio_path(message.audio_path, task_data)
logger.info(f"Task {message.task_id} audio path: {task_data.get('audio_path')}")
if hasattr(message, "talk_objects") and message.talk_objects:
await self._process_talk_objects(message.talk_objects, task_data)
self._prepare_output_path(message.save_result_path, task_data)
task_data["seed"] = message.seed
result = await self.inference_service.submit_task_async(task_data)
if result is None:
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled during processing")
return None
raise RuntimeError("Task processing failed")
if result.get("status") == "success":
actual_save_path = self.file_service.get_output_path(message.save_result_path)
if not actual_save_path.suffix:
actual_save_path = actual_save_path.with_suffix(self.get_output_extension())
return TaskResponse(
task_id=message.task_id,
task_status="completed",
save_result_path=actual_save_path.name,
)
else:
error_msg = result.get("error", "Inference failed")
raise RuntimeError(error_msg)
except Exception as e:
logger.exception(f"Task {message.task_id} processing failed: {str(e)}")
raise
from typing import Any, Optional
from loguru import logger
from ...schema import TaskResponse
from ..file_service import FileService
from ..inference import DistributedInferenceService
from .base import BaseGenerationService
class ImageGenerationService(BaseGenerationService):
def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
super().__init__(file_service, inference_service)
def get_output_extension(self) -> str:
return ".png"
def get_task_type(self) -> str:
return "t2i,i2i"
async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
try:
task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
task_data["task_id"] = message.task_id
if hasattr(message, "aspect_ratio"):
task_data["aspect_ratio"] = message.aspect_ratio
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled before processing")
return None
if hasattr(message, "image_path") and message.image_path:
await self._process_image_path(message.image_path, task_data)
logger.info(f"Task {message.task_id} image path: {task_data.get('image_path')}")
self._prepare_output_path(message.save_result_path, task_data)
task_data["seed"] = message.seed
result = await self.inference_service.submit_task_async(task_data)
if result is None:
if stop_event.is_set():
logger.info(f"Task {message.task_id} cancelled during processing")
return None
raise RuntimeError("Task processing failed")
if result.get("status") == "success":
actual_save_path = self.file_service.get_output_path(message.save_result_path)
if not actual_save_path.suffix:
actual_save_path = actual_save_path.with_suffix(self.get_output_extension())
return TaskResponse(
task_id=message.task_id,
task_status="completed",
save_result_path=actual_save_path.name,
)
else:
error_msg = result.get("error", "Inference failed")
raise RuntimeError(error_msg)
except Exception as e:
logger.exception(f"Task {message.task_id} processing failed: {str(e)}")
raise
async def generate_image_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
return await self.generate_with_stop_event(message, stop_event)
from typing import Any, Optional
from ..file_service import FileService
from ..inference import DistributedInferenceService
from .base import BaseGenerationService
class VideoGenerationService(BaseGenerationService):
def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
super().__init__(file_service, inference_service)
def get_output_extension(self) -> str:
return ".mp4"
def get_task_type(self) -> str:
return "t2v,i2v,s2v"
async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
return await super().generate_with_stop_event(message, stop_event)
async def generate_video_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
return await self.generate_with_stop_event(message, stop_event)
from .service import DistributedInferenceService
from .worker import TorchrunInferenceWorker
__all__ = [
"TorchrunInferenceWorker",
"DistributedInferenceService",
]
from typing import Optional
from loguru import logger
from .worker import TorchrunInferenceWorker
class DistributedInferenceService:
def __init__(self):
self.worker = None
self.is_running = False
self.args = None
def start_distributed_inference(self, args) -> bool:
self.args = args
if self.is_running:
logger.warning("Distributed inference service is already running")
return True
try:
self.worker = TorchrunInferenceWorker()
if not self.worker.init(args):
raise RuntimeError("Worker initialization failed")
self.is_running = True
logger.info(f"Rank {self.worker.rank} inference service started successfully")
return True
except Exception as e:
logger.error(f"Error starting inference service: {str(e)}")
self.stop_distributed_inference()
return False
def stop_distributed_inference(self):
if not self.is_running:
return
try:
if self.worker:
self.worker.cleanup()
logger.info("Inference service stopped")
except Exception as e:
logger.error(f"Error stopping inference service: {str(e)}")
finally:
self.worker = None
self.is_running = False
async def submit_task_async(self, task_data: dict) -> Optional[dict]:
if not self.is_running or not self.worker:
logger.error("Inference service is not started")
return None
if self.worker.rank != 0:
return None
try:
if self.worker.processing:
logger.info(f"Waiting for previous task to complete before processing task {task_data.get('task_id')}")
self.worker.processing = True
result = await self.worker.process_request(task_data)
self.worker.processing = False
return result
except Exception as e:
self.worker.processing = False
logger.error(f"Failed to process task: {str(e)}")
return {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Task processing failed: {str(e)}",
}
def server_metadata(self):
assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
return {"nproc_per_node": self.worker.world_size, "model_cls": self.args.model_cls, "model_path": self.args.model_path}
async def run_worker_loop(self):
if self.worker and self.worker.rank != 0:
await self.worker.worker_loop()
import asyncio
import json
import os
from typing import Any, Dict
import torch
from easydict import EasyDict
from loguru import logger
from lightx2v.infer import init_runner
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.set_config import set_config
from ..distributed_utils import DistributedManager
class TorchrunInferenceWorker:
def __init__(self):
self.rank = int(os.environ.get("LOCAL_RANK", 0))
self.world_size = int(os.environ.get("WORLD_SIZE", 1))
self.runner = None
self.dist_manager = DistributedManager()
self.processing = False
def init(self, args) -> bool:
try:
if self.world_size > 1:
if not self.dist_manager.init_process_group():
raise RuntimeError("Failed to initialize distributed process group")
else:
self.dist_manager.rank = 0
self.dist_manager.world_size = 1
self.dist_manager.device = "cuda:0" if torch.cuda.is_available() else "cpu"
self.dist_manager.is_initialized = False
config = set_config(args)
if self.rank == 0:
logger.info(f"Config:\n {json.dumps(config, ensure_ascii=False, indent=4)}")
self.runner = init_runner(config)
logger.info(f"Rank {self.rank}/{self.world_size - 1} initialization completed")
return True
except Exception as e:
logger.exception(f"Rank {self.rank} initialization failed: {str(e)}")
return False
async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
try:
if self.world_size > 1 and self.rank == 0:
task_data = self.dist_manager.broadcast_task_data(task_data)
task_data["task"] = self.runner.config["task"]
task_data["return_result_tensor"] = False
task_data["negative_prompt"] = task_data.get("negative_prompt", "")
task_data = EasyDict(task_data)
input_info = set_input_info(task_data)
self.runner.set_config(task_data)
self.runner.run_pipeline(input_info)
await asyncio.sleep(0)
if self.world_size > 1:
self.dist_manager.barrier()
if self.rank == 0:
return {
"task_id": task_data["task_id"],
"status": "success",
"save_result_path": task_data.get("video_path", task_data["save_result_path"]),
"message": "Inference completed",
}
else:
return None
except Exception as e:
logger.exception(f"Rank {self.rank} inference failed: {str(e)}")
if self.world_size > 1:
self.dist_manager.barrier()
if self.rank == 0:
return {
"task_id": task_data.get("task_id", "unknown"),
"status": "failed",
"error": str(e),
"message": f"Inference failed: {str(e)}",
}
else:
return None
async def worker_loop(self):
while True:
try:
task_data = self.dist_manager.broadcast_task_data()
if task_data is None:
logger.info(f"Rank {self.rank} received stop signal")
break
await self.process_request(task_data)
except Exception as e:
logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
continue
def cleanup(self):
self.dist_manager.cleanup()
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service
python -m lightx2v.server \
--model_cls qwen_image \
--task i2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i.json \
--port 8000
echo "Service stopped"
# {
# "prompt": "turn the style of the photo to vintage comic book",
# "image_path": "assets/inputs/imgs/snake.png",
# "infer_steps": 50
# }
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
# Start API server with distributed inference service
python -m lightx2v.server \
--model_cls qwen_image \
--task t2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_t2i.json \
--port 8000
echo "Service stopped"
# {
# "prompt": "a beautiful sunset over the ocean",
# "aspect_ratio": "16:9",
# "infer_steps": 50
# }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment