Unverified Commit 0379ae88 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

fix server use new config system (#362)

parent 544435d1
......@@ -40,11 +40,14 @@ class ApiServer:
self._setup_routes()
def _setup_routes(self):
@self.app.get("/")
def redirect_to_docs():
return RedirectResponse(url="/docs")
self._setup_task_routes()
self._setup_file_routes()
self._setup_service_routes()
# Register routers
self.app.include_router(self.tasks_router)
self.app.include_router(self.files_router)
self.app.include_router(self.service_router)
......@@ -133,7 +136,7 @@ class ApiServer:
infer_steps: int = Form(default=5),
target_video_length: int = Form(default=81),
seed: int = Form(default=42),
audio_file: Optional[UploadFile] = File(default=None),
audio_file: UploadFile = File(None),
video_duration: int = Form(default=5),
):
assert self.file_service is not None, "File service is not initialized"
......@@ -305,7 +308,7 @@ class ApiServer:
if not parsed_url.scheme or not parsed_url.netloc:
return False
timeout = httpx.Timeout(connect=5.0, read=5.0)
timeout = httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=5.0)
async with httpx.AsyncClient(verify=False, timeout=timeout) as client:
response = await client.head(image_url, follow_redirects=True)
return response.status_code < 400
......@@ -375,7 +378,7 @@ class ApiServer:
logger.error(f"Task {task_id} generation failed")
except Exception as e:
logger.error(f"Task {task_id} processing failed: {str(e)}")
logger.exception(f"Task {task_id} processing failed: {str(e)}")
task_manager.fail_task(task_id, str(e))
finally:
if lock_acquired:
......
from typing import Optional
from pydantic import BaseModel, Field
from ..utils.generate_task_id import generate_task_id
class TalkObject(BaseModel):
audio: str = Field(..., description="Audio path")
mask: str = Field(..., description="Mask path")
class TaskRequest(BaseModel):
task_id: str = Field(default_factory=generate_task_id, description="Task ID (auto-generated)")
prompt: str = Field("", description="Generation prompt")
......@@ -16,6 +23,7 @@ class TaskRequest(BaseModel):
seed: int = Field(42, description="Random seed")
audio_path: str = Field("", description="Input audio path (Wan-Audio)")
video_duration: int = Field(5, description="Video duration (Wan-Audio)")
talk_objects: Optional[list[TalkObject]] = Field(None, description="Talk objects (Wan-Audio)")
def __init__(self, **data):
super().__init__(**data)
......
......@@ -8,9 +8,11 @@ from urllib.parse import urlparse
import httpx
import torch
from easydict import EasyDict
from loguru import logger
from ..infer import init_runner
from ..utils.input_info import set_input_info
from ..utils.set_config import set_config
from .audio_utils import is_base64_audio, save_base64_audio
from .distributed_utils import DistributedManager
......@@ -245,8 +247,21 @@ class TorchrunInferenceWorker:
# Run inference directly - torchrun handles the parallelization
# Using asyncio.to_thread would be risky with NCCL operations
# Instead, we rely on FastAPI's async handling and queue management
self.runner.set_inputs(task_data)
self.runner.run_pipeline()
task_data["task"] = self.runner.config["task"]
task_data["return_result_tensor"] = False
task_data["negative_prompt"] = task_data.get("negative_prompt", "")
# must be convert
task_data = EasyDict(task_data)
input_info = set_input_info(task_data)
# update lock config
self.runner.set_config(task_data)
# print("input_info==>", input_info)
self.runner.run_pipeline(input_info)
# Small yield to allow other async operations if needed
await asyncio.sleep(0)
......@@ -267,7 +282,7 @@ class TorchrunInferenceWorker:
return None
except Exception as e:
logger.error(f"Rank {self.rank} inference failed: {str(e)}")
logger.exception(f"Rank {self.rank} inference failed: {str(e)}")
if self.world_size > 1:
self.dist_manager.barrier()
......@@ -418,6 +433,37 @@ class VideoGenerationService:
logger.info(f"Task {message.task_id} audio path: {task_data['audio_path']}")
if "talk_objects" in message.model_fields_set and message.talk_objects:
task_data["talk_objects"] = [{} for _ in range(len(message.talk_objects))]
for index, talk_object in enumerate(message.talk_objects):
if talk_object.audio.startswith("http"):
audio_path = await self.file_service.download_audio(talk_object.audio)
task_data["talk_objects"][index]["audio"] = str(audio_path)
elif is_base64_audio(talk_object.audio):
audio_path = save_base64_audio(talk_object.audio, str(self.file_service.input_audio_dir))
task_data["talk_objects"][index]["audio"] = str(audio_path)
else:
task_data["talk_objects"][index]["audio"] = talk_object.audio
if talk_object.mask.startswith("http"):
mask_path = await self.file_service.download_image(talk_object.mask)
task_data["talk_objects"][index]["mask"] = str(mask_path)
elif is_base64_image(talk_object.mask):
mask_path = save_base64_image(talk_object.mask, str(self.file_service.input_image_dir))
task_data["talk_objects"][index]["mask"] = str(mask_path)
else:
task_data["talk_objects"][index]["mask"] = talk_object.mask
# FIXME(xxx): 存储成一个config.json , 然后将这个config.json 的路径,赋值给task_data["audio_path"]
temp_path = self.file_service.cache_dir / uuid.uuid4().hex[:8]
temp_path.mkdir(parents=True, exist_ok=True)
task_data["audio_path"] = str(temp_path)
config_path = temp_path / "config.json"
with open(config_path, "w") as f:
json.dump({"talk_objects": task_data["talk_objects"]}, f)
actual_save_path = self.file_service.get_output_path(message.save_result_path)
task_data["save_result_path"] = str(actual_save_path)
task_data["video_path"] = message.save_result_path
......@@ -441,5 +487,5 @@ class VideoGenerationService:
raise RuntimeError(error_msg)
except Exception as e:
logger.error(f"Task {message.task_id} processing failed: {str(e)}")
logger.exception(f"Task {message.task_id} processing failed: {str(e)}")
raise
......@@ -163,6 +163,8 @@ def set_input_info(args):
)
else:
raise ValueError(f"Unsupported task: {args.task}")
assert not (input_info.save_result_path and input_info.return_result_tensor), "save_result_path and return_result_tensor cannot be set at the same time"
return input_info
......
......@@ -34,8 +34,6 @@ def get_default_config():
def set_config(args):
assert not (args.save_result_path and args.return_result_tensor), "save_result_path and return_result_tensor cannot be set at the same time"
config = get_default_config()
config.update({k: v for k, v in vars(args).items() if k not in ALL_INPUT_INFO_KEYS})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment