service.py 18 KB
Newer Older
gaclove's avatar
gaclove committed
1
import asyncio
PengGao's avatar
PengGao committed
2
3
import json
import os
PengGao's avatar
PengGao committed
4
5
import uuid
from pathlib import Path
PengGao's avatar
PengGao committed
6
from typing import Any, Dict, Optional
PengGao's avatar
PengGao committed
7
8
9
from urllib.parse import urlparse

import httpx
PengGao's avatar
PengGao committed
10
import torch
PengGao's avatar
PengGao committed
11
12
13
from loguru import logger

from ..infer import init_runner
PengGao's avatar
PengGao committed
14
from ..utils.set_config import set_config
15
from .audio_utils import is_base64_audio, save_base64_audio
PengGao's avatar
PengGao committed
16
from .distributed_utils import DistributedManager
gaclove's avatar
gaclove committed
17
from .image_utils import is_base64_image, save_base64_image
PengGao's avatar
PengGao committed
18
from .schema import TaskRequest, TaskResponse
PengGao's avatar
PengGao committed
19
20
21
22
23
24
25
26
27


class FileService:
    def __init__(self, cache_dir: Path):
        self.cache_dir = cache_dir
        self.input_image_dir = cache_dir / "inputs" / "imgs"
        self.input_audio_dir = cache_dir / "inputs" / "audios"
        self.output_video_dir = cache_dir / "outputs"

gaclove's avatar
gaclove committed
28
29
30
31
32
33
34
        self._http_client = None
        self._client_lock = asyncio.Lock()

        self.max_retries = 3
        self.retry_delay = 1.0
        self.max_retry_delay = 10.0

PengGao's avatar
PengGao committed
35
36
37
38
39
40
41
        for directory in [
            self.input_image_dir,
            self.output_video_dir,
            self.input_audio_dir,
        ]:
            directory.mkdir(parents=True, exist_ok=True)

gaclove's avatar
gaclove committed
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    async def _get_http_client(self) -> httpx.AsyncClient:
        """Get or create a persistent HTTP client with connection pooling."""
        async with self._client_lock:
            if self._http_client is None or self._http_client.is_closed:
                timeout = httpx.Timeout(
                    connect=10.0,
                    read=30.0,
                    write=10.0,
                    pool=5.0,
                )
                limits = httpx.Limits(max_keepalive_connections=5, max_connections=10, keepalive_expiry=30.0)
                self._http_client = httpx.AsyncClient(verify=False, timeout=timeout, limits=limits, follow_redirects=True)
            return self._http_client

    async def _download_with_retry(self, url: str, max_retries: Optional[int] = None) -> httpx.Response:
        """Download with exponential backoff retry logic."""
        if max_retries is None:
            max_retries = self.max_retries

        last_exception = None

        retry_delay = self.retry_delay

        for attempt in range(max_retries):
            try:
                client = await self._get_http_client()
                response = await client.get(url)

                if response.status_code == 200:
                    return response
                elif response.status_code >= 500:
                    logger.warning(f"Server error {response.status_code} for {url}, attempt {attempt + 1}/{max_retries}")
                    last_exception = httpx.HTTPStatusError(f"Server returned {response.status_code}", request=response.request, response=response)
                else:
                    raise httpx.HTTPStatusError(f"Client error {response.status_code}", request=response.request, response=response)

            except (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError) as e:
                logger.warning(f"Connection error for {url}, attempt {attempt + 1}/{max_retries}: {str(e)}")
                last_exception = e
            except httpx.HTTPStatusError as e:
                if e.response and e.response.status_code < 500:
                    raise
                last_exception = e
            except Exception as e:
                logger.error(f"Unexpected error downloading {url}: {str(e)}")
                last_exception = e

            if attempt < max_retries - 1:
                await asyncio.sleep(retry_delay)
                retry_delay = min(retry_delay * 2, self.max_retry_delay)

        error_msg = f"All {max_retries} connection attempts failed for {url}"
        if last_exception:
            error_msg += f": {str(last_exception)}"
        raise httpx.ConnectError(error_msg)

PengGao's avatar
PengGao committed
98
    async def download_image(self, image_url: str) -> Path:
gaclove's avatar
gaclove committed
99
        """Download image with retry logic and proper error handling."""
PengGao's avatar
PengGao committed
100
        try:
gaclove's avatar
gaclove committed
101
102
103
            parsed_url = urlparse(image_url)
            if not parsed_url.scheme or not parsed_url.netloc:
                raise ValueError(f"Invalid URL format: {image_url}")
PengGao's avatar
PengGao committed
104

gaclove's avatar
gaclove committed
105
            response = await self._download_with_retry(image_url)
PengGao's avatar
PengGao committed
106

gaclove's avatar
gaclove committed
107
            image_name = Path(parsed_url.path).name
PengGao's avatar
PengGao committed
108
            if not image_name:
gaclove's avatar
gaclove committed
109
                image_name = f"{uuid.uuid4()}.jpg"
PengGao's avatar
PengGao committed
110
111
112
113
114
115
116

            image_path = self.input_image_dir / image_name
            image_path.parent.mkdir(parents=True, exist_ok=True)

            with open(image_path, "wb") as f:
                f.write(response.content)

gaclove's avatar
gaclove committed
117
            logger.info(f"Successfully downloaded image from {image_url} to {image_path}")
PengGao's avatar
PengGao committed
118
            return image_path
gaclove's avatar
gaclove committed
119
120
121
122
123
124
125
126
127
128
129

        except httpx.ConnectError as e:
            logger.error(f"Connection error downloading image from {image_url}: {str(e)}")
            raise ValueError(f"Failed to connect to {image_url}: {str(e)}")
        except httpx.TimeoutException as e:
            logger.error(f"Timeout downloading image from {image_url}: {str(e)}")
            raise ValueError(f"Download timeout for {image_url}: {str(e)}")
        except httpx.HTTPStatusError as e:
            logger.error(f"HTTP error downloading image from {image_url}: {str(e)}")
            raise ValueError(f"HTTP error for {image_url}: {str(e)}")
        except ValueError as e:
PengGao's avatar
PengGao committed
130
            raise
gaclove's avatar
gaclove committed
131
132
133
        except Exception as e:
            logger.error(f"Unexpected error downloading image from {image_url}: {str(e)}")
            raise ValueError(f"Failed to download image from {image_url}: {str(e)}")
PengGao's avatar
PengGao committed
134

135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
    async def download_audio(self, audio_url: str) -> Path:
        """Download audio with retry logic and proper error handling."""
        try:
            parsed_url = urlparse(audio_url)
            if not parsed_url.scheme or not parsed_url.netloc:
                raise ValueError(f"Invalid URL format: {audio_url}")

            response = await self._download_with_retry(audio_url)

            audio_name = Path(parsed_url.path).name
            if not audio_name:
                audio_name = f"{uuid.uuid4()}.mp3"

            audio_path = self.input_audio_dir / audio_name
            audio_path.parent.mkdir(parents=True, exist_ok=True)

            with open(audio_path, "wb") as f:
                f.write(response.content)

            logger.info(f"Successfully downloaded audio from {audio_url} to {audio_path}")
            return audio_path

        except httpx.ConnectError as e:
            logger.error(f"Connection error downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"Failed to connect to {audio_url}: {str(e)}")
        except httpx.TimeoutException as e:
            logger.error(f"Timeout downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"Download timeout for {audio_url}: {str(e)}")
        except httpx.HTTPStatusError as e:
            logger.error(f"HTTP error downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"HTTP error for {audio_url}: {str(e)}")
        except ValueError as e:
            raise
        except Exception as e:
            logger.error(f"Unexpected error downloading audio from {audio_url}: {str(e)}")
            raise ValueError(f"Failed to download audio from {audio_url}: {str(e)}")

PengGao's avatar
PengGao committed
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    def save_uploaded_file(self, file_content: bytes, filename: str) -> Path:
        file_extension = Path(filename).suffix
        unique_filename = f"{uuid.uuid4()}{file_extension}"
        file_path = self.input_image_dir / unique_filename

        with open(file_path, "wb") as f:
            f.write(file_content)

        return file_path

    def get_output_path(self, save_video_path: str) -> Path:
        video_path = Path(save_video_path)
        if not video_path.is_absolute():
            return self.output_video_dir / save_video_path
        return video_path

gaclove's avatar
gaclove committed
188
189
190
191
192
193
194
    async def cleanup(self):
        """Cleanup resources including HTTP client."""
        async with self._client_lock:
            if self._http_client and not self._http_client.is_closed:
                await self._http_client.aclose()
                self._http_client = None

PengGao's avatar
PengGao committed
195

PengGao's avatar
PengGao committed
196
197
class TorchrunInferenceWorker:
    """Worker class for torchrun-based distributed inference"""
PengGao's avatar
PengGao committed
198

PengGao's avatar
PengGao committed
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    def __init__(self):
        self.rank = int(os.environ.get("LOCAL_RANK", 0))
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.runner = None
        self.dist_manager = DistributedManager()
        self.processing = False  # Track if currently processing a request

    def init(self, args) -> bool:
        """Initialize the worker with model and distributed setup"""
        try:
            # Initialize distributed process group using torchrun env vars
            if self.world_size > 1:
                if not self.dist_manager.init_process_group():
                    raise RuntimeError("Failed to initialize distributed process group")
            else:
                # Single GPU mode
                self.dist_manager.rank = 0
                self.dist_manager.world_size = 1
                self.dist_manager.device = "cuda:0" if torch.cuda.is_available() else "cpu"
                self.dist_manager.is_initialized = False
PengGao's avatar
PengGao committed
219

PengGao's avatar
PengGao committed
220
221
222
223
            # Initialize model
            config = set_config(args)
            if self.rank == 0:
                logger.info(f"Config:\n {json.dumps(config, ensure_ascii=False, indent=4)}")
PengGao's avatar
PengGao committed
224

PengGao's avatar
PengGao committed
225
226
            self.runner = init_runner(config)
            logger.info(f"Rank {self.rank}/{self.world_size - 1} initialization completed")
PengGao's avatar
PengGao committed
227

PengGao's avatar
PengGao committed
228
            return True
PengGao's avatar
PengGao committed
229

PengGao's avatar
PengGao committed
230
231
232
        except Exception as e:
            logger.error(f"Rank {self.rank} initialization failed: {str(e)}")
            return False
gaclove's avatar
gaclove committed
233

PengGao's avatar
PengGao committed
234
235
    async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
        """Process a single inference request
gaclove's avatar
gaclove committed
236

PengGao's avatar
PengGao committed
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
        Note: We keep the inference synchronous to maintain NCCL/CUDA context integrity.
        The async wrapper allows FastAPI to handle other requests while this runs.
        """
        try:
            # Only rank 0 broadcasts task data (worker processes already received it in worker_loop)
            if self.world_size > 1 and self.rank == 0:
                task_data = self.dist_manager.broadcast_task_data(task_data)

            # Run inference directly - torchrun handles the parallelization
            # Using asyncio.to_thread would be risky with NCCL operations
            # Instead, we rely on FastAPI's async handling and queue management
            self.runner.set_inputs(task_data)
            self.runner.run_pipeline()

            # Small yield to allow other async operations if needed
            await asyncio.sleep(0)

            # Synchronize all ranks
            if self.world_size > 1:
                self.dist_manager.barrier()

            # Only rank 0 returns the result
            if self.rank == 0:
                return {
                    "task_id": task_data["task_id"],
                    "status": "success",
                    "save_video_path": task_data.get("video_path", task_data["save_video_path"]),
                    "message": "Inference completed",
                }
            else:
                return None

        except Exception as e:
            logger.error(f"Rank {self.rank} inference failed: {str(e)}")
            if self.world_size > 1:
                self.dist_manager.barrier()

            if self.rank == 0:
                return {
                    "task_id": task_data.get("task_id", "unknown"),
                    "status": "failed",
                    "error": str(e),
                    "message": f"Inference failed: {str(e)}",
                }
PengGao's avatar
PengGao committed
281
            else:
PengGao's avatar
PengGao committed
282
283
284
285
286
287
288
                return None

    async def worker_loop(self):
        """Non-rank-0 workers: Listen for broadcast tasks"""
        while True:
            try:
                task_data = self.dist_manager.broadcast_task_data()
gaclove's avatar
gaclove committed
289
                if task_data is None:
PengGao's avatar
PengGao committed
290
                    logger.info(f"Rank {self.rank} received stop signal")
PengGao's avatar
PengGao committed
291
292
                    break

PengGao's avatar
PengGao committed
293
294
295
296
297
298
299
300
                await self.process_request(task_data)

            except Exception as e:
                logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
                continue

    def cleanup(self):
        self.dist_manager.cleanup()
PengGao's avatar
PengGao committed
301
302
303
304


class DistributedInferenceService:
    def __init__(self):
PengGao's avatar
PengGao committed
305
        self.worker = None
PengGao's avatar
PengGao committed
306
        self.is_running = False
PengGao's avatar
PengGao committed
307
        self.args = None
PengGao's avatar
PengGao committed
308
309

    def start_distributed_inference(self, args) -> bool:
310
        self.args = args
PengGao's avatar
PengGao committed
311
312
313
314
315
        if self.is_running:
            logger.warning("Distributed inference service is already running")
            return True

        try:
PengGao's avatar
PengGao committed
316
317
318
319
            self.worker = TorchrunInferenceWorker()

            if not self.worker.init(args):
                raise RuntimeError("Worker initialization failed")
PengGao's avatar
PengGao committed
320
321

            self.is_running = True
PengGao's avatar
PengGao committed
322
            logger.info(f"Rank {self.worker.rank} inference service started successfully")
PengGao's avatar
PengGao committed
323
324
325
            return True

        except Exception as e:
PengGao's avatar
PengGao committed
326
            logger.error(f"Error starting inference service: {str(e)}")
PengGao's avatar
PengGao committed
327
328
329
330
331
332
333
334
            self.stop_distributed_inference()
            return False

    def stop_distributed_inference(self):
        if not self.is_running:
            return

        try:
PengGao's avatar
PengGao committed
335
336
337
            if self.worker:
                self.worker.cleanup()
            logger.info("Inference service stopped")
PengGao's avatar
PengGao committed
338
        except Exception as e:
PengGao's avatar
PengGao committed
339
            logger.error(f"Error stopping inference service: {str(e)}")
PengGao's avatar
PengGao committed
340
        finally:
PengGao's avatar
PengGao committed
341
            self.worker = None
PengGao's avatar
PengGao committed
342
343
            self.is_running = False

PengGao's avatar
PengGao committed
344
345
346
    async def submit_task_async(self, task_data: dict) -> Optional[dict]:
        if not self.is_running or not self.worker:
            logger.error("Inference service is not started")
PengGao's avatar
PengGao committed
347
348
            return None

PengGao's avatar
PengGao committed
349
        if self.worker.rank != 0:
gaclove's avatar
gaclove committed
350
351
            return None

PengGao's avatar
PengGao committed
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
        try:
            if self.worker.processing:
                # If we want to support queueing, we can add the task to queue
                # For now, we'll process sequentially
                logger.info(f"Waiting for previous task to complete before processing task {task_data.get('task_id')}")

            self.worker.processing = True
            result = await self.worker.process_request(task_data)
            self.worker.processing = False
            return result
        except Exception as e:
            self.worker.processing = False
            logger.error(f"Failed to process task: {str(e)}")
            return {
                "task_id": task_data.get("task_id", "unknown"),
                "status": "failed",
                "error": str(e),
                "message": f"Task processing failed: {str(e)}",
            }
PengGao's avatar
PengGao committed
371

372
373
    def server_metadata(self):
        assert hasattr(self, "args"), "Distributed inference service has not been started. Call start_distributed_inference() first."
PengGao's avatar
PengGao committed
374
375
376
377
378
379
        return {"nproc_per_node": self.worker.world_size, "model_cls": self.args.model_cls, "model_path": self.args.model_path}

    async def run_worker_loop(self):
        """Run the worker loop for non-rank-0 processes"""
        if self.worker and self.worker.rank != 0:
            await self.worker.worker_loop()
380

PengGao's avatar
PengGao committed
381
382
383
384
385
386

class VideoGenerationService:
    def __init__(self, file_service: FileService, inference_service: DistributedInferenceService):
        self.file_service = file_service
        self.inference_service = inference_service

gaclove's avatar
gaclove committed
387
    async def generate_video_with_stop_event(self, message: TaskRequest, stop_event) -> Optional[TaskResponse]:
PengGao's avatar
PengGao committed
388
        """Generate video using torchrun-based inference"""
PengGao's avatar
PengGao committed
389
        try:
390
391
            task_data = {field: getattr(message, field) for field in message.model_fields_set if field != "task_id"}
            task_data["task_id"] = message.task_id
PengGao's avatar
PengGao committed
392

gaclove's avatar
gaclove committed
393
394
395
396
397
398
399
400
401
402
403
404
405
            if stop_event.is_set():
                logger.info(f"Task {message.task_id} cancelled before processing")
                return None

            if "image_path" in message.model_fields_set and message.image_path:
                if message.image_path.startswith("http"):
                    image_path = await self.file_service.download_image(message.image_path)
                    task_data["image_path"] = str(image_path)
                elif is_base64_image(message.image_path):
                    image_path = save_base64_image(message.image_path, str(self.file_service.input_image_dir))
                    task_data["image_path"] = str(image_path)
                else:
                    task_data["image_path"] = message.image_path
PengGao's avatar
PengGao committed
406

PengGao's avatar
PengGao committed
407
408
                logger.info(f"Task {message.task_id} image path: {task_data['image_path']}")

409
410
411
412
413
414
415
416
417
418
            if "audio_path" in message.model_fields_set and message.audio_path:
                if message.audio_path.startswith("http"):
                    audio_path = await self.file_service.download_audio(message.audio_path)
                    task_data["audio_path"] = str(audio_path)
                elif is_base64_audio(message.audio_path):
                    audio_path = save_base64_audio(message.audio_path, str(self.file_service.input_audio_dir))
                    task_data["audio_path"] = str(audio_path)
                else:
                    task_data["audio_path"] = message.audio_path

PengGao's avatar
PengGao committed
419
420
                logger.info(f"Task {message.task_id} audio path: {task_data['audio_path']}")

gaclove's avatar
gaclove committed
421
422
423
            actual_save_path = self.file_service.get_output_path(message.save_video_path)
            task_data["save_video_path"] = str(actual_save_path)
            task_data["video_path"] = message.save_video_path
PengGao's avatar
PengGao committed
424

PengGao's avatar
PengGao committed
425
            result = await self.inference_service.submit_task_async(task_data)
PengGao's avatar
PengGao committed
426
427

            if result is None:
gaclove's avatar
gaclove committed
428
429
430
                if stop_event.is_set():
                    logger.info(f"Task {message.task_id} cancelled during processing")
                    return None
PengGao's avatar
PengGao committed
431
                raise RuntimeError("Task processing failed")
PengGao's avatar
PengGao committed
432
433
434
435
436

            if result.get("status") == "success":
                return TaskResponse(
                    task_id=message.task_id,
                    task_status="completed",
gaclove's avatar
gaclove committed
437
                    save_video_path=message.save_video_path,  # Return original path
PengGao's avatar
PengGao committed
438
439
440
441
442
443
444
445
                )
            else:
                error_msg = result.get("error", "Inference failed")
                raise RuntimeError(error_msg)

        except Exception as e:
            logger.error(f"Task {message.task_id} processing failed: {str(e)}")
            raise