main.py 4.18 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
import socketio
2
3
import asyncio

Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6
7
8
9
10
11

from apps.webui.models.users import Users
from utils.utils import decode_token

sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")

# Dictionary to maintain the user pool
Timothy J. Baek's avatar
Timothy J. Baek committed
12

Timothy J. Baek's avatar
Timothy J. Baek committed
13
SESSION_POOL = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
14
USER_POOL = {}
15
16
17
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
23
24


@sio.event
async def connect(sid, environ, auth):
    print("connect ", sid)

    user = None
Timothy J. Baek's avatar
Timothy J. Baek committed
25
26
27
28
29
30
31
    if auth and "token" in auth:
        data = decode_token(auth["token"])

        if data is not None and "id" in data:
            user = Users.get_user_by_id(data["id"])

        if user:
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
34
35
36
37
            SESSION_POOL[sid] = user.id
            if user.id in USER_POOL:
                USER_POOL[user.id].append(sid)
            else:
                USER_POOL[user.id] = [sid]

Timothy J. Baek's avatar
Timothy J. Baek committed
38
            print(f"user {user.name}({user.id}) connected with session ID {sid}")
Timothy J. Baek's avatar
Timothy J. Baek committed
39

Timothy J. Baek's avatar
Timothy J. Baek committed
40
41
            print(len(set(USER_POOL)))
            await sio.emit("user-count", {"count": len(set(USER_POOL))})
Timothy J. Baek's avatar
Timothy J. Baek committed
42
            await sio.emit("usage", {"models": get_models_in_use()})
Timothy J. Baek's avatar
Timothy J. Baek committed
43
44


Timothy J. Baek's avatar
Timothy J. Baek committed
45
46
47
48
49
50
51
52
53
54
55
56
57
@sio.on("user-join")
async def user_join(sid, data):
    print("user-join", sid, data)

    auth = data["auth"] if "auth" in data else None

    if auth and "token" in auth:
        data = decode_token(auth["token"])

        if data is not None and "id" in data:
            user = Users.get_user_by_id(data["id"])

        if user:
Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60
61
62
63
64

            SESSION_POOL[sid] = user.id
            if user.id in USER_POOL:
                USER_POOL[user.id].append(sid)
            else:
                USER_POOL[user.id] = [sid]

Timothy J. Baek's avatar
Timothy J. Baek committed
65
66
67
68
69
70
            print(f"user {user.name}({user.id}) connected with session ID {sid}")

            print(len(set(USER_POOL)))
            await sio.emit("user-count", {"count": len(set(USER_POOL))})


Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
73
74
75
@sio.on("user-count")
async def user_count(sid):
    print("user-count", sid)
    await sio.emit("user-count", {"count": len(set(USER_POOL))})

Timothy J. Baek's avatar
Timothy J. Baek committed
76

77
78
79
def get_models_in_use():
    # Aggregate all models in use
    models_in_use = []
Timothy J. Baek's avatar
Timothy J. Baek committed
80
81
    for model_id, data in USAGE_POOL.items():
        models_in_use.append(model_id)
82
83
84
85
86
87
88
89
90
91
92
    print(f"Models in use: {models_in_use}")

    return models_in_use


@sio.on("usage")
async def usage(sid, data):
    print(f'Received "usage" event from {sid}: {data}')

    model_id = data["model"]

Timothy J. Baek's avatar
Timothy J. Baek committed
93
94
95
    # Cancel previous callback if there is one
    if model_id in USAGE_POOL:
        USAGE_POOL[model_id]["callback"].cancel()
96

Timothy J. Baek's avatar
Timothy J. Baek committed
97
    # Store the new usage data and task
98

Timothy J. Baek's avatar
Timothy J. Baek committed
99
100
101
    if model_id in USAGE_POOL:
        USAGE_POOL[model_id]["sids"].append(sid)
        USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
102
103

    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
104
        USAGE_POOL[model_id] = {"sids": [sid]}
105
106

    # Schedule a task to remove the usage data after TIMEOUT_DURATION
Timothy J. Baek's avatar
Timothy J. Baek committed
107
108
109
    USAGE_POOL[model_id]["callback"] = asyncio.create_task(
        remove_after_timeout(sid, model_id)
    )
110
111

    # Broadcast the usage data to all clients
Timothy J. Baek's avatar
Timothy J. Baek committed
112
    await sio.emit("usage", {"models": get_models_in_use()})
113
114
115
116


async def remove_after_timeout(sid, model_id):
    try:
Timothy J. Baek's avatar
Timothy J. Baek committed
117
        print("remove_after_timeout", sid, model_id)
118
        await asyncio.sleep(TIMEOUT_DURATION)
Timothy J. Baek's avatar
Timothy J. Baek committed
119
120
121
122
123
124
125
126
127
        if model_id in USAGE_POOL:
            print(USAGE_POOL[model_id]["sids"])
            USAGE_POOL[model_id]["sids"].remove(sid)
            USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))

            if len(USAGE_POOL[model_id]["sids"]) == 0:
                del USAGE_POOL[model_id]

            print(f"Removed usage data for {model_id} due to timeout")
128
            # Broadcast the usage data to all clients
Timothy J. Baek's avatar
Timothy J. Baek committed
129
            await sio.emit("usage", {"models": get_models_in_use()})
130
131
132
133
134
    except asyncio.CancelledError:
        # Task was cancelled due to new 'usage' event
        pass


Timothy J. Baek's avatar
Timothy J. Baek committed
135
@sio.event
Timothy J. Baek's avatar
Timothy J. Baek committed
136
async def disconnect(sid):
Timothy J. Baek's avatar
Timothy J. Baek committed
137
138
139
    if sid in SESSION_POOL:
        user_id = SESSION_POOL[sid]
        del SESSION_POOL[sid]
Timothy J. Baek's avatar
Timothy J. Baek committed
140

Timothy J. Baek's avatar
Timothy J. Baek committed
141
142
143
144
145
146
147
        USER_POOL[user_id].remove(sid)

        if len(USER_POOL[user_id]) == 0:
            del USER_POOL[user_id]

        print(f"user {user_id} disconnected with session ID {sid}")
        print(USER_POOL)
Timothy J. Baek's avatar
Timothy J. Baek committed
148
149

        await sio.emit("user-count", {"count": len(USER_POOL)})
Timothy J. Baek's avatar
Timothy J. Baek committed
150
151
    else:
        print(f"Unknown session ID {sid} disconnected")