rabbitmq_queue_manager.py 4.31 KB
Newer Older
LiangLiu's avatar
LiangLiu committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import asyncio
import json
import traceback

import aio_pika
from loguru import logger

from lightx2v.deploy.common.utils import class_try_catch_async
from lightx2v.deploy.queue_manager import BaseQueueManager


class RabbitMQQueueManager(BaseQueueManager):
    def __init__(self, conn_url, max_retries=3):
        self.conn_url = conn_url
        self.max_retries = max_retries
        self.conn = None
        self.chan = None
        self.queues = set()

    async def init(self):
        await self.get_conn()

    async def close(self):
        await self.del_conn()

    async def get_conn(self):
        if self.chan and self.conn:
            return
        for i in range(self.max_retries):
            try:
                logger.info(f"Connect to RabbitMQ (attempt {i + 1}/{self.max_retries}..)")
                self.conn = await aio_pika.connect_robust(self.conn_url)
                self.chan = await self.conn.channel()
                self.queues = set()
                await self.chan.set_qos(prefetch_count=10)
                logger.info("Successfully connected to RabbitMQ")
                return
            except Exception as e:
                logger.warning(f"Failed to connect to RabbitMQ: {e}")
                if i < self.max_retries - 1:
                    await asyncio.sleep(1)
                else:
                    raise

    async def declare_queue(self, queue):
        if queue not in self.queues:
            await self.get_conn()
            await self.chan.declare_queue(queue, durable=True)
            self.queues.add(queue)
        return await self.chan.get_queue(queue)

    @class_try_catch_async
    async def put_subtask(self, subtask):
        queue = subtask["queue"]
        await self.declare_queue(queue)
        keys = ["queue", "task_id", "worker_name", "inputs", "outputs", "params"]
        msg = json.dumps({k: subtask[k] for k in keys}).encode("utf-8")
        message = aio_pika.Message(body=msg, delivery_mode=aio_pika.DeliveryMode.PERSISTENT, content_type="application/json")
        await self.chan.default_exchange.publish(message, routing_key=queue)
        logger.info(f"Rabbitmq published subtask: ({subtask['task_id']}, {subtask['worker_name']}) to {queue}")
        return True

    async def get_subtasks(self, queue, max_batch, timeout):
        try:
            q = await self.declare_queue(queue)
            subtasks = []
            async with q.iterator() as qiter:
                async for message in qiter:
                    await message.ack()
                    subtask = json.loads(message.body.decode("utf-8"))
                    subtasks.append(subtask)
                    if len(subtasks) >= max_batch:
                        return subtasks
                    while True:
                        message = await q.get(no_ack=False, fail=False)
                        if message:
                            await message.ack()
                            subtask = json.loads(message.body.decode("utf-8"))
                            subtasks.append(subtask)
                            if len(subtasks) >= max_batch:
                                return subtasks
                        else:
                            return subtasks
        except asyncio.CancelledError:
            logger.warning(f"rabbitmq get_subtasks for {queue} cancelled")
            return None
        except:  # noqa
            logger.warning(f"rabbitmq get_subtasks for {queue} failed: {traceback.format_exc()}")
            return None

    @class_try_catch_async
    async def pending_num(self, queue):
        q = await self.declare_queue(queue)
        return q.declaration_result.message_count

    async def del_conn(self):
        if self.chan:
            await self.chan.close()
        if self.conn:
            await self.conn.close()


async def test():
    conn_url = "amqp://username:password@127.0.0.1:5672"
    q = RabbitMQQueueManager(conn_url)
    await q.init()
    subtask = {
        "task_id": "test-subtask-id",
        "queue": "test_queue",
        "worker_name": "test_worker",
        "inputs": {},
        "outputs": {},
        "params": {},
    }
    await q.put_subtask(subtask)
    await asyncio.sleep(5)
    for i in range(2):
        subtask = await q.get_subtasks("test_queue", 3, 5)
        print("get subtask:", subtask)
    await q.close()


if __name__ == "__main__":
    asyncio.run(test())