service.py 13.6 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import asyncio
import queue
import time
import uuid
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse

import httpx
import torch.multiprocessing as mp
from loguru import logger

from ..utils.set_config import set_config
from ..infer import init_runner
from .utils import ServiceStatus
from .schema import TaskRequest, TaskResponse
from .distributed_utils import create_distributed_worker


mp.set_start_method("spawn", force=True)


class FileService:
    def __init__(self, cache_dir: Path):
        self.cache_dir = cache_dir
        self.input_image_dir = cache_dir / "inputs" / "imgs"
        self.input_audio_dir = cache_dir / "inputs" / "audios"
        self.output_video_dir = cache_dir / "outputs"

        # Create directories
        for directory in [
            self.input_image_dir,
            self.output_video_dir,
            self.input_audio_dir,
        ]:
            directory.mkdir(parents=True, exist_ok=True)

    async def download_image(self, image_url: str) -> Path:
        try:
            async with httpx.AsyncClient(verify=False) as client:
                response = await client.get(image_url)

            if response.status_code != 200:
                raise ValueError(f"Failed to download image from {image_url}")

            image_name = Path(urlparse(image_url).path).name
            if not image_name:
                raise ValueError(f"Invalid image URL: {image_url}")

            image_path = self.input_image_dir / image_name
            image_path.parent.mkdir(parents=True, exist_ok=True)

            with open(image_path, "wb") as f:
                f.write(response.content)

            return image_path
        except Exception as e:
            logger.error(f"Failed to download image: {e}")
            raise

    def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
        file_extension = Path(filename).suffix
        unique_filename = f"{uuid.uuid4()}{file_extension}"
        file_path = self.input_image_dir / unique_filename

        with open(file_path, "wb") as f:
            f.write(file_content)

        return file_path

    def get_output_path(self, save_video_path: str) -> Path:
        video_path = Path(save_video_path)
        if not video_path.is_absolute():
            return self.output_video_dir / save_video_path
        return video_path


def _distributed_inference_worker(rank, world_size, master_addr, master_port, args, task_queue, result_queue):
    task_data = None
    loop = None
    worker = None

    try:
        logger.info(f"Process {rank}/{world_size - 1} initializing distributed inference service...")

        # Create and initialize distributed worker process
        worker = create_distributed_worker(rank, world_size, master_addr, master_port)
        if not worker.init():
            raise RuntimeError(f"Rank {rank} distributed environment initialization failed")

        # Initialize configuration and model
        config = set_config(args)
        config["mode"] = "server"
        logger.info(f"Rank {rank} config: {config}")

        runner = init_runner(config)
        logger.info(f"Process {rank}/{world_size - 1} distributed inference service initialization completed")

        # Create event loop
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)

        while True:
            # Only rank=0 reads tasks from queue
            if rank == 0:
                try:
                    task_data = task_queue.get(timeout=1.0)
                    if task_data is None:  # Stop signal
                        logger.info(f"Process {rank} received stop signal, exiting inference service")
                        # Broadcast stop signal to other processes
                        worker.dist_manager.broadcast_task_data(None)
                        break
                    # Broadcast task data to other processes
                    worker.dist_manager.broadcast_task_data(task_data)
                except queue.Empty:
                    # Queue is empty, continue waiting
                    continue
            else:
                # Non-rank=0 processes receive task data from rank=0
                task_data = worker.dist_manager.broadcast_task_data()
                if task_data is None:  # Stop signal
                    logger.info(f"Process {rank} received stop signal, exiting inference service")
                    break

            # All processes handle the task
            if task_data is not None:
                logger.info(f"Process {rank} received inference task: {task_data['task_id']}")

                try:
                    # Set inputs and run inference
                    runner.set_inputs(task_data)  # type: ignore
                    loop.run_until_complete(runner.run_pipeline())

                    # Synchronize and report results
                    worker.sync_and_report(
                        task_data["task_id"],
                        "success",
                        result_queue,
                        save_video_path=task_data["save_video_path"],
                        message="Inference completed",
                    )
                except Exception as e:
                    logger.error(f"Process {rank} error occurred while processing task: {str(e)}")

                    # Synchronize and report error
                    worker.sync_and_report(
                        task_data.get("task_id", "unknown"),
                        "failed",
                        result_queue,
                        error=str(e),
                        message=f"Inference failed: {str(e)}",
                    )

    except KeyboardInterrupt:
        logger.info(f"Process {rank} received KeyboardInterrupt, gracefully exiting")
    except Exception as e:
        logger.error(f"Distributed inference service process {rank} startup failed: {str(e)}")
        if rank == 0:
            error_result = {
                "task_id": "startup",
                "status": "startup_failed",
                "error": str(e),
                "message": f"Inference service startup failed: {str(e)}",
            }
            result_queue.put(error_result)
    finally:
        # Clean up resources
        try:
            if loop and not loop.is_closed():
                loop.close()
        except:  # noqa: E722
            pass

        try:
            if worker:
                worker.cleanup()
        except:  # noqa: E722
            pass


class DistributedInferenceService:
    def __init__(self):
        self.task_queue = None
        self.result_queue = None
        self.processes = []
        self.is_running = False

    def start_distributed_inference(self, args) -> bool:
189
        self.args = args
PengGao's avatar
PengGao committed
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
        if self.is_running:
            logger.warning("Distributed inference service is already running")
            return True

        nproc_per_node = args.nproc_per_node
        if nproc_per_node <= 0:
            logger.error("nproc_per_node must be greater than 0")
            return False

        try:
            import random

            master_addr = "127.0.0.1"
            master_port = str(random.randint(20000, 29999))
            logger.info(f"Distributed inference service Master Addr: {master_addr}, Master Port: {master_port}")

            # Create shared queues
            self.task_queue = mp.Queue()
            self.result_queue = mp.Queue()

            # Start processes
            for rank in range(nproc_per_node):
                p = mp.Process(
                    target=_distributed_inference_worker,
                    args=(
                        rank,
                        nproc_per_node,
                        master_addr,
                        master_port,
                        args,
                        self.task_queue,
                        self.result_queue,
                    ),
                    daemon=True,
                )
                p.start()
                self.processes.append(p)

            self.is_running = True
            logger.info(f"Distributed inference service started successfully with {nproc_per_node} processes")
            return True

        except Exception as e:
            logger.exception(f"Error occurred while starting distributed inference service: {str(e)}")
            self.stop_distributed_inference()
            return False

    def stop_distributed_inference(self):
        if not self.is_running:
            return

        try:
            logger.info(f"Stopping {len(self.processes)} distributed inference service processes...")

            # Send stop signal
            if self.task_queue:
                for _ in self.processes:
                    self.task_queue.put(None)

            # Wait for processes to end
            for p in self.processes:
                try:
                    p.join(timeout=10)
                    if p.is_alive():
                        logger.warning(f"Process {p.pid} did not end within the specified time, forcing termination...")
                        p.terminate()
                        p.join(timeout=5)
                except:  # noqa: E722
                    pass

            logger.info("All distributed inference service processes have stopped")

        except Exception as e:
            logger.error(f"Error occurred while stopping distributed inference service: {str(e)}")
        finally:
            # Clean up resources
            self._clean_queues()
            self.processes = []
            self.task_queue = None
            self.result_queue = None
            self.is_running = False

    def _clean_queues(self):
        for queue_obj in [self.task_queue, self.result_queue]:
            if queue_obj:
                try:
                    while not queue_obj.empty():
                        queue_obj.get_nowait()
                except:  # noqa: E722
                    pass

    def submit_task(self, task_data: dict) -> bool:
        if not self.is_running or not self.task_queue:
            logger.error("Distributed inference service is not started")
            return False

        try:
            self.task_queue.put(task_data)
            return True
        except Exception as e:
            logger.error(f"Failed to submit task: {str(e)}")
            return False

    def wait_for_result(self, task_id: str, timeout: int = 300) -> Optional[dict]:
        if not self.is_running or not self.result_queue:
            return None

        start_time = time.time()

        while time.time() - start_time < timeout:
            try:
                result = self.result_queue.get(timeout=1.0)

                if result.get("task_id") == task_id:
                    return result
                else:
                    # Not the result for current task, put back in queue
                    self.result_queue.put(result)
                    time.sleep(0.1)

            except queue.Empty:
                continue

        return None

315
316
317
318
    def server_metadata(self):
        assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
        return {"nproc_per_node": self.args.nproc_per_node, "model_cls": self.args.model_cls, "model_path": self.args.model_path}

PengGao's avatar
PengGao committed
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377

class VideoGenerationService:
    def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
        self.file_service = file_service
        self.inference_service = inference_service

    async def generate_video(self, message: TaskRequest) -> TaskResponse:
        try:
            # Process image path
            task_data = {
                "task_id": message.task_id,
                "prompt": message.prompt,
                "use_prompt_enhancer": message.use_prompt_enhancer,
                "negative_prompt": message.negative_prompt,
                "image_path": message.image_path,
                "num_fragments": message.num_fragments,
                "save_video_path": message.save_video_path,
                "infer_steps": message.infer_steps,
                "target_video_length": message.target_video_length,
                "seed": message.seed,
                "audio_path": message.audio_path,
                "video_duration": message.video_duration,
            }

            # Process network image
            if message.image_path.startswith("http"):
                image_path = await self.file_service.download_image(message.image_path)
                task_data["image_path"] = str(image_path)

            # Process output path
            save_video_path = self.file_service.get_output_path(message.save_video_path)
            task_data["save_video_path"] = str(save_video_path)

            # Submit task to distributed inference service
            if not self.inference_service.submit_task(task_data):
                raise RuntimeError("Distributed inference service is not started")

            # Wait for result
            result = self.inference_service.wait_for_result(message.task_id)

            if result is None:
                raise RuntimeError("Task processing timeout")

            if result.get("status") == "success":
                ServiceStatus.complete_task(message)
                return TaskResponse(
                    task_id=message.task_id,
                    task_status="completed",
                    save_video_path=str(save_video_path),
                )
            else:
                error_msg = result.get("error", "Inference failed")
                ServiceStatus.record_failed_task(message, error=error_msg)
                raise RuntimeError(error_msg)

        except Exception as e:
            logger.error(f"Task {message.task_id} processing failed: {str(e)}")
            ServiceStatus.record_failed_task(message, error=str(e))
            raise