redis_monitor.py 7.2 KB
Newer Older
LiangLiu's avatar
LiangLiu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import asyncio
import json
import time

from loguru import logger

from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.server.monitor import ServerMonitor, WorkerStatus
from lightx2v.deploy.server.redis_client import RedisClient


class RedisServerMonitor(ServerMonitor):
    def __init__(self, model_pipelines, task_manager, queue_manager, redis_url):
        super().__init__(model_pipelines, task_manager, queue_manager)
        self.redis_url = redis_url
        self.redis_client = RedisClient(redis_url)
        self.last_correct = None
        self.correct_interval = 60 * 60 * 24

    async def init(self):
        await self.redis_client.init()
        await self.init_pending_subtasks()

    async def loop(self):
        while True:
            if self.stop:
                break
            if self.last_correct is None or time.time() - self.last_correct > self.correct_interval:
                self.last_correct = time.time()
                await self.correct_pending_info()
            await self.clean_workers()
            await self.clean_subtasks()
            await asyncio.sleep(self.interval)
        logger.info("RedisServerMonitor stopped")

    async def close(self):
        await super().close()
        await self.redis_client.close()

    @class_try_catch_async
    async def worker_update(self, queue, identity, status):
        status = status.name
        key = f"workers:{queue}:workers"
        infer_key = f"workers:{queue}:infer_cost"

        update_t = time.time()
        worker = await self.redis_client.hget(key, identity)
        if worker is None:
            worker = {"status": "", "fetched_t": 0, "update_t": update_t}
            await self.redis_client.hset(key, identity, json.dumps(worker))
        else:
            worker = json.loads(worker)

        pre_status = worker["status"]
        pre_fetched_t = float(worker["fetched_t"])
        worker["status"] = status
        worker["update_t"] = update_t

        if status == WorkerStatus.REPORT.name and pre_fetched_t > 0:
            cur_cost = update_t - pre_fetched_t
            worker["fetched_t"] = 0.0
            if cur_cost < self.subtask_run_timeouts[queue]:
                await self.redis_client.list_push(infer_key, max(cur_cost, 1), self.worker_avg_window)
                logger.info(f"Worker {identity} {queue} avg infer cost update: {cur_cost:.2f} s")

        elif status == WorkerStatus.FETCHED.name:
            worker["fetched_t"] = update_t

        await self.redis_client.hset(key, identity, json.dumps(worker))
        logger.info(f"Worker {identity} {queue} update [{status}]")

    @class_try_catch_async
    async def clean_workers(self):
        for queue in self.all_queues:
            key = f"workers:{queue}:workers"
            workers = await self.redis_client.hgetall(key)

            for identity, worker in workers.items():
                worker = json.loads(worker)
                fetched_t = float(worker["fetched_t"])
                update_t = float(worker["update_t"])
                status = worker["status"]
                # logger.warning(f"{queue} avg infer cost {infer_avg:.2f} s, worker: {worker}")

                # infer too long
                if fetched_t > 0:
                    elapse = time.time() - fetched_t
                    if elapse > self.subtask_run_timeouts[queue]:
                        logger.warning(f"Worker {identity} {queue} infer timeout2: {elapse:.2f} s")
                        await self.redis_client.hdel(key, identity)
                        continue

                elapse = time.time() - update_t
                # no ping too long
                if status in [WorkerStatus.FETCHED.name, WorkerStatus.PING.name]:
                    if elapse > self.ping_timeout:
                        logger.warning(f"Worker {identity} {queue} ping timeout: {elapse:.2f} s")
                        await self.redis_client.hdel(key, identity)
                        continue

                # offline too long
                elif status in [WorkerStatus.DISCONNECT.name, WorkerStatus.REPORT.name]:
                    if elapse > self.worker_offline_timeout:
                        logger.warning(f"Worker {identity} {queue} offline timeout2: {elapse:.2f} s")
                        await self.redis_client.hdel(key, identity)
                        continue

                # fetching too long
                elif status == WorkerStatus.FETCHING.name:
                    if elapse > self.fetching_timeout:
                        logger.warning(f"Worker {identity} {queue} fetching timeout: {elapse:.2f} s")
                        await self.redis_client.hdel(key, identity)
                        continue

    async def get_ready_worker_count(self, queue):
        key = f"workers:{queue}:workers"
        worker_count = await self.redis_client.hlen(key)
        return worker_count

    async def get_avg_worker_infer_cost(self, queue):
        infer_key = f"workers:{queue}:infer_cost"
        infer_cost = await self.redis_client.list_avg(infer_key, self.worker_avg_window)
        if infer_cost < 0:
            return self.subtask_run_timeouts[queue]
        return infer_cost

    async def correct_pending_info(self):
        for queue in self.all_queues:
            pending_num = await self.queue_manager.pending_num(queue)
            await self.redis_client.correct_pending_info(f"pendings:{queue}:info", pending_num)

    @class_try_catch_async
    async def init_pending_subtasks(self):
        await super().init_pending_subtasks()
        # save to redis if not exists
        for queue, v in self.pending_subtasks.items():
            subtasks = v.pop("subtasks", {})
            await self.redis_client.create_if_not_exists(f"pendings:{queue}:info", v)
            for task_id, order_id in subtasks.items():
                await self.redis_client.set(f"pendings:{queue}:subtasks:{task_id}", order_id, nx=True)
        self.pending_subtasks = None
        logger.info(f"Inited pending subtasks to redis")

    @class_try_catch_async
    async def pending_subtasks_add(self, queue, task_id):
        max_count = await self.redis_client.increment_and_get(f"pendings:{queue}:info", "max_count", 1)
        await self.redis_client.set(f"pendings:{queue}:subtasks:{task_id}", max_count)
        # logger.warning(f"Redis pending subtasks {queue} add {task_id}: {max_count}")

    @class_try_catch_async
    async def pending_subtasks_sub(self, queue, task_id):
        consume_count = await self.redis_client.increment_and_get(f"pendings:{queue}:info", "consume_count", 1)
        await self.redis_client.delete_key(f"pendings:{queue}:subtasks:{task_id}")
        # logger.warning(f"Redis pending subtasks {queue} sub {task_id}: {consume_count}")

    @class_try_catch_async
    async def pending_subtasks_get_order(self, queue, task_id):
        order = await self.redis_client.get(f"pendings:{queue}:subtasks:{task_id}")
        if order is None:
            return None
        consume = await self.redis_client.hget(f"pendings:{queue}:info", "consume_count")
        if consume is None:
            return None
        real_order = max(int(order) - int(consume), 1)
        # logger.warning(f"Redis pending subtasks {queue} get order {task_id}: real={real_order} order={order} consume={consume}")
        return real_order