worker.py 4.45 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
3
4
5
6
7
8
9
10
import asyncio
import os
from typing import Any, Dict

import torch
from easydict import EasyDict
from loguru import logger

from lightx2v.infer import init_runner
from lightx2v.utils.input_info import set_input_info
11
from lightx2v.utils.set_config import set_config, set_parallel_config
PengGao's avatar
PengGao committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

from ..distributed_utils import DistributedManager


class TorchrunInferenceWorker:
    def __init__(self):
        self.rank = int(os.environ.get("LOCAL_RANK", 0))
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.runner = None
        self.dist_manager = DistributedManager()
        self.processing = False

    def init(self, args) -> bool:
        try:
            if self.world_size > 1:
                if not self.dist_manager.init_process_group():
                    raise RuntimeError("Failed to initialize distributed process group")
            else:
                self.dist_manager.rank = 0
                self.dist_manager.world_size = 1
                self.dist_manager.device = "cuda:0" if torch.cuda.is_available() else "cpu"
                self.dist_manager.is_initialized = False

            config = set_config(args)
36
37
38
39

            if config["parallel"]:
                set_parallel_config(config)

PengGao's avatar
PengGao committed
40
            if self.rank == 0:
41
                logger.info(f"Config:\n {config}")
PengGao's avatar
PengGao committed
42
43
44
45
46
47
48
49
50
51
52

            self.runner = init_runner(config)
            logger.info(f"Rank {self.rank}/{self.world_size - 1} initialization completed")

            return True

        except Exception as e:
            logger.exception(f"Rank {self.rank} initialization failed: {str(e)}")
            return False

    async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
PengGao's avatar
PengGao committed
53
54
55
        has_error = False
        error_msg = ""

PengGao's avatar
PengGao committed
56
57
58
59
60
61
62
63
        try:
            if self.world_size > 1 and self.rank == 0:
                task_data = self.dist_manager.broadcast_task_data(task_data)

            task_data["task"] = self.runner.config["task"]
            task_data["return_result_tensor"] = False
            task_data["negative_prompt"] = task_data.get("negative_prompt", "")

64
65
66
67
68
69
70
            target_fps = task_data.pop("target_fps", None)
            if target_fps is not None:
                vfi_cfg = self.runner.config.get("video_frame_interpolation")
                if vfi_cfg:
                    task_data["video_frame_interpolation"] = {**vfi_cfg, "target_fps": target_fps}
                else:
                    logger.warning(f"Target FPS {target_fps} is set, but video frame interpolation is not configured")
71

PengGao's avatar
PengGao committed
72
73
74
75
76
77
78
79
80
            task_data = EasyDict(task_data)
            input_info = set_input_info(task_data)

            self.runner.set_config(task_data)
            self.runner.run_pipeline(input_info)

            await asyncio.sleep(0)

        except Exception as e:
PengGao's avatar
PengGao committed
81
82
83
            has_error = True
            error_msg = str(e)
            logger.exception(f"Rank {self.rank} inference failed: {error_msg}")
PengGao's avatar
PengGao committed
84

PengGao's avatar
PengGao committed
85
86
87
88
89
        if self.world_size > 1:
            self.dist_manager.barrier()

        if self.rank == 0:
            if has_error:
PengGao's avatar
PengGao committed
90
91
92
                return {
                    "task_id": task_data.get("task_id", "unknown"),
                    "status": "failed",
PengGao's avatar
PengGao committed
93
94
                    "error": error_msg,
                    "message": f"Inference failed: {error_msg}",
PengGao's avatar
PengGao committed
95
96
                }
            else:
PengGao's avatar
PengGao committed
97
98
99
100
101
102
103
104
                return {
                    "task_id": task_data["task_id"],
                    "status": "success",
                    "save_result_path": task_data.get("video_path", task_data["save_result_path"]),
                    "message": "Inference completed",
                }
        else:
            return None
PengGao's avatar
PengGao committed
105
106
107

    async def worker_loop(self):
        while True:
PengGao's avatar
PengGao committed
108
            task_data = None
PengGao's avatar
PengGao committed
109
110
111
112
113
114
115
116
117
118
            try:
                task_data = self.dist_manager.broadcast_task_data()
                if task_data is None:
                    logger.info(f"Rank {self.rank} received stop signal")
                    break

                await self.process_request(task_data)

            except Exception as e:
                logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
PengGao's avatar
PengGao committed
119
120
121
122
123
124
                if self.world_size > 1 and task_data is not None:
                    try:
                        self.dist_manager.barrier()
                    except Exception as barrier_error:
                        logger.error(f"Rank {self.rank} barrier failed after error: {barrier_error}")
                        break
PengGao's avatar
PengGao committed
125
126
127
128
                continue

    def cleanup(self):
        self.dist_manager.cleanup()