Commit 4e735704 authored by gaclove's avatar gaclove
Browse files

Refactor distributed inference worker to remove unused asyncio event loop and...

Refactor distributed inference worker to remove unused asyncio event loop and streamline task data processing in VideoGenerationService.
parent acac50a6
import asyncio
import queue
import time
import uuid
......@@ -77,7 +76,6 @@ class FileService:
def _distributed_inference_worker(rank, world_size, master_addr, master_port, args, task_queue, result_queue):
task_data = None
loop = None
worker = None
try:
......@@ -95,10 +93,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
runner = init_runner(config)
logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")
# Create event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
while True:
# Only rank=0 reads tasks from queue
if rank == 0:
......@@ -128,7 +122,7 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
try:
# Set inputs and run inference
runner.set_inputs(task_data) # type: ignore
loop.run_until_complete(runner.run_pipeline())
runner.run_pipeline()
# Synchronize and report results
worker.sync_and_report(
......@@ -163,13 +157,6 @@ def _distributed_inference_worker(rank, world_size, master_addr, master_port, ar
}
result_queue.put(error_result)
finally:
# Clean up resources
try:
if loop and not loop.is_closed():
loop.close()
except: # noqa: E722
pass
try:
if worker:
worker.cleanup()
......@@ -329,36 +316,20 @@ class VideoGenerationService:
async def generate_video(self, message: TaskRequest) -> TaskResponse:
try:
# Process image path
task_data = {
"task_id": message.task_id,
"prompt": message.prompt,
"use_prompt_enhancer": message.use_prompt_enhancer,
"negative_prompt": message.negative_prompt,
"image_path": message.image_path,
"num_fragments": message.num_fragments,
"save_video_path": message.save_video_path,
"infer_steps": message.infer_steps,
"target_video_length": message.target_video_length,
"seed": message.seed,
"audio_path": message.audio_path,
"video_duration": message.video_duration,
}
# 只包含显式设置的字段,排除使用默认值的字段
task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
task_data["task_id"] = message.task_id
# Process network image
if message.image_path.startswith("http"):
if "image_path" in message.model_fields_set and message.image_path.startswith("http"):
image_path = await self.file_service.download_image(message.image_path)
task_data["image_path"] = str(image_path)
# Process output path
save_video_path = self.file_service.get_output_path(message.save_video_path)
task_data["save_video_path"] = str(save_video_path)
# Submit task to distributed inference service
if not self.inference_service.submit_task(task_data):
raise RuntimeError("Distributed inference service is not started")
# Wait for result
result = self.inference_service.wait_for_result(message.task_id)
if result is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment