main.py 3.81 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
import socketio
2
3
import asyncio

Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6
7
8
9
10
11

from apps.webui.models.users import Users
from utils.utils import decode_token

sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")

# Dictionary to maintain the user pool
Timothy J. Baek's avatar
Timothy J. Baek committed
12

Timothy J. Baek's avatar
Timothy J. Baek committed
13
SESSION_POOL = {}
Timothy J. Baek's avatar
Timothy J. Baek committed
14
USER_POOL = {}
15
16
17
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22


@sio.event
async def connect(sid, environ, auth):
    user = None
Timothy J. Baek's avatar
Timothy J. Baek committed
23
24
25
26
    if auth and "token" in auth:
        data = decode_token(auth["token"])

        if data is not None and "id" in data:
27
28
29
            from apps.webui.internal.db import SessionLocal

            user = Users.get_user_by_id(SessionLocal(), data["id"])
Timothy J. Baek's avatar
Timothy J. Baek committed
30
31

        if user:
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
34
35
36
37
            SESSION_POOL[sid] = user.id
            if user.id in USER_POOL:
                USER_POOL[user.id].append(sid)
            else:
                USER_POOL[user.id] = [sid]

Timothy J. Baek's avatar
Timothy J. Baek committed
38
            print(f"user {user.name}({user.id}) connected with session ID {sid}")
Timothy J. Baek's avatar
Timothy J. Baek committed
39

Timothy J. Baek's avatar
Timothy J. Baek committed
40
            await sio.emit("user-count", {"count": len(set(USER_POOL))})
Timothy J. Baek's avatar
Timothy J. Baek committed
41
            await sio.emit("usage", {"models": get_models_in_use()})
Timothy J. Baek's avatar
Timothy J. Baek committed
42
43


Timothy J. Baek's avatar
Timothy J. Baek committed
44
45
46
47
48
49
50
51
52
53
54
55
56
@sio.on("user-join")
async def user_join(sid, data):
    print("user-join", sid, data)

    auth = data["auth"] if "auth" in data else None

    if auth and "token" in auth:
        data = decode_token(auth["token"])

        if data is not None and "id" in data:
            user = Users.get_user_by_id(data["id"])

        if user:
Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
59
60
61
62
63

            SESSION_POOL[sid] = user.id
            if user.id in USER_POOL:
                USER_POOL[user.id].append(sid)
            else:
                USER_POOL[user.id] = [sid]

Timothy J. Baek's avatar
Timothy J. Baek committed
64
65
66
67
68
            print(f"user {user.name}({user.id}) connected with session ID {sid}")

            await sio.emit("user-count", {"count": len(set(USER_POOL))})


Timothy J. Baek's avatar
Timothy J. Baek committed
69
70
71
72
@sio.on("user-count")
async def user_count(sid):
    await sio.emit("user-count", {"count": len(set(USER_POOL))})

Timothy J. Baek's avatar
Timothy J. Baek committed
73

74
75
76
def get_models_in_use():
    # Aggregate all models in use
    models_in_use = []
Timothy J. Baek's avatar
Timothy J. Baek committed
77
78
    for model_id, data in USAGE_POOL.items():
        models_in_use.append(model_id)
79
80
81
82
83
84
85
86
87

    return models_in_use


@sio.on("usage")
async def usage(sid, data):

    model_id = data["model"]

Timothy J. Baek's avatar
Timothy J. Baek committed
88
89
90
    # Cancel previous callback if there is one
    if model_id in USAGE_POOL:
        USAGE_POOL[model_id]["callback"].cancel()
91

Timothy J. Baek's avatar
Timothy J. Baek committed
92
    # Store the new usage data and task
93

Timothy J. Baek's avatar
Timothy J. Baek committed
94
95
96
    if model_id in USAGE_POOL:
        USAGE_POOL[model_id]["sids"].append(sid)
        USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
97
98

    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
99
        USAGE_POOL[model_id] = {"sids": [sid]}
100
101

    # Schedule a task to remove the usage data after TIMEOUT_DURATION
Timothy J. Baek's avatar
Timothy J. Baek committed
102
103
104
    USAGE_POOL[model_id]["callback"] = asyncio.create_task(
        remove_after_timeout(sid, model_id)
    )
105
106

    # Broadcast the usage data to all clients
Timothy J. Baek's avatar
Timothy J. Baek committed
107
    await sio.emit("usage", {"models": get_models_in_use()})
108
109
110
111
112


async def remove_after_timeout(sid, model_id):
    try:
        await asyncio.sleep(TIMEOUT_DURATION)
Timothy J. Baek's avatar
Timothy J. Baek committed
113
114
115
116
117
118
119
120
        if model_id in USAGE_POOL:
            print(USAGE_POOL[model_id]["sids"])
            USAGE_POOL[model_id]["sids"].remove(sid)
            USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))

            if len(USAGE_POOL[model_id]["sids"]) == 0:
                del USAGE_POOL[model_id]

121
            # Broadcast the usage data to all clients
Timothy J. Baek's avatar
Timothy J. Baek committed
122
            await sio.emit("usage", {"models": get_models_in_use()})
123
124
125
126
127
    except asyncio.CancelledError:
        # Task was cancelled due to new 'usage' event
        pass


Timothy J. Baek's avatar
Timothy J. Baek committed
128
@sio.event
Timothy J. Baek's avatar
Timothy J. Baek committed
129
async def disconnect(sid):
Timothy J. Baek's avatar
Timothy J. Baek committed
130
131
132
    if sid in SESSION_POOL:
        user_id = SESSION_POOL[sid]
        del SESSION_POOL[sid]
Timothy J. Baek's avatar
Timothy J. Baek committed
133

Timothy J. Baek's avatar
Timothy J. Baek committed
134
135
136
137
138
        USER_POOL[user_id].remove(sid)

        if len(USER_POOL[user_id]) == 0:
            del USER_POOL[user_id]

Timothy J. Baek's avatar
Timothy J. Baek committed
139
        await sio.emit("user-count", {"count": len(USER_POOL)})
Timothy J. Baek's avatar
Timothy J. Baek committed
140
141
    else:
        print(f"Unknown session ID {sid} disconnected")