service.py 20.4 KB
Newer Older
gaclove's avatar
gaclove committed
1
import asyncio
PengGao's avatar
PengGao committed
2
3
import json
import os
PengGao's avatar
PengGao committed
4
5
import uuid
from pathlib import Path
PengGao's avatar
PengGao committed
6
from typing import Any, Dict, Optional
PengGao's avatar
PengGao committed
7
8
9
from urllib.parse import urlparse

import httpx
PengGao's avatar
PengGao committed
10
import torch
11
from easydict import EasyDict
PengGao's avatar
PengGao committed
12
13
14
from loguru import logger

from ..infer import init_runner
15
from ..utils.input_info import set_input_info
PengGao's avatar
PengGao committed
16
from ..utils.set_config import set_config
17
from .audio_utils import is_base64_audio, save_base64_audio
PengGao's avatar
PengGao committed
18
from .distributed_utils import DistributedManager
gaclove's avatar
gaclove committed
19
from .image_utils import is_base64_image, save_base64_image
PengGao's avatar
PengGao committed
20
from .schema import TaskRequest, TaskResponse
PengGao's avatar
PengGao committed
21
22
23
24
25
26
27
28
29


class FileService:
    def __init__(self, cache_dir: Path):
        self.cache_dir = cache_dir
        self.input_image_dir = cache_dir / "inputs" / "imgs"
        self.input_audio_dir = cache_dir / "inputs" / "audios"
        self.output_video_dir = cache_dir / "outputs"

gaclove's avatar
gaclove committed
30
31
32
33
34
35
36
        self._http_client = None
        self._client_lock = asyncio.Lock()

        self.max_retries = 3
        self.retry_delay = 1.0
        self.max_retry_delay = 10.0

PengGao's avatar
PengGao committed
37
38
39
40
41
42
43
        for directory in [
            self.input_image_dir,
            self.output_video_dir,
            self.input_audio_dir,
        ]:
            directory.mkdir(parents=True, exist_ok=True)

gaclove's avatar
gaclove committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    async def _get_http_client(self) -> httpx.AsyncClient:
        """Get or create a persistent HTTP client with connection pooling."""
        async with self._client_lock:
            if self._http_client is None or self._http_client.is_closed:
                timeout = httpx.Timeout(
                    connect=10.0,
                    read=30.0,
                    write=10.0,
                    pool=5.0,
                )
                limits = httpx.Limits(max_keepalive_connections=5, max_connections=10, keepalive_expiry=30.0)
                self._http_client = httpx.AsyncClient(verify=False, timeout=timeout, limits=limits, follow_redirects=True)
            return self._http_client

    async def _download_with_retry(self, url: str, max_retries: Optional[int] = None) -> httpx.Response:
        """Download with exponential backoff retry logic."""
        if max_retries is None:
            max_retries = self.max_retries

        last_exception = None

        retry_delay = self.retry_delay

        for attempt in range(max_retries):
            try:
                client = await self._get_http_client()
                response = await client.get(url)

                if response.status_code == 200:
                    return response
                elif response.status_code >= 500:
                    logger.warning(f"Server error {response.status_code} for {url}, attempt {attempt + 1}/{max_retries}")
                    last_exception = httpx.HTTPStatusError(f"Server returned {response.status_code}", request=response.request, response=response)
                else:
                    raise httpx.HTTPStatusError(f"Client error {response.status_code}", request=response.request, response=response)

            except (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError) as e:
                logger.warning(f"Connection error for {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
                last_exception = e
            except httpx.HTTPStatusError as e:
                if e.response and e.response.status_code < 500:
                    raise
                last_exception = e
            except Exception as e:
                logger.error(f"Unexpected error downloading {url}: {str(e)}")
                last_exception = e

            if attempt < max_retries - 1:
                await asyncio.sleep(retry_delay)
                retry_delay = min(retry_delay * 2, self.max_retry_delay)

        error_msg = f"All {max_retries} connection attempts failed for {url}"
        if last_exception:
            error_msg += f": {str(last_exception)}"
        raise httpx.ConnectError(error_msg)

PengGao's avatar
PengGao committed
100
    async def download_image(self, image_url: str) -> Path:
gaclove's avatar
gaclove committed
101
        """Download image with retry logic and proper error handling."""
PengGao's avatar
PengGao committed
102
        try:
gaclove's avatar
gaclove committed
103
104
105
            parsed_url = urlparse(image_url)
            if not parsed_url.scheme or not parsed_url.netloc:
                raise ValueError(f"Invalid URL format: {image_url}")
PengGao's avatar
PengGao committed
106

gaclove's avatar
gaclove committed
107
            response = await self._download_with_retry(image_url)
PengGao's avatar
PengGao committed
108

gaclove's avatar
gaclove committed
109
            image_name = Path(parsed_url.path).name
PengGao's avatar
PengGao committed
110
            if not image_name:
gaclove's avatar
gaclove committed
111
                image_name = f"{uuid.uuid4()}.jpg"
PengGao's avatar
PengGao committed
112
113
114
115
116
117
118

            image_path = self.input_image_dir / image_name
            image_path.parent.mkdir(parents=True, exist_ok=True)

            with open(image_path, "wb") as f:
                f.write(response.content)

gaclove's avatar
gaclove committed
119
            logger.info(f"Successfully downloaded image from {image_url} to {image_path}")
PengGao's avatar
PengGao committed
120
            return image_path
gaclove's avatar
gaclove committed
121
122
123
124
125
126
127
128
129
130
131

        except httpx.ConnectError as e:
            logger.error(f"Connection error downloading image from {image_url}: {str(e)}")
            raise ValueError(f"Failed to connect to {image_url}: {str(e)}")
        except httpx.TimeoutException as e:
            logger.error(f"Timeout downloading image from {image_url}: {str(e)}")
            raise ValueError(f"Download timeout for {image_url}: {str(e)}")
        except httpx.HTTPStatusError as e:
            logger.error(f"HTTP error downloading image from {image_url}: {str(e)}")
            raise ValueError(f"HTTP error for {image_url}: {str(e)}")
        except ValueError as e:
PengGao's avatar
PengGao committed
132
            raise
gaclove's avatar
gaclove committed
133
134
135
        except Exception as e:
            logger.error(f"Unexpected error downloading image from {image_url}: {str(e)}")
            raise ValueError(f"Failed to download image from {image_url}: {str(e)}")
PengGao's avatar
PengGao committed
136

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    async def download_audio(self, audio_url: str) -> Path:
        """Download audio with retry logic and proper error handling."""
        try:
            parsed_url = urlparse(audio_url)
            if not parsed_url.scheme or not parsed_url.netloc:
                raise ValueError(f"Invalid URL format: {audio_url}")

            response = await self._download_with_retry(audio_url)

            audio_name = Path(parsed_url.path).name
            if not audio_name:
                audio_name = f"{uuid.uuid4()}.mp3"

            audio_path = self.input_audio_dir / audio_name
            audio_path.parent.mkdir(parents=True, exist_ok=True)

            with open(audio_path, "wb") as f:
                f.write(response.content)

            logger.info(f"Successfully downloaded audio from {audio_url} to {audio_path}")
            return audio_path

        except httpx.ConnectError as e:
            logger.error(f"Connection error downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"Failed to connect to {audio_url}: {str(e)}")
        except httpx.TimeoutException as e:
            logger.error(f"Timeout downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"Download timeout for {audio_url}: {str(e)}")
        except httpx.HTTPStatusError as e:
            logger.error(f"HTTP error downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"HTTP error for {audio_url}: {str(e)}")
        except ValueError as e:
            raise
        except Exception as e:
            logger.error(f"Unexpected error downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"Failed to download audio from {audio_url}: {str(e)}")

PengGao's avatar
PengGao committed
174
175
176
177
178
179
180
181
182
183
    def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
        file_extension = Path(filename).suffix
        unique_filename = f"{uuid.uuid4()}{file_extension}"
        file_path = self.input_image_dir / unique_filename

        with open(file_path, "wb") as f:
            f.write(file_content)

        return file_path

184
185
    def get_output_path(self, save_result_path: str) -> Path:
        video_path = Path(save_result_path)
PengGao's avatar
PengGao committed
186
        if not video_path.is_absolute():
187
            return self.output_video_dir / save_result_path
PengGao's avatar
PengGao committed
188
189
        return video_path

gaclove's avatar
gaclove committed
190
191
192
193
194
195
196
    async def cleanup(self):
        """Cleanup resources including HTTP client."""
        async with self._client_lock:
            if self._http_client and not self._http_client.is_closed:
                await self._http_client.aclose()
                self._http_client = None

PengGao's avatar
PengGao committed
197

PengGao's avatar
PengGao committed
198
199
class TorchrunInferenceWorker:
    """Worker class for torchrun-based distributed inference"""
PengGao's avatar
PengGao committed
200

PengGao's avatar
PengGao committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
    def __init__(self):
        self.rank = int(os.environ.get("LOCAL_RANK", 0))
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.runner = None
        self.dist_manager = DistributedManager()
        self.processing = False  # Track if currently processing a request

    def init(self, args) -> bool:
        """Initialize the worker with model and distributed setup"""
        try:
            # Initialize distributed process group using torchrun env vars
            if self.world_size > 1:
                if not self.dist_manager.init_process_group():
                    raise RuntimeError("Failed to initialize distributed process group")
            else:
                # Single GPU mode
                self.dist_manager.rank = 0
                self.dist_manager.world_size = 1
                self.dist_manager.device = "cuda:0" if torch.cuda.is_available() else "cpu"
                self.dist_manager.is_initialized = False
PengGao's avatar
PengGao committed
221

PengGao's avatar
PengGao committed
222
223
224
225
            # Initialize model
            config = set_config(args)
            if self.rank == 0:
                logger.info(f"Config:\n {json.dumps(config, ensure_ascii=False, indent=4)}")
PengGao's avatar
PengGao committed
226

PengGao's avatar
PengGao committed
227
228
            self.runner = init_runner(config)
            logger.info(f"Rank {self.rank}/{self.world_size - 1} initialization completed")
PengGao's avatar
PengGao committed
229

PengGao's avatar
PengGao committed
230
            return True
PengGao's avatar
PengGao committed
231

PengGao's avatar
PengGao committed
232
233
234
        except Exception as e:
            logger.error(f"Rank {self.rank} initialization failed: {str(e)}")
            return False
gaclove's avatar
gaclove committed
235

PengGao's avatar
PengGao committed
236
237
    async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
        """Process a single inference request
gaclove's avatar
gaclove committed
238

PengGao's avatar
PengGao committed
239
240
241
242
243
244
245
246
247
248
249
        Note: We keep the inference synchronous to maintain NCCL/CUDA context integrity.
        The async wrapper allows FastAPI to handle other requests while this runs.
        """
        try:
            # Only rank 0 broadcasts task data (worker processes already received it in worker_loop)
            if self.world_size > 1 and self.rank == 0:
                task_data = self.dist_manager.broadcast_task_data(task_data)

            # Run inference directly - torchrun handles the parallelization
            # Using asyncio.to_thread would be risky with NCCL operations
            # Instead, we rely on FastAPI's async handling and queue management
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

            task_data["task"] = self.runner.config["task"]
            task_data["return_result_tensor"] = False
            task_data["negative_prompt"] = task_data.get("negative_prompt", "")

            # must be convert
            task_data = EasyDict(task_data)
            input_info = set_input_info(task_data)

            # update lock config
            self.runner.set_config(task_data)

            # print("input_info==>", input_info)

            self.runner.run_pipeline(input_info)
PengGao's avatar
PengGao committed
265
266
267
268
269
270
271
272
273
274
275
276
277

            # Small yield to allow other async operations if needed
            await asyncio.sleep(0)

            # Synchronize all ranks
            if self.world_size > 1:
                self.dist_manager.barrier()

            # Only rank 0 returns the result
            if self.rank == 0:
                return {
                    "task_id": task_data["task_id"],
                    "status": "success",
278
                    "save_result_path": task_data.get("video_path", task_data["save_result_path"]),
PengGao's avatar
PengGao committed
279
280
281
282
283
284
                    "message": "Inference completed",
                }
            else:
                return None

        except Exception as e:
285
            logger.exception(f"Rank {self.rank} inference failed: {str(e)}")
PengGao's avatar
PengGao committed
286
287
288
289
290
291
292
293
294
295
            if self.world_size > 1:
                self.dist_manager.barrier()

            if self.rank == 0:
                return {
                    "task_id": task_data.get("task_id", "unknown"),
                    "status": "failed",
                    "error": str(e),
                    "message": f"Inference failed: {str(e)}",
                }
PengGao's avatar
PengGao committed
296
            else:
PengGao's avatar
PengGao committed
297
298
299
300
301
302
303
                return None

    async def worker_loop(self):
        """Non-rank-0 workers: Listen for broadcast tasks"""
        while True:
            try:
                task_data = self.dist_manager.broadcast_task_data()
gaclove's avatar
gaclove committed
304
                if task_data is None:
PengGao's avatar
PengGao committed
305
                    logger.info(f"Rank {self.rank} received stop signal")
PengGao's avatar
PengGao committed
306
307
                    break

PengGao's avatar
PengGao committed
308
309
310
311
312
313
314
315
                await self.process_request(task_data)

            except Exception as e:
                logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
                continue

    def cleanup(self):
        self.dist_manager.cleanup()
PengGao's avatar
PengGao committed
316
317
318
319


class DistributedInferenceService:
    def __init__(self):
PengGao's avatar
PengGao committed
320
        self.worker = None
PengGao's avatar
PengGao committed
321
        self.is_running = False
PengGao's avatar
PengGao committed
322
        self.args = None
PengGao's avatar
PengGao committed
323
324

    def start_distributed_inference(self, args) -> bool:
325
        self.args = args
PengGao's avatar
PengGao committed
326
327
328
329
330
        if self.is_running:
            logger.warning("Distributed inference service is already running")
            return True

        try:
PengGao's avatar
PengGao committed
331
332
333
334
            self.worker = TorchrunInferenceWorker()

            if not self.worker.init(args):
                raise RuntimeError("Worker initialization failed")
PengGao's avatar
PengGao committed
335
336

            self.is_running = True
PengGao's avatar
PengGao committed
337
            logger.info(f"Rank {self.worker.rank} inference service started successfully")
PengGao's avatar
PengGao committed
338
339
340
            return True

        except Exception as e:
PengGao's avatar
PengGao committed
341
            logger.error(f"Error starting inference service: {str(e)}")
PengGao's avatar
PengGao committed
342
343
344
345
346
347
348
349
            self.stop_distributed_inference()
            return False

    def stop_distributed_inference(self):
        if not self.is_running:
            return

        try:
PengGao's avatar
PengGao committed
350
351
352
            if self.worker:
                self.worker.cleanup()
            logger.info("Inference service stopped")
PengGao's avatar
PengGao committed
353
        except Exception as e:
PengGao's avatar
PengGao committed
354
            logger.error(f"Error stopping inference service: {str(e)}")
PengGao's avatar
PengGao committed
355
        finally:
PengGao's avatar
PengGao committed
356
            self.worker = None
PengGao's avatar
PengGao committed
357
358
            self.is_running = False

PengGao's avatar
PengGao committed
359
360
361
    async def submit_task_async(self, task_data: dict) -> Optional[dict]:
        if not self.is_running or not self.worker:
            logger.error("Inference service is not started")
PengGao's avatar
PengGao committed
362
363
            return None

PengGao's avatar
PengGao committed
364
        if self.worker.rank != 0:
gaclove's avatar
gaclove committed
365
366
            return None

PengGao's avatar
PengGao committed
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
        try:
            if self.worker.processing:
                # If we want to support queueing, we can add the task to queue
                # For now, we'll process sequentially
                logger.info(f"Waiting for previous task to complete before processing task {task_data.get('task_id')}")

            self.worker.processing = True
            result = await self.worker.process_request(task_data)
            self.worker.processing = False
            return result
        except Exception as e:
            self.worker.processing = False
            logger.error(f"Failed to process task: {str(e)}")
            return {
                "task_id": task_data.get("task_id", "unknown"),
                "status": "failed",
                "error": str(e),
                "message": f"Task processing failed: {str(e)}",
            }
PengGao's avatar
PengGao committed
386

387
388
    def server_metadata(self):
        assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
PengGao's avatar
PengGao committed
389
390
391
392
393
394
        return {"nproc_per_node": self.worker.world_size, "model_cls": self.args.model_cls, "model_path": self.args.model_path}

    async def run_worker_loop(self):
        """Run the worker loop for non-rank-0 processes"""
        if self.worker and self.worker.rank != 0:
            await self.worker.worker_loop()
395

PengGao's avatar
PengGao committed
396
397
398
399
400
401

class VideoGenerationService:
    def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
        self.file_service = file_service
        self.inference_service = inference_service

gaclove's avatar
gaclove committed
402
    async def generate_video_with_stop_event(self, message: TaskRequest, stop_event) -> Optional[TaskResponse]:
PengGao's avatar
PengGao committed
403
        """Generate video using torchrun-based inference"""
PengGao's avatar
PengGao committed
404
        try:
405
406
            task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
            task_data["task_id"] = message.task_id
PengGao's avatar
PengGao committed
407

gaclove's avatar
gaclove committed
408
409
410
411
412
413
414
415
416
417
418
419
420
            if stop_event.is_set():
                logger.info(f"Task {message.task_id} cancelled before processing")
                return None

            if "image_path" in message.model_fields_set and message.image_path:
                if message.image_path.startswith("http"):
                    image_path = await self.file_service.download_image(message.image_path)
                    task_data["image_path"] = str(image_path)
                elif is_base64_image(message.image_path):
                    image_path = save_base64_image(message.image_path, str(self.file_service.input_image_dir))
                    task_data["image_path"] = str(image_path)
                else:
                    task_data["image_path"] = message.image_path
PengGao's avatar
PengGao committed
421

PengGao's avatar
PengGao committed
422
423
                logger.info(f"Task {message.task_id} image path: {task_data['image_path']}")

424
425
426
427
428
429
430
431
432
433
            if "audio_path" in message.model_fields_set and message.audio_path:
                if message.audio_path.startswith("http"):
                    audio_path = await self.file_service.download_audio(message.audio_path)
                    task_data["audio_path"] = str(audio_path)
                elif is_base64_audio(message.audio_path):
                    audio_path = save_base64_audio(message.audio_path, str(self.file_service.input_audio_dir))
                    task_data["audio_path"] = str(audio_path)
                else:
                    task_data["audio_path"] = message.audio_path

PengGao's avatar
PengGao committed
434
435
                logger.info(f"Task {message.task_id} audio path: {task_data['audio_path']}")

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
            if "talk_objects" in message.model_fields_set and message.talk_objects:
                task_data["talk_objects"] = [{} for _ in range(len(message.talk_objects))]

                for index, talk_object in enumerate(message.talk_objects):
                    if talk_object.audio.startswith("http"):
                        audio_path = await self.file_service.download_audio(talk_object.audio)
                        task_data["talk_objects"][index]["audio"] = str(audio_path)
                    elif is_base64_audio(talk_object.audio):
                        audio_path = save_base64_audio(talk_object.audio, str(self.file_service.input_audio_dir))
                        task_data["talk_objects"][index]["audio"] = str(audio_path)
                    else:
                        task_data["talk_objects"][index]["audio"] = talk_object.audio

                    if talk_object.mask.startswith("http"):
                        mask_path = await self.file_service.download_image(talk_object.mask)
                        task_data["talk_objects"][index]["mask"] = str(mask_path)
                    elif is_base64_image(talk_object.mask):
                        mask_path = save_base64_image(talk_object.mask, str(self.file_service.input_image_dir))
                        task_data["talk_objects"][index]["mask"] = str(mask_path)
                    else:
                        task_data["talk_objects"][index]["mask"] = talk_object.mask

                # FIXME(xxx): 存储成一个config.json , 然后将这个config.json 的路径,赋值给task_data["audio_path"]
                temp_path = self.file_service.cache_dir / uuid.uuid4().hex[:8]
                temp_path.mkdir(parents=True, exist_ok=True)
                task_data["audio_path"] = str(temp_path)

                config_path = temp_path / "config.json"
                with open(config_path, "w") as f:
                    json.dump({"talk_objects": task_data["talk_objects"]}, f)

467
468
469
            actual_save_path = self.file_service.get_output_path(message.save_result_path)
            task_data["save_result_path"] = str(actual_save_path)
            task_data["video_path"] = message.save_result_path
PengGao's avatar
PengGao committed
470

PengGao's avatar
PengGao committed
471
            result = await self.inference_service.submit_task_async(task_data)
PengGao's avatar
PengGao committed
472
473

            if result is None:
gaclove's avatar
gaclove committed
474
475
476
                if stop_event.is_set():
                    logger.info(f"Task {message.task_id} cancelled during processing")
                    return None
PengGao's avatar
PengGao committed
477
                raise RuntimeError("Task processing failed")
PengGao's avatar
PengGao committed
478
479
480
481
482

            if result.get("status") == "success":
                return TaskResponse(
                    task_id=message.task_id,
                    task_status="completed",
483
                    save_result_path=message.save_result_path,  # Return original path
PengGao's avatar
PengGao committed
484
485
486
487
488
489
                )
            else:
                error_msg = result.get("error", "Inference failed")
                raise RuntimeError(error_msg)

        except Exception as e:
490
            logger.exception(f"Task {message.task_id} processing failed: {str(e)}")
PengGao's avatar
PengGao committed
491
            raise