worker.py 4.37 KB
Newer Older
PengGao's avatar
PengGao committed
1
2
3
4
5
6
7
8
9
10
import asyncio
import os
from typing import Any, Dict

import torch
from easydict import EasyDict
from loguru import logger

from lightx2v.infer import init_runner
from lightx2v.utils.input_info import set_input_info
11
from lightx2v.utils.set_config import set_config, set_parallel_config
PengGao's avatar
PengGao committed
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35

from ..distributed_utils import DistributedManager


class TorchrunInferenceWorker:
    def __init__(self):
        self.rank = int(os.environ.get("LOCAL_RANK", 0))
        self.world_size = int(os.environ.get("WORLD_SIZE", 1))
        self.runner = None
        self.dist_manager = DistributedManager()
        self.processing = False

    def init(self, args) -> bool:
        try:
            if self.world_size > 1:
                if not self.dist_manager.init_process_group():
                    raise RuntimeError("Failed to initialize distributed process group")
            else:
                self.dist_manager.rank = 0
                self.dist_manager.world_size = 1
                self.dist_manager.device = "cuda:0" if torch.cuda.is_available() else "cpu"
                self.dist_manager.is_initialized = False

            config = set_config(args)
36
37
38
39

            if config["parallel"]:
                set_parallel_config(config)

PengGao's avatar
PengGao committed
40
            if self.rank == 0:
41
                logger.info(f"Config:\n {config}")
PengGao's avatar
PengGao committed
42
43
44
45
46
47
48
49
50
51
52

            self.runner = init_runner(config)
            logger.info(f"Rank {self.rank}/{self.world_size - 1} initialization completed")

            return True

        except Exception as e:
            logger.exception(f"Rank {self.rank} initialization failed: {str(e)}")
            return False

    async def process_request(self, task_data: Dict[str, Any]) -> Dict[str, Any]:
PengGao's avatar
PengGao committed
53
54
55
        has_error = False
        error_msg = ""

PengGao's avatar
PengGao committed
56
57
58
59
60
61
62
63
        try:
            if self.world_size > 1 and self.rank == 0:
                task_data = self.dist_manager.broadcast_task_data(task_data)

            task_data["task"] = self.runner.config["task"]
            task_data["return_result_tensor"] = False
            task_data["negative_prompt"] = task_data.get("negative_prompt", "")

64
65
66
67
68
            if task_data.get("target_fps") is not None and "video_frame_interpolation" in self.runner.config:
                task_data["video_frame_interpolation"] = dict(self.runner.config["video_frame_interpolation"])
                task_data["video_frame_interpolation"]["target_fps"] = task_data["target_fps"]
                del task_data["target_fps"]

PengGao's avatar
PengGao committed
69
70
71
72
73
74
75
76
77
            task_data = EasyDict(task_data)
            input_info = set_input_info(task_data)

            self.runner.set_config(task_data)
            self.runner.run_pipeline(input_info)

            await asyncio.sleep(0)

        except Exception as e:
PengGao's avatar
PengGao committed
78
79
80
            has_error = True
            error_msg = str(e)
            logger.exception(f"Rank {self.rank} inference failed: {error_msg}")
PengGao's avatar
PengGao committed
81

PengGao's avatar
PengGao committed
82
83
84
85
86
        if self.world_size > 1:
            self.dist_manager.barrier()

        if self.rank == 0:
            if has_error:
PengGao's avatar
PengGao committed
87
88
89
                return {
                    "task_id": task_data.get("task_id", "unknown"),
                    "status": "failed",
PengGao's avatar
PengGao committed
90
91
                    "error": error_msg,
                    "message": f"Inference failed: {error_msg}",
PengGao's avatar
PengGao committed
92
93
                }
            else:
PengGao's avatar
PengGao committed
94
95
96
97
98
99
100
101
                return {
                    "task_id": task_data["task_id"],
                    "status": "success",
                    "save_result_path": task_data.get("video_path", task_data["save_result_path"]),
                    "message": "Inference completed",
                }
        else:
            return None
PengGao's avatar
PengGao committed
102
103
104

    async def worker_loop(self):
        while True:
PengGao's avatar
PengGao committed
105
            task_data = None
PengGao's avatar
PengGao committed
106
107
108
109
110
111
112
113
114
115
            try:
                task_data = self.dist_manager.broadcast_task_data()
                if task_data is None:
                    logger.info(f"Rank {self.rank} received stop signal")
                    break

                await self.process_request(task_data)

            except Exception as e:
                logger.error(f"Rank {self.rank} worker loop error: {str(e)}")
PengGao's avatar
PengGao committed
116
117
118
119
120
121
                if self.world_size > 1 and task_data is not None:
                    try:
                        self.dist_manager.barrier()
                    except Exception as barrier_error:
                        logger.error(f"Rank {self.rank} barrier failed after error: {barrier_error}")
                        break
PengGao's avatar
PengGao committed
122
123
124
125
                continue

    def cleanup(self):
        self.dist_manager.cleanup()