service.py 12.9 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
3
4
5
6
7
8
9
10
11
12
import queue
import time
import uuid
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse

import httpx
import torch.multiprocessing as mp
from loguru import logger

from ..infer import init_runner
PengGao's avatar
PengGao committed
13
from ..utils.set_config import set_config
PengGao's avatar
PengGao committed
14
from .distributed_utils import create_distributed_worker
PengGao's avatar
PengGao committed
15
16
from .schema import TaskRequest, TaskResponse
from .utils import ServiceStatus
PengGao's avatar
PengGao committed
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123

mp.set_start_method("spawn", force=True)


class FileService:
    def __init__(self, cache_dir: Path):
        self.cache_dir = cache_dir
        self.input_image_dir = cache_dir / "inputs" / "imgs"
        self.input_audio_dir = cache_dir / "inputs" / "audios"
        self.output_video_dir = cache_dir / "outputs"

        # Create directories
        for directory in [
            self.input_image_dir,
            self.output_video_dir,
            self.input_audio_dir,
        ]:
            directory.mkdir(parents=True, exist_ok=True)

    async def download_image(self, image_url: str) -> Path:
        try:
            async with httpx.AsyncClient(verify=False) as client:
                response = await client.get(image_url)

            if response.status_code != 200:
                raise ValueError(f"Failed to download image from {image_url}")

            image_name = Path(urlparse(image_url).path).name
            if not image_name:
                raise ValueError(f"Invalid image URL: {image_url}")

            image_path = self.input_image_dir / image_name
            image_path.parent.mkdir(parents=True, exist_ok=True)

            with open(image_path, "wb") as f:
                f.write(response.content)

            return image_path
        except Exception as e:
            logger.error(f"Failed to download image: {e}")
            raise

    def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
        file_extension = Path(filename).suffix
        unique_filename = f"{uuid.uuid4()}{file_extension}"
        file_path = self.input_image_dir / unique_filename

        with open(file_path, "wb") as f:
            f.write(file_content)

        return file_path

    def get_output_path(self, save_video_path: str) -> Path:
        video_path = Path(save_video_path)
        if not video_path.is_absolute():
            return self.output_video_dir / save_video_path
        return video_path


def _distributed_inference_worker(rank, world_size, master_addr, master_port, args, task_queue, result_queue):
    task_data = None
    worker = None

    try:
        logger.info(f"Process {rank}/{world_size - 1} initializing distributed inference service...")

        # Create and initialize distributed worker process
        worker = create_distributed_worker(rank, world_size, master_addr, master_port)
        if not worker.init():
            raise RuntimeError(f"Rank {rank} distributed environment initialization failed")

        # Initialize configuration and model
        config = set_config(args)
        logger.info(f"Rank {rank} config: {config}")

        runner = init_runner(config)
        logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")

        while True:
            # Only rank=0 reads tasks from queue
            if rank == 0:
                try:
                    task_data = task_queue.get(timeout=1.0)
                    if task_data is None:  # Stop signal
                        logger.info(f"Process {rank} received stop signal, exiting inference service")
                        # Broadcast stop signal to other processes
                        worker.dist_manager.broadcast_task_data(None)
                        break
                    # Broadcast task data to other processes
                    worker.dist_manager.broadcast_task_data(task_data)
                except queue.Empty:
                    # Queue is empty, continue waiting
                    continue
            else:
                # Non-rank=0 processes receive task data from rank=0
                task_data = worker.dist_manager.broadcast_task_data()
                if task_data is None:  # Stop signal
                    logger.info(f"Process {rank} received stop signal, exiting inference service")
                    break

            # All processes handle the task
            if task_data is not None:
                logger.info(f"Process {rank} received inference task: {task_data['task_id']}")

                try:
                    # Set inputs and run inference
                    runner.set_inputs(task_data)  # type: ignore
124
                    runner.run_pipeline()
PengGao's avatar
PengGao committed
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173

                    # Synchronize and report results
                    worker.sync_and_report(
                        task_data["task_id"],
                        "success",
                        result_queue,
                        save_video_path=task_data["save_video_path"],
                        message="Inference completed",
                    )
                except Exception as e:
                    logger.error(f"Process {rank} error occurred while processing task: {str(e)}")

                    # Synchronize and report error
                    worker.sync_and_report(
                        task_data.get("task_id", "unknown"),
                        "failed",
                        result_queue,
                        error=str(e),
                        message=f"Inference failed: {str(e)}",
                    )

    except KeyboardInterrupt:
        logger.info(f"Process {rank} received KeyboardInterrupt, gracefully exiting")
    except Exception as e:
        logger.error(f"Distributed inference service process {rank} startup failed: {str(e)}")
        if rank == 0:
            error_result = {
                "task_id": "startup",
                "status": "startup_failed",
                "error": str(e),
                "message": f"Inference service startup failed: {str(e)}",
            }
            result_queue.put(error_result)
    finally:
        try:
            if worker:
                worker.cleanup()
        except:  # noqa: E722
            pass


class DistributedInferenceService:
    def __init__(self):
        self.task_queue = None
        self.result_queue = None
        self.processes = []
        self.is_running = False

    def start_distributed_inference(self, args) -> bool:
174
175
176
177
178
179
        if hasattr(args, "lora_path") and args.lora_path:
            args.lora_configs = [{"path": args.lora_path, "strength": getattr(args, "lora_strength", 1.0)}]
            delattr(args, "lora_path")
            if hasattr(args, "lora_strength"):
                delattr(args, "lora_strength")

180
        self.args = args
PengGao's avatar
PengGao committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
        if self.is_running:
            logger.warning("Distributed inference service is already running")
            return True

        nproc_per_node = args.nproc_per_node
        if nproc_per_node <= 0:
            logger.error("nproc_per_node must be greater than 0")
            return False

        try:
            import random

            master_addr = "127.0.0.1"
            master_port = str(random.randint(20000, 29999))
            logger.info(f"Distributed inference service Master Addr: {master_addr}, Master Port: {master_port}")

            # Create shared queues
            self.task_queue = mp.Queue()
            self.result_queue = mp.Queue()

            # Start processes
            for rank in range(nproc_per_node):
                p = mp.Process(
                    target=_distributed_inference_worker,
                    args=(
                        rank,
                        nproc_per_node,
                        master_addr,
                        master_port,
                        args,
                        self.task_queue,
                        self.result_queue,
                    ),
                    daemon=True,
                )
                p.start()
                self.processes.append(p)

            self.is_running = True
            logger.info(f"Distributed inference service started successfully with {nproc_per_node} processes")
            return True

        except Exception as e:
            logger.exception(f"Error occurred while starting distributed inference service: {str(e)}")
            self.stop_distributed_inference()
            return False

    def stop_distributed_inference(self):
        if not self.is_running:
            return

        try:
            logger.info(f"Stopping {len(self.processes)} distributed inference service processes...")

            # Send stop signal
            if self.task_queue:
                for _ in self.processes:
                    self.task_queue.put(None)

            # Wait for processes to end
            for p in self.processes:
                try:
                    p.join(timeout=10)
                    if p.is_alive():
                        logger.warning(f"Process {p.pid} did not end within the specified time, forcing termination...")
                        p.terminate()
                        p.join(timeout=5)
                except:  # noqa: E722
                    pass

            logger.info("All distributed inference service processes have stopped")

        except Exception as e:
            logger.error(f"Error occurred while stopping distributed inference service: {str(e)}")
        finally:
            # Clean up resources
            self._clean_queues()
            self.processes = []
            self.task_queue = None
            self.result_queue = None
            self.is_running = False

    def _clean_queues(self):
        for queue_obj in [self.task_queue, self.result_queue]:
            if queue_obj:
                try:
                    while not queue_obj.empty():
                        queue_obj.get_nowait()
                except:  # noqa: E722
                    pass

    def submit_task(self, task_data: dict) -> bool:
        if not self.is_running or not self.task_queue:
            logger.error("Distributed inference service is not started")
            return False

        try:
            self.task_queue.put(task_data)
            return True
        except Exception as e:
            logger.error(f"Failed to submit task: {str(e)}")
            return False

    def wait_for_result(self, task_id: str, timeout: int = 300) -> Optional[dict]:
        if not self.is_running or not self.result_queue:
            return None

        start_time = time.time()

        while time.time() - start_time < timeout:
            try:
                result = self.result_queue.get(timeout=1.0)

                if result.get("task_id") == task_id:
                    return result
                else:
                    # Not the result for current task, put back in queue
                    self.result_queue.put(result)
                    time.sleep(0.1)

            except queue.Empty:
                continue

        return None

306
307
308
309
    def server_metadata(self):
        assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
        return {"nproc_per_node": self.args.nproc_per_node, "model_cls": self.args.model_cls, "model_path": self.args.model_path}

PengGao's avatar
PengGao committed
310
311
312
313
314
315
316
317

class VideoGenerationService:
    def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
        self.file_service = file_service
        self.inference_service = inference_service

    async def generate_video(self, message: TaskRequest) -> TaskResponse:
        try:
318
319
            task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
            task_data["task_id"] = message.task_id
PengGao's avatar
PengGao committed
320

321
            if "image_path" in message.model_fields_set and message.image_path.startswith("http"):
PengGao's avatar
PengGao committed
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
                image_path = await self.file_service.download_image(message.image_path)
                task_data["image_path"] = str(image_path)

            save_video_path = self.file_service.get_output_path(message.save_video_path)
            task_data["save_video_path"] = str(save_video_path)

            if not self.inference_service.submit_task(task_data):
                raise RuntimeError("Distributed inference service is not started")

            result = self.inference_service.wait_for_result(message.task_id)

            if result is None:
                raise RuntimeError("Task processing timeout")

            if result.get("status") == "success":
                ServiceStatus.complete_task(message)
                return TaskResponse(
                    task_id=message.task_id,
                    task_status="completed",
                    save_video_path=str(save_video_path),
                )
            else:
                error_msg = result.get("error", "Inference failed")
                ServiceStatus.record_failed_task(message, error=error_msg)
                raise RuntimeError(error_msg)

        except Exception as e:
            logger.error(f"Task {message.task_id} processing failed: {str(e)}")
            ServiceStatus.record_failed_task(message, error=str(e))
            raise