Commit 0c3f4bb1 authored by PengGao's avatar PengGao Committed by GitHub
Browse files

feat: add synchronous file writing method and enhance async file savi… (#101)

* feat: add synchronous file writing method and enhance async file saving in API server

* refactor: clean up image file handling in API server and improve error message in service
parent 5ec2f551
...@@ -45,6 +45,11 @@ class ApiServer: ...@@ -45,6 +45,11 @@ class ApiServer:
self.app.include_router(self.files_router) self.app.include_router(self.files_router)
self.app.include_router(self.service_router) self.app.include_router(self.service_router)
def _write_file_sync(self, file_path: Path, content: bytes) -> None:
"""同步写入文件到指定路径"""
with open(file_path, "wb") as buffer:
buffer.write(content)
def _stream_file_response(self, file_path: Path, filename: str | None = None) -> StreamingResponse: def _stream_file_response(self, file_path: Path, filename: str | None = None) -> StreamingResponse:
"""Common file streaming response method""" """Common file streaming response method"""
assert self.file_service is not None, "File service is not initialized" assert self.file_service is not None, "File service is not initialized"
...@@ -130,32 +135,30 @@ class ApiServer: ...@@ -130,32 +135,30 @@ class ApiServer:
video_duration: int = Form(default=5), video_duration: int = Form(default=5),
): ):
"""Create video generation task via form""" """Create video generation task via form"""
# Process uploaded image file
image_path = ""
assert self.file_service is not None, "File service is not initialized" assert self.file_service is not None, "File service is not initialized"
if image_file and image_file.filename: async def save_file_async(file: UploadFile, target_dir: Path) -> str:
file_extension = Path(image_file.filename).suffix """异步保存文件到指定目录"""
if not file or not file.filename:
return ""
file_extension = Path(file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}" unique_filename = f"{uuid.uuid4()}{file_extension}"
image_path = self.file_service.input_image_dir / unique_filename file_path = target_dir / unique_filename
with open(image_path, "wb") as buffer: content = await file.read()
content = await image_file.read()
buffer.write(content)
image_path = str(image_path) await asyncio.to_thread(self._write_file_sync, file_path, content)
audio_path = "" return str(file_path)
if audio_file and audio_file.filename:
file_extension = Path(audio_file.filename).suffix
unique_filename = f"{uuid.uuid4()}{file_extension}"
audio_path = self.file_service.input_audio_dir / unique_filename
with open(audio_path, "wb") as buffer: image_path = ""
content = await audio_file.read() if image_file and image_file.filename:
buffer.write(content) image_path = await save_file_async(image_file, self.file_service.input_image_dir)
audio_path = str(audio_path) audio_path = ""
if audio_file and audio_file.filename:
audio_path = await save_file_async(audio_file, self.file_service.input_audio_dir)
message = TaskRequest( message = TaskRequest(
prompt=prompt, prompt=prompt,
...@@ -276,6 +279,12 @@ class ApiServer: ...@@ -276,6 +279,12 @@ class ApiServer:
"""Get service status""" """Get service status"""
return ServiceStatus.get_status_service() return ServiceStatus.get_status_service()
@self.service_router.get("/metadata", response_model=dict)
async def get_service_metadata():
"""Get service metadata"""
assert self.inference_service is not None, "Inference service is not initialized"
return self.inference_service.server_metadata()
def _process_video_generation(self, message: TaskRequest, stop_event: threading.Event): def _process_video_generation(self, message: TaskRequest, stop_event: threading.Event):
assert self.video_service is not None, "Video service is not initialized" assert self.video_service is not None, "Video service is not initialized"
try: try:
......
...@@ -186,6 +186,7 @@ class DistributedInferenceService: ...@@ -186,6 +186,7 @@ class DistributedInferenceService:
self.is_running = False self.is_running = False
def start_distributed_inference(self, args) -> bool: def start_distributed_inference(self, args) -> bool:
self.args = args
if self.is_running: if self.is_running:
logger.warning("Distributed inference service is already running") logger.warning("Distributed inference service is already running")
return True return True
...@@ -311,6 +312,10 @@ class DistributedInferenceService: ...@@ -311,6 +312,10 @@ class DistributedInferenceService:
return None return None
def server_metadata(self):
assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
return {"nproc_per_node": self.args.nproc_per_node, "model_cls": self.args.model_cls, "model_path": self.args.model_path}
class VideoGenerationService: class VideoGenerationService:
def __init__(self, file_service: FileService, inference_service: DistributedInferenceService): def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment