Commit 4e735704 authored by gaclove's avatar gaclove
Browse files

Refactor distributed inference worker to remove unused asyncio event loop and...

Refactor distributed inference worker to remove unused asyncio event loop and streamline task data processing in VideoGenerationService.
parent acac50a6
import asyncio
import queue import queue
import time import time
import uuid import uuid
...@@ -77,7 +76,6 @@ class FileService: ...@@ -77,7 +76,6 @@ class FileService:
def _distributed_inference_worker(rank, world_size, master_addr, master_port, args, task_queue, result_queue): def _distributed_inference_worker(rank, world_size, master_addr, master_port, args, task_queue, result_queue):
task_data = None task_data = None
loop = None
worker = None worker = None
try: try:
...@@ -95,10 +93,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar ...@@ -95,10 +93,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
runner = init_runner(config) runner = init_runner(config)
logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed") logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")
# Create event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while True: while True:
# Only rank=0 reads tasks from queue # Only rank=0 reads tasks from queue
if rank == 0: if rank == 0:
...@@ -128,7 +122,7 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar ...@@ -128,7 +122,7 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
try: try:
# Set inputs and run inference # Set inputs and run inference
runner.set_inputs(task_data) # type: ignore runner.set_inputs(task_data) # type: ignore
loop.run_until_complete(runner.run_pipeline()) runner.run_pipeline()
# Synchronize and report results # Synchronize and report results
worker.sync_and_report( worker.sync_and_report(
...@@ -163,13 +157,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar ...@@ -163,13 +157,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
} }
result_queue.put(error_result) result_queue.put(error_result)
finally: finally:
# Clean up resources
try:
if loop and not loop.is_closed():
loop.close()
except: # noqa: E722
pass
try: try:
if worker: if worker:
worker.cleanup() worker.cleanup()
...@@ -329,36 +316,20 @@ class VideoGenerationService: ...@@ -329,36 +316,20 @@ class VideoGenerationService:
async def generate_video(self, message: TaskRequest) -> TaskResponse: async def generate_video(self, message: TaskRequest) -> TaskResponse:
try: try:
# Process image path # 只包含显式设置的字段,排除使用默认值的字段
task_data = { task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
"task_id": message.task_id, task_data["task_id"] = message.task_id
"prompt": message.prompt,
"use_prompt_enhancer": message.use_prompt_enhancer,
"negative_prompt": message.negative_prompt,
"image_path": message.image_path,
"num_fragments": message.num_fragments,
"save_video_path": message.save_video_path,
"infer_steps": message.infer_steps,
"target_video_length": message.target_video_length,
"seed": message.seed,
"audio_path": message.audio_path,
"video_duration": message.video_duration,
}
# Process network image if "image_path" in message.model_fields_set and message.image_path.startswith("http"):
if message.image_path.startswith("http"):
image_path = await self.file_service.download_image(message.image_path) image_path = await self.file_service.download_image(message.image_path)
task_data["image_path"] = str(image_path) task_data["image_path"] = str(image_path)
# Process output path
save_video_path = self.file_service.get_output_path(message.save_video_path) save_video_path = self.file_service.get_output_path(message.save_video_path)
task_data["save_video_path"] = str(save_video_path) task_data["save_video_path"] = str(save_video_path)
# Submit task to distributed inference service
if not self.inference_service.submit_task(task_data): if not self.inference_service.submit_task(task_data):
raise RuntimeError("Distributed inference service is not started") raise RuntimeError("Distributed inference service is not started")
# Wait for result
result = self.inference_service.wait_for_result(message.task_id) result = self.inference_service.wait_for_result(message.task_id)
if result is None: if result is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment