service.py 18 KB
Newer Older
gaclove's avatar
gaclove committed
1
import asyncio
PengGao's avatar
PengGao committed
2
3
import json
import os
PengGao's avatar
PengGao committed
4
5
import uuid
from pathlib import Path
PengGao's avatar
PengGao committed
6
from typing import Any, Dict, Optional
PengGao's avatar
PengGao committed
7
8
9
from urllib.parse import urlparse

import httpx
PengGao's avatar
PengGao committed
10
import torch
PengGao's avatar
PengGao committed
11
12
13
from loguru import logger

from ..infer import init_runner
PengGao's avatar
PengGao committed
14
from ..utils.set_config import set_config
15
from .audio_utils import is_base64_audio, save_base64_audio
PengGao's avatar
PengGao committed
16
from .distributed_utils import DistributedManager
gaclove's avatar
gaclove committed
17
from .image_utils import is_base64_image, save_base64_image
PengGao's avatar
PengGao committed
18
from .schema import TaskRequest, TaskResponse
PengGao's avatar
PengGao committed
19
20
21
22
23
24
25
26
27


class FileService:
    def __init__(self, cache_dir: Path):
        self.cache_dir = cache_dir
        self.input_image_dir = cache_dir / "inputs" / "imgs"
        self.input_audio_dir = cache_dir / "inputs" / "audios"
        self.output_video_dir = cache_dir / "outputs"

gaclove's avatar
gaclove committed
28
29
30
31
32
33
34
        self._http_client = None
        self._client_lock = asyncio.Lock()

        self.max_retries = 3
        self.retry_delay = 1.0
        self.max_retry_delay = 10.0

PengGao's avatar
PengGao committed
35
36
37
38
39
40
41
        for directory in [
            self.input_image_dir,
            self.output_video_dir,
            self.input_audio_dir,
        ]:
            directory.mkdir(parents=True, exist_ok=True)

gaclove's avatar
gaclove committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    async def _get_http_client(self) -> httpx.AsyncClient:
        """Get or create a persistent HTTP client with connection pooling."""
        async with self._client_lock:
            if self._http_client is None or self._http_client.is_closed:
                timeout = httpx.Timeout(
                    connect=10.0,
                    read=30.0,
                    write=10.0,
                    pool=5.0,
                )
                limits = httpx.Limits(max_keepalive_connections=5, max_connections=10, keepalive_expiry=30.0)
                self._http_client = httpx.AsyncClient(verify=False, timeout=timeout, limits=limits, follow_redirects=True)
            return self._http_client

    async def _download_with_retry(self, url: str, max_retries: Optional[int] = None) -> httpx.Response:
        """Download with exponential backoff retry logic."""
        if max_retries is None:
            max_retries = self.max_retries

        last_exception = None

        retry_delay = self.retry_delay

        for attempt in range(max_retries):
            try:
                client = await self._get_http_client()
                response = await client.get(url)

                if response.status_code == 200:
                    return response
                elif response.status_code >= 500:
                    logger.warning(f"Server error {response.status_code} for {url}, attempt {attempt + 1}/{max_retries}")
                    last_exception = httpx.HTTPStatusError(f"Server returned {response.status_code}", request=response.request, response=response)
                else:
                    raise httpx.HTTPStatusError(f"Client error {response.status_code}", request=response.request, response=response)

            except (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError) as e:
                logger.warning(f"Connection error for {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
                last_exception = e
            except httpx.HTTPStatusError as e:
                if e.response and e.response.status_code < 500:
                    raise
                last_exception = e
            except Exception as e:
                logger.error(f"Unexpected error downloading {url}: {str(e)}")
                last_exception = e

            if attempt < max_retries - 1:
                await asyncio.sleep(retry_delay)
                retry_delay = min(retry_delay * 2, self.max_retry_delay)

        error_msg = f"All {max_retries} connection attempts failed for {url}"
        if last_exception:
            error_msg += f": {str(last_exception)}"
        raise httpx.ConnectError(error_msg)

PengGao's avatar
PengGao committed
98
    async def download_image(self, image_url: str) -> Path:
gaclove's avatar
gaclove committed
99
        """Download image with retry logic and proper error handling."""
PengGao's avatar
PengGao committed
100
        try:
gaclove's avatar
gaclove committed
101
102
103
            parsed_url = urlparse(image_url)
            if not parsed_url.scheme or not parsed_url.netloc:
                raise ValueError(f"Invalid URL format: {image_url}")
PengGao's avatar
PengGao committed
104

gaclove's avatar
gaclove committed
105
            response = await self._download_with_retry(image_url)
PengGao's avatar
PengGao committed
106

gaclove's avatar
gaclove committed
107
            image_name = Path(parsed_url.path).name
PengGao's avatar
PengGao committed
108
            if not image_name:
gaclove's avatar
gaclove committed
109
                image_name = f"{uuid.uuid4()}.jpg"
PengGao's avatar
PengGao committed
110
111
112
113
114
115
116

            image_path = self.input_image_dir / image_name
            image_path.parent.mkdir(parents=True, exist_ok=True)

            with open(image_path, "wb") as f:
                f.write(response.content)

gaclove's avatar
gaclove committed
117
            logger.info(f"Successfully downloaded image from {image_url} to {image_path}")
PengGao's avatar
PengGao committed
118
            return image_path
gaclove's avatar
gaclove committed
119
120
121
122
123
124
125
126
127
128
129

        except httpx.ConnectError as e:
            logger.error(f"Connection error downloading image from {image_url}: {str(e)}")
            raise ValueError(f"Failed to connect to {image_url}: {str(e)}")
        except httpx.TimeoutException as e:
            logger.error(f"Timeout downloading image from {image_url}: {str(e)}")
            raise ValueError(f"Download timeout for {image_url}: {str(e)}")
        except httpx.HTTPStatusError as e:
            logger.error(f"HTTP error downloading image from {image_url}: {str(e)}")
            raise ValueError(f"HTTP error for {image_url}: {str(e)}")
        except ValueError as e:
PengGao's avatar
PengGao committed
130
            raise
gaclove's avatar
gaclove committed
131
132
133
        except Exception as e:
            logger.error(f"Unexpected error downloading image from {image_url}: {str(e)}")
            raise ValueError(f"Failed to download image from {image_url}: {str(e)}")
PengGao's avatar
PengGao committed
134

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    async def download_audio(self, audio_url: str) -> Path:
        """Download audio with retry logic and proper error handling."""
        try:
            parsed_url = urlparse(audio_url)
            if not parsed_url.scheme or not parsed_url.netloc:
                raise ValueError(f"Invalid URL format: {audio_url}")

            response = await self._download_with_retry(audio_url)

            audio_name = Path(parsed_url.path).name
            if not audio_name:
                audio_name = f"{uuid.uuid4()}.mp3"

            audio_path = self.input_audio_dir / audio_name
            audio_path.parent.mkdir(parents=True, exist_ok=True)

            with open(audio_path, "wb") as f:
                f.write(response.content)

            logger.info(f"Successfully downloaded audio from {audio_url} to {audio_path}")
            return audio_path

        except httpx.ConnectError as e:
            logger.error(f"Connection error downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"Failed to connect to {audio_url}: {str(e)}")
        except httpx.TimeoutException as e:
            logger.error(f"Timeout downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"Download timeout for {audio_url}: {str(e)}")
        except httpx.HTTPStatusError as e:
            logger.error(f"HTTP error downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"HTTP error for {audio_url}: {str(e)}")
        except ValueError as e:
            raise
        except Exception as e:
            logger.error(f"Unexpected error downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"Failed to download audio from {audio_url}: {str(e)}")

PengGao's avatar
PengGao committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
        file_extension = Path(filename).suffix
        unique_filename = f"{uuid.uuid4()}{file_extension}"
        file_path = self.input_image_dir / unique_filename

        with open(file_path, "wb") as f:
            f.write(file_content)

        return file_path

    def get_output_path(self, save_video_path: str) -> Path:
        video_path = Path(save_video_path)
        if not video_path.is_absolute():
            return self.output_video_dir / save_video_path
        return video_path

gaclove's avatar
gaclove committed
188
189
190
191
192
193
194
    async def cleanup(self):
        """Cleanup resources including HTTP client."""
        async with self._client_lock:
            if self._http_client and not self._http_client.is_closed:
                await self._http_client.aclose()
                self._http_client = None

PengGao's avatar
PengGao committed
195

PengGao's avatar
PengGao committed
196
197
class TorchrunInferenceWorker:
    """Worker class for torchrun-based distributed inference"""
PengGao's avatar
PengGao committed
198

PengGao's avatar
PengGao committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    def __init__(self):
        self.rank = int(os.environ.get("LOCAL_RANK", 0))
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.runner = None
        self.dist_manager = DistributedManager()
        self.request_queue = asyncio.Queue() if self.rank == 0 else None
        self.processing = False  # Track if currently processing a request

    def init(self, args) -> bool:
        """Initialize the worker with model and distributed setup"""
        try:
            # Initialize distributed process group using torchrun env vars
            if self.world_size > 1:
                if not self.dist_manager.init_process_group():
                    raise RuntimeError("Failed to initialize distributed process group")
            else:
                # Single GPU mode
                self.dist_manager.rank = 0
                self.dist_manager.world_size = 1
                self.dist_manager.device = "cuda:0" if torch.cuda.is_available() else "cpu"
                self.dist_manager.is_initialized = False
PengGao's avatar
PengGao committed
220

PengGao's avatar
PengGao committed
221
222
223
224
            # Initialize model
            config = set_config(args)
            if self.rank == 0:
                logger.info(f"Config:\n {json.dumps(config, ensure_ascii=False, indent=4)}")
PengGao's avatar
PengGao committed
225

PengGao's avatar
PengGao committed
226
227
            self.runner = init_runner(config)
            logger.info(f"Rank {self.rank}/{self.world_size - 1} initialization completed")
PengGao's avatar
PengGao committed
228

PengGao's avatar
PengGao committed
229
            return True
PengGao's avatar
PengGao committed
230

PengGao's avatar
PengGao committed
231
232
233
        except Exception as e:
            logger.error(f"Rank {self.rank} initialization failed: {str(e)}")
            return False
gaclove's avatar
gaclove committed
234

PengGao's avatar
PengGao committed
235
236
    async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
        """Process a single inference request
gaclove's avatar
gaclove committed
237

PengGao's avatar
PengGao committed
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
        Note: We keep the inference synchronous to maintain NCCL/CUDA context integrity.
        The async wrapper allows FastAPI to handle other requests while this runs.
        """
        try:
            # Only rank 0 broadcasts task data (worker processes already received it in worker_loop)
            if self.world_size > 1 and self.rank == 0:
                task_data = self.dist_manager.broadcast_task_data(task_data)

            # Run inference directly - torchrun handles the parallelization
            # Using asyncio.to_thread would be risky with NCCL operations
            # Instead, we rely on FastAPI's async handling and queue management
            self.runner.set_inputs(task_data)
            self.runner.run_pipeline()

            # Small yield to allow other async operations if needed
            await asyncio.sleep(0)

            # Synchronize all ranks
            if self.world_size > 1:
                self.dist_manager.barrier()

            # Only rank 0 returns the result
            if self.rank == 0:
                return {
                    "task_id": task_data["task_id"],
                    "status": "success",
                    "save_video_path": task_data.get("video_path", task_data["save_video_path"]),
                    "message": "Inference completed",
                }
            else:
                return None

        except Exception as e:
            logger.error(f"Rank {self.rank} inference failed: {str(e)}")
            if self.world_size > 1:
                self.dist_manager.barrier()

            if self.rank == 0:
                return {
                    "task_id": task_data.get("task_id", "unknown"),
                    "status": "failed",
                    "error": str(e),
                    "message": f"Inference failed: {str(e)}",
                }
PengGao's avatar
PengGao committed
282
            else:
PengGao's avatar
PengGao committed
283
284
285
286
287
288
289
                return None

    async def worker_loop(self):
        """Non-rank-0 workers: Listen for broadcast tasks"""
        while True:
            try:
                task_data = self.dist_manager.broadcast_task_data()
gaclove's avatar
gaclove committed
290
                if task_data is None:
PengGao's avatar
PengGao committed
291
                    logger.info(f"Rank {self.rank} received stop signal")
PengGao's avatar
PengGao committed
292
293
                    break

PengGao's avatar
PengGao committed
294
295
296
297
298
299
300
301
                await self.process_request(task_data)

            except Exception as e:
                logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
                continue

    def cleanup(self):
        self.dist_manager.cleanup()
PengGao's avatar
PengGao committed
302
303
304
305


class DistributedInferenceService:
    def __init__(self):
PengGao's avatar
PengGao committed
306
        self.worker = None
PengGao's avatar
PengGao committed
307
        self.is_running = False
PengGao's avatar
PengGao committed
308
        self.args = None
PengGao's avatar
PengGao committed
309
310

    def start_distributed_inference(self, args) -> bool:
311
        self.args = args
PengGao's avatar
PengGao committed
312
313
314
315
316
        if self.is_running:
            logger.warning("Distributed inference service is already running")
            return True

        try:
PengGao's avatar
PengGao committed
317
318
319
320
            self.worker = TorchrunInferenceWorker()

            if not self.worker.init(args):
                raise RuntimeError("Worker initialization failed")
PengGao's avatar
PengGao committed
321
322

            self.is_running = True
PengGao's avatar
PengGao committed
323
            logger.info(f"Rank {self.worker.rank} inference service started successfully")
PengGao's avatar
PengGao committed
324
325
326
            return True

        except Exception as e:
PengGao's avatar
PengGao committed
327
            logger.error(f"Error starting inference service: {str(e)}")
PengGao's avatar
PengGao committed
328
329
330
331
332
333
334
335
            self.stop_distributed_inference()
            return False

    def stop_distributed_inference(self):
        if not self.is_running:
            return

        try:
PengGao's avatar
PengGao committed
336
337
338
            if self.worker:
                self.worker.cleanup()
            logger.info("Inference service stopped")
PengGao's avatar
PengGao committed
339
        except Exception as e:
PengGao's avatar
PengGao committed
340
            logger.error(f"Error stopping inference service: {str(e)}")
PengGao's avatar
PengGao committed
341
        finally:
PengGao's avatar
PengGao committed
342
            self.worker = None
PengGao's avatar
PengGao committed
343
344
            self.is_running = False

PengGao's avatar
PengGao committed
345
346
347
    async def submit_task_async(self, task_data: dict) -> Optional[dict]:
        if not self.is_running or not self.worker:
            logger.error("Inference service is not started")
PengGao's avatar
PengGao committed
348
349
            return None

PengGao's avatar
PengGao committed
350
        if self.worker.rank != 0:
gaclove's avatar
gaclove committed
351
352
            return None

PengGao's avatar
PengGao committed
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
        try:
            if self.worker.processing:
                # If we want to support queueing, we can add the task to queue
                # For now, we'll process sequentially
                logger.info(f"Waiting for previous task to complete before processing task {task_data.get('task_id')}")

            self.worker.processing = True
            result = await self.worker.process_request(task_data)
            self.worker.processing = False
            return result
        except Exception as e:
            self.worker.processing = False
            logger.error(f"Failed to process task: {str(e)}")
            return {
                "task_id": task_data.get("task_id", "unknown"),
                "status": "failed",
                "error": str(e),
                "message": f"Task processing failed: {str(e)}",
            }
PengGao's avatar
PengGao committed
372

373
374
    def server_metadata(self):
        assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
PengGao's avatar
PengGao committed
375
376
377
378
379
380
        return {"nproc_per_node": self.worker.world_size, "model_cls": self.args.model_cls, "model_path": self.args.model_path}

    async def run_worker_loop(self):
        """Run the worker loop for non-rank-0 processes"""
        if self.worker and self.worker.rank != 0:
            await self.worker.worker_loop()
381

PengGao's avatar
PengGao committed
382
383
384
385
386
387

class VideoGenerationService:
    def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
        self.file_service = file_service
        self.inference_service = inference_service

gaclove's avatar
gaclove committed
388
    async def generate_video_with_stop_event(self, message: TaskRequest, stop_event) -> Optional[TaskResponse]:
PengGao's avatar
PengGao committed
389
        """Generate video using torchrun-based inference"""
PengGao's avatar
PengGao committed
390
        try:
391
392
            task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
            task_data["task_id"] = message.task_id
PengGao's avatar
PengGao committed
393

gaclove's avatar
gaclove committed
394
395
396
397
398
399
400
401
402
403
404
405
406
            if stop_event.is_set():
                logger.info(f"Task {message.task_id} cancelled before processing")
                return None

            if "image_path" in message.model_fields_set and message.image_path:
                if message.image_path.startswith("http"):
                    image_path = await self.file_service.download_image(message.image_path)
                    task_data["image_path"] = str(image_path)
                elif is_base64_image(message.image_path):
                    image_path = save_base64_image(message.image_path, str(self.file_service.input_image_dir))
                    task_data["image_path"] = str(image_path)
                else:
                    task_data["image_path"] = message.image_path
PengGao's avatar
PengGao committed
407

PengGao's avatar
PengGao committed
408
409
                logger.info(f"Task {message.task_id} image path: {task_data['image_path']}")

410
411
412
413
414
415
416
417
418
419
            if "audio_path" in message.model_fields_set and message.audio_path:
                if message.audio_path.startswith("http"):
                    audio_path = await self.file_service.download_audio(message.audio_path)
                    task_data["audio_path"] = str(audio_path)
                elif is_base64_audio(message.audio_path):
                    audio_path = save_base64_audio(message.audio_path, str(self.file_service.input_audio_dir))
                    task_data["audio_path"] = str(audio_path)
                else:
                    task_data["audio_path"] = message.audio_path

PengGao's avatar
PengGao committed
420
421
                logger.info(f"Task {message.task_id} audio path: {task_data['audio_path']}")

gaclove's avatar
gaclove committed
422
423
424
            actual_save_path = self.file_service.get_output_path(message.save_video_path)
            task_data["save_video_path"] = str(actual_save_path)
            task_data["video_path"] = message.save_video_path
PengGao's avatar
PengGao committed
425

PengGao's avatar
PengGao committed
426
            result = await self.inference_service.submit_task_async(task_data)
PengGao's avatar
PengGao committed
427
428

            if result is None:
gaclove's avatar
gaclove committed
429
430
431
                if stop_event.is_set():
                    logger.info(f"Task {message.task_id} cancelled during processing")
                    return None
PengGao's avatar
PengGao committed
432
                raise RuntimeError("Task processing failed")
PengGao's avatar
PengGao committed
433
434
435
436
437

            if result.get("status") == "success":
                return TaskResponse(
                    task_id=message.task_id,
                    task_status="completed",
gaclove's avatar
gaclove committed
438
                    save_video_path=message.save_video_path,  # Return original path
PengGao's avatar
PengGao committed
439
440
441
442
443
444
445
446
                )
            else:
                error_msg = result.get("error", "Inference failed")
                raise RuntimeError(error_msg)

        except Exception as e:
            logger.error(f"Task {message.task_id} processing failed: {str(e)}")
            raise