base.py 6.42 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import json
import uuid
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional

from loguru import logger

from ...media import is_base64_audio, is_base64_image, save_base64_audio, save_base64_image
from ...schema import TaskResponse
from ..file_service import FileService
from ..inference import DistributedInferenceService


class BaseGenerationService(ABC):
    def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
        self.file_service = file_service
        self.inference_service = inference_service

    @abstractmethod
    def get_output_extension(self) -> str:
        pass

    @abstractmethod
    def get_task_type(self) -> str:
        pass

    def _is_target_task_type(self) -> bool:
        if self.inference_service.worker and self.inference_service.worker.runner:
            task_type = self.inference_service.worker.runner.config.get("task", "t2v")
            return task_type in self.get_task_type().split(",")
        return False

    async def _process_image_path(self, image_path: str, task_data: Dict[str, Any]) -> None:
        if not image_path:
            return

        if image_path.startswith("http"):
            downloaded_path = await self.file_service.download_image(image_path)
            task_data["image_path"] = str(downloaded_path)
        elif is_base64_image(image_path):
            saved_path = save_base64_image(image_path, str(self.file_service.input_image_dir))
            task_data["image_path"] = str(saved_path)
        else:
            task_data["image_path"] = image_path

    async def _process_audio_path(self, audio_path: str, task_data: Dict[str, Any]) -> None:
        if not audio_path:
            return

        if audio_path.startswith("http"):
            downloaded_path = await self.file_service.download_audio(audio_path)
            task_data["audio_path"] = str(downloaded_path)
        elif is_base64_audio(audio_path):
            saved_path = save_base64_audio(audio_path, str(self.file_service.input_audio_dir))
            task_data["audio_path"] = str(saved_path)
        else:
            task_data["audio_path"] = audio_path

    async def _process_talk_objects(self, talk_objects: list, task_data: Dict[str, Any]) -> None:
        if not talk_objects:
            return

        task_data["talk_objects"] = [{} for _ in range(len(talk_objects))]

        for index, talk_object in enumerate(talk_objects):
            if talk_object.audio.startswith("http"):
                audio_path = await self.file_service.download_audio(talk_object.audio)
                task_data["talk_objects"][index]["audio"] = str(audio_path)
            elif is_base64_audio(talk_object.audio):
                audio_path = save_base64_audio(talk_object.audio, str(self.file_service.input_audio_dir))
                task_data["talk_objects"][index]["audio"] = str(audio_path)
            else:
                task_data["talk_objects"][index]["audio"] = talk_object.audio

            if talk_object.mask.startswith("http"):
                mask_path = await self.file_service.download_image(talk_object.mask)
                task_data["talk_objects"][index]["mask"] = str(mask_path)
            elif is_base64_image(talk_object.mask):
                mask_path = save_base64_image(talk_object.mask, str(self.file_service.input_image_dir))
                task_data["talk_objects"][index]["mask"] = str(mask_path)
            else:
                task_data["talk_objects"][index]["mask"] = talk_object.mask

        temp_path = self.file_service.cache_dir / uuid.uuid4().hex[:8]
        temp_path.mkdir(parents=True, exist_ok=True)
        task_data["audio_path"] = str(temp_path)

        config_path = temp_path / "config.json"
        with open(config_path, "w") as f:
            json.dump({"talk_objects": task_data["talk_objects"]}, f)

    def _prepare_output_path(self, save_result_path: str, task_data: Dict[str, Any]) -> None:
        actual_save_path = self.file_service.get_output_path(save_result_path)
        if not actual_save_path.suffix:
            actual_save_path = actual_save_path.with_suffix(self.get_output_extension())
        task_data["save_result_path"] = str(actual_save_path)
        task_data["video_path"] = actual_save_path.name

    async def generate_with_stop_event(self, message: Any, stop_event) -> Optional[Any]:
        try:
            task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
            task_data["task_id"] = message.task_id

            if stop_event.is_set():
                logger.info(f"Task {message.task_id} cancelled before processing")
                return None

            if hasattr(message, "image_path") and message.image_path:
                await self._process_image_path(message.image_path, task_data)
                logger.info(f"Task {message.task_id} image path: {task_data.get('image_path')}")

            if hasattr(message, "audio_path") and message.audio_path:
                await self._process_audio_path(message.audio_path, task_data)
                logger.info(f"Task {message.task_id} audio path: {task_data.get('audio_path')}")

            if hasattr(message, "talk_objects") and message.talk_objects:
                await self._process_talk_objects(message.talk_objects, task_data)

            self._prepare_output_path(message.save_result_path, task_data)
            task_data["seed"] = message.seed

            result = await self.inference_service.submit_task_async(task_data)

            if result is None:
                if stop_event.is_set():
                    logger.info(f"Task {message.task_id} cancelled during processing")
                    return None
                raise RuntimeError("Task processing failed")

            if result.get("status") == "success":
                actual_save_path = self.file_service.get_output_path(message.save_result_path)
                if not actual_save_path.suffix:
                    actual_save_path = actual_save_path.with_suffix(self.get_output_extension())
                return TaskResponse(
                    task_id=message.task_id,
                    task_status="completed",
                    save_result_path=actual_save_path.name,
                )
            else:
                error_msg = result.get("error", "Inference failed")
                raise RuntimeError(error_msg)

        except Exception as e:
            logger.exception(f"Task {message.task_id} processing failed: {str(e)}")
            raise