main.py 3.64 KB
Newer Older
Timothy J. Baek's avatar
Timothy J. Baek committed
1
import socketio
2
3
import asyncio

Timothy J. Baek's avatar
Timothy J. Baek committed
4
5
6
7
8
9
10
11

from apps.webui.models.users import Users
from utils.utils import decode_token

sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")

# Dictionary to maintain the user pool
Timothy J. Baek's avatar
Timothy J. Baek committed
12
13


Timothy J. Baek's avatar
Timothy J. Baek committed
14
USER_POOL = {}
15
16
17
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
21
22
23
24


@sio.event
async def connect(sid, environ, auth):
    print("connect ", sid)

    user = None
Timothy J. Baek's avatar
Timothy J. Baek committed
25
26
27
28
29
30
31
    if auth and "token" in auth:
        data = decode_token(auth["token"])

        if data is not None and "id" in data:
            user = Users.get_user_by_id(data["id"])

        if user:
Timothy J. Baek's avatar
Timothy J. Baek committed
32
            USER_POOL[sid] = user.id
Timothy J. Baek's avatar
Timothy J. Baek committed
33
            print(f"user {user.name}({user.id}) connected with session ID {sid}")
Timothy J. Baek's avatar
Timothy J. Baek committed
34

Timothy J. Baek's avatar
Timothy J. Baek committed
35
36
37
38
            print(len(set(USER_POOL)))
            await sio.emit("user-count", {"count": len(set(USER_POOL))})


Timothy J. Baek's avatar
Timothy J. Baek committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
@sio.on("user-join")
async def user_join(sid, data):
    print("user-join", sid, data)

    auth = data["auth"] if "auth" in data else None

    if auth and "token" in auth:
        data = decode_token(auth["token"])

        if data is not None and "id" in data:
            user = Users.get_user_by_id(data["id"])

        if user:
            USER_POOL[sid] = user.id
            print(f"user {user.name}({user.id}) connected with session ID {sid}")

            print(len(set(USER_POOL)))
            await sio.emit("user-count", {"count": len(set(USER_POOL))})


Timothy J. Baek's avatar
Timothy J. Baek committed
59
60
61
62
63
@sio.on("user-count")
async def user_count(sid):
    print("user-count", sid)
    await sio.emit("user-count", {"count": len(set(USER_POOL))})

Timothy J. Baek's avatar
Timothy J. Baek committed
64

65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
def get_models_in_use():
    # Aggregate all models in use

    models_in_use = []
    for sid, data in USAGE_POOL.items():
        models_in_use.extend(data["models"])
    print(f"Models in use: {models_in_use}")

    return models_in_use


@sio.on("usage")
async def usage(sid, data):
    print(f'Received "usage" event from {sid}: {data}')

    # Cancel previous task if there is one
    if sid in USAGE_POOL:
        USAGE_POOL[sid]["task"].cancel()

    # Store the new usage data and task
    model_id = data["model"]

    if sid in USAGE_POOL and "models" in USAGE_POOL[sid]:

        print(USAGE_POOL[sid])

        models = USAGE_POOL[sid]["models"]
        if model_id not in models:
            models.append(model_id)
            USAGE_POOL[sid] = {"models": models}

    else:
        USAGE_POOL[sid] = {"models": [model_id]}

    # Schedule a task to remove the usage data after TIMEOUT_DURATION
    USAGE_POOL[sid]["task"] = asyncio.create_task(remove_after_timeout(sid, model_id))

    models_in_use = get_models_in_use()
    # Broadcast the usage data to all clients
    await sio.emit("usage", {"models": models_in_use})


async def remove_after_timeout(sid, model_id):
    try:
        await asyncio.sleep(TIMEOUT_DURATION)
        if sid in USAGE_POOL:
            if model_id in USAGE_POOL[sid]["models"]:
                USAGE_POOL[sid]["models"].remove(model_id)
            if len(USAGE_POOL[sid]["models"]) == 0:
                del USAGE_POOL[sid]
            print(f"Removed usage data for {sid} due to timeout")

            models_in_use = get_models_in_use()
            # Broadcast the usage data to all clients
            await sio.emit("usage", {"models": models_in_use})
    except asyncio.CancelledError:
        # Task was cancelled due to new 'usage' event
        pass


Timothy J. Baek's avatar
Timothy J. Baek committed
125
@sio.event
Timothy J. Baek's avatar
Timothy J. Baek committed
126
async def disconnect(sid):
Timothy J. Baek's avatar
Timothy J. Baek committed
127
128
129
    if sid in USER_POOL:
        disconnected_user = USER_POOL.pop(sid)
        print(f"user {disconnected_user} disconnected with session ID {sid}")
Timothy J. Baek's avatar
Timothy J. Baek committed
130
131

        await sio.emit("user-count", {"count": len(USER_POOL)})
Timothy J. Baek's avatar
Timothy J. Baek committed
132
133
    else:
        print(f"Unknown session ID {sid} disconnected")