image.py 2.67 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import Any, Optional

from loguru import logger

from ...schema import TaskResponse
from ..file_service import FileService
from ..inference import DistributedInferenceService
from .base import BaseGenerationService


class ImageGenerationService(BaseGenerationService):
    def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
        super().__init__(file_service, inference_service)

    def get_output_extension(self) -> str:
        return ".png"

    def get_task_type(self) -> str:
        return "t2i,i2i"

    async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
        try:
            task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
            task_data["task_id"] = message.task_id

            if hasattr(message, "aspect_ratio"):
                task_data["aspect_ratio"] = message.aspect_ratio

            if stop_event.is_set():
                logger.info(f"Task {message.task_id} cancelled before processing")
                return None

            if hasattr(message, "image_path") and message.image_path:
                await self._process_image_path(message.image_path, task_data)
                logger.info(f"Task {message.task_id} image path: {task_data.get('image_path')}")

            self._prepare_output_path(message.save_result_path, task_data)
            task_data["seed"] = message.seed

            result = await self.inference_service.submit_task_async(task_data)

            if result is None:
                if stop_event.is_set():
                    logger.info(f"Task {message.task_id} cancelled during processing")
                    return None
                raise RuntimeError("Task processing failed")

            if result.get("status") == "success":
                actual_save_path = self.file_service.get_output_path(message.save_result_path)
                if not actual_save_path.suffix:
                    actual_save_path = actual_save_path.with_suffix(self.get_output_extension())
                return TaskResponse(
                    task_id=message.task_id,
                    task_status="completed",
                    save_result_path=actual_save_path.name,
                )
            else:
                error_msg = result.get("error", "Inference failed")
                raise RuntimeError(error_msg)

        except Exception as e:
            logger.exception(f"Task {message.task_id} processing failed: {str(e)}")
            raise

    async def generate_image_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
        return await self.generate_with_stop_event(message, stop_event)