chats.py 12 KB
Newer Older
1
from pydantic import BaseModel, ConfigDict
Michael Poluektov's avatar
Michael Poluektov committed
2
from typing import Union, Optional
Timothy J. Baek's avatar
Timothy J. Baek committed
3
4
5
6
7

import json
import uuid
import time

8
from sqlalchemy import Column, String, BigInteger, Boolean, Text
9

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
10
from apps.webui.internal.db import Base, get_db
11

Timothy J. Baek's avatar
Timothy J. Baek committed
12
13
14
15
16
17

####################
# Chat DB Schema
####################


18
19
class Chat(Base):
    __tablename__ = "chat"
20

21
22
    id = Column(String, primary_key=True)
    user_id = Column(String)
23
24
    title = Column(Text)
    chat = Column(Text)  # Save Chat JSON as Text
25

26
27
    created_at = Column(BigInteger)
    updated_at = Column(BigInteger)
Timothy J. Baek's avatar
Timothy J. Baek committed
28

29
    share_id = Column(Text, unique=True, nullable=True)
30
    archived = Column(Boolean, default=False)
Timothy J. Baek's avatar
Timothy J. Baek committed
31
32
33


class ChatModel(BaseModel):
34
35
    model_config = ConfigDict(from_attributes=True)

Timothy J. Baek's avatar
Timothy J. Baek committed
36
37
38
    id: str
    user_id: str
    title: str
Timothy J. Baek's avatar
Timothy J. Baek committed
39
    chat: str
40
41
42
43

    created_at: int  # timestamp in epoch
    updated_at: int  # timestamp in epoch

44
    share_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
45
    archived: bool = False
Timothy J. Baek's avatar
Timothy J. Baek committed
46
47
48
49
50
51
52
53
54
55
56


####################
# Forms
####################


class ChatForm(BaseModel):
    chat: dict


Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
59
60
class ChatTitleForm(BaseModel):
    title: str


61
class ChatResponse(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
62
    id: str
63
64
65
    user_id: str
    title: str
    chat: dict
66
67
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch
68
    share_id: Optional[str] = None  # id of the chat to be shared
69
    archived: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
72
73
74


class ChatTitleIdResponse(BaseModel):
    id: str
    title: str
75
76
    updated_at: int
    created_at: int
Timothy J. Baek's avatar
Timothy J. Baek committed
77
78
79
80


class ChatTable:

81
    def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        with get_db() as db:

            id = str(uuid.uuid4())
            chat = ChatModel(
                **{
                    "id": id,
                    "user_id": user_id,
                    "title": (
                        form_data.chat["title"]
                        if "title" in form_data.chat
                        else "New Chat"
                    ),
                    "chat": json.dumps(form_data.chat),
                    "created_at": int(time.time()),
                    "updated_at": int(time.time()),
                }
            )

            result = Chat(**chat.model_dump())
            db.add(result)
            db.commit()
            db.refresh(result)
            return ChatModel.model_validate(result) if result else None
Timothy J. Baek's avatar
Timothy J. Baek committed
105

106
    def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
107
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
108
109
110
111
112
113
114
115
116
117
            with get_db() as db:

                chat_obj = db.get(Chat, id)
                chat_obj.chat = json.dumps(chat)
                chat_obj.title = chat["title"] if "title" in chat else "New Chat"
                chat_obj.updated_at = int(time.time())
                db.commit()
                db.refresh(chat_obj)

                return ChatModel.model_validate(chat_obj)
118
119
        except Exception as e:
            return None
120

121
    def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        with get_db() as db:

            # Get the existing chat to share
            chat = db.get(Chat, chat_id)
            # Check if the chat is already shared
            if chat.share_id:
                return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
            # Create a new chat with the same data, but with a new ID
            shared_chat = ChatModel(
                **{
                    "id": str(uuid.uuid4()),
                    "user_id": f"shared-{chat_id}",
                    "title": chat.title,
                    "chat": chat.chat,
                    "created_at": chat.created_at,
                    "updated_at": int(time.time()),
                }
            )
            shared_result = Chat(**shared_chat.model_dump())
            db.add(shared_result)
            db.commit()
            db.refresh(shared_result)
Timothy J. Baek's avatar
Timothy J. Baek committed
144

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
145
146
147
148
149
150
            # Update the original chat with the share_id
            result = (
                db.query(Chat)
                .filter_by(id=chat_id)
                .update({"share_id": shared_chat.id})
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
151
            db.commit()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
152
            return shared_chat if (shared_result and result) else None
153

154
    def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
155
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
156
157
158
159
160
161
162
163
164
165
166
            with get_db() as db:

                print("update_shared_chat_by_id")
                chat = db.get(Chat, chat_id)
                print(chat)
                chat.title = chat.title
                chat.chat = chat.chat
                db.commit()
                db.refresh(chat)

                return self.get_chat_by_id(chat.share_id)
167
        except Exception:
168
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
169

170
    def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
171
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
172
173
174
            with get_db() as db:

                db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
175
176
                db.commit()

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
177
                return True
178
        except Exception:
Timothy J. Baek's avatar
Timothy J. Baek committed
179
180
            return False

181
    def update_chat_share_id_by_id(
182
        self, id: str, share_id: Optional[str]
183
184
    ) -> Optional[ChatModel]:
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
185
186
187
188
189
190
191
            with get_db() as db:

                chat = db.get(Chat, id)
                chat.share_id = share_id
                db.commit()
                db.refresh(chat)
                return ChatModel.model_validate(chat)
192
        except Exception:
193
194
            return None

195
    def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
196
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
197
198
199
200
201
202
203
            with get_db() as db:

                chat = db.get(Chat, id)
                chat.archived = not chat.archived
                db.commit()
                db.refresh(chat)
                return ChatModel.model_validate(chat)
204
        except Exception:
Timothy J. Baek's avatar
Timothy J. Baek committed
205
206
            return None

207
    def archive_all_chats_by_user_id(self, user_id: str) -> bool:
208
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
209
210
            with get_db() as db:
                db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
Timothy J. Baek's avatar
Timothy J. Baek committed
211
                db.commit()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
212
                return True
213
        except Exception:
214
215
            return False

Timothy J. Baek's avatar
Timothy J. Baek committed
216
    def get_archived_chat_list_by_user_id(
217
        self, user_id: str, skip: int = 0, limit: int = 50
Michael Poluektov's avatar
Michael Poluektov committed
218
    ) -> list[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
219
220
221
222
223
224
225
226
227
228
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                .filter_by(user_id=user_id, archived=True)
                .order_by(Chat.updated_at.desc())
                # .limit(limit).offset(skip)
                .all()
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
229

Timothy J. Baek's avatar
Timothy J. Baek committed
230
    def get_chat_list_by_user_id(
231
232
233
234
235
        self,
        user_id: str,
        include_archived: bool = False,
        skip: int = 0,
        limit: int = 50,
Michael Poluektov's avatar
Michael Poluektov committed
236
    ) -> list[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
237
238
239
240
241
242
243
244
245
246
        with get_db() as db:
            query = db.query(Chat).filter_by(user_id=user_id)
            if not include_archived:
                query = query.filter_by(archived=False)
            all_chats = (
                query.order_by(Chat.updated_at.desc())
                # .limit(limit).offset(skip)
                .all()
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
Aryan Kothari's avatar
Aryan Kothari committed
247

248
249
250
251
252
    def get_chat_title_id_list_by_user_id(
        self,
        user_id: str,
        include_archived: bool = False,
        skip: int = 0,
253
        limit: int = -1,
Michael Poluektov's avatar
Michael Poluektov committed
254
    ) -> list[ChatTitleIdResponse]:
255
256
257
258
259
260
261
262
        with get_db() as db:
            query = db.query(Chat).filter_by(user_id=user_id)
            if not include_archived:
                query = query.filter_by(archived=False)

            all_chats = (
                query.order_by(Chat.updated_at.desc())
                # limit cols
263
264
265
266
                .with_entities(Chat.id, Chat.title, Chat.updated_at, Chat.created_at)
                .limit(limit)
                .offset(skip)
                .all()
267
268
            )
            # result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
269
270
271
272
273
274
275
276
            return [
                ChatTitleIdResponse.model_validate(
                    {
                        "id": chat[0],
                        "title": chat[1],
                        "updated_at": chat[2],
                        "created_at": chat[3],
                    }
Aryan Kothari's avatar
Aryan Kothari committed
277
                )
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
278
279
                for chat in all_chats
            ]
Timothy J. Baek's avatar
Timothy J. Baek committed
280

Timothy J. Baek's avatar
Timothy J. Baek committed
281
    def get_chat_list_by_chat_ids(
Michael Poluektov's avatar
Michael Poluektov committed
282
283
        self, chat_ids: list[str], skip: int = 0, limit: int = 50
    ) -> list[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
284
285
286
287
288
289
290
291
292
        with get_db() as db:
            all_chats = (
                db.query(Chat)
                .filter(Chat.id.in_(chat_ids))
                .filter_by(archived=False)
                .order_by(Chat.updated_at.desc())
                .all()
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
293
294

    def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
295
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
296
297
298
299
            with get_db() as db:

                chat = db.get(Chat, id)
                return ChatModel.model_validate(chat)
300
        except Exception:
Timothy J. Baek's avatar
Timothy J. Baek committed
301
302
            return None

303
    def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
304
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
305
            with get_db() as db:
306

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
307
308
309
310
311
312
                chat = db.query(Chat).filter_by(share_id=id).first()

                if chat:
                    return self.get_chat_by_id(id)
                else:
                    return None
313
        except Exception as e:
314
315
            return None

316
    def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
317
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
318
319
320
321
            with get_db() as db:

                chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
                return ChatModel.model_validate(chat)
322
        except Exception:
Timothy J. Baek's avatar
Timothy J. Baek committed
323
324
            return None

Michael Poluektov's avatar
Michael Poluektov committed
325
    def get_chats(self, skip: int = 0, limit: int = 50) -> list[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
326
327
328
329
330
331
332
333
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                # .limit(limit).offset(skip)
                .order_by(Chat.updated_at.desc())
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
334

Michael Poluektov's avatar
Michael Poluektov committed
335
    def get_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
336
337
338
339
340
341
342
343
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                .filter_by(user_id=user_id)
                .order_by(Chat.updated_at.desc())
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
344

Michael Poluektov's avatar
Michael Poluektov committed
345
    def get_archived_chats_by_user_id(self, user_id: str) -> list[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
346
347
348
349
350
351
352
353
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                .filter_by(user_id=user_id, archived=True)
                .order_by(Chat.updated_at.desc())
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
354
355

    def delete_chat_by_id(self, id: str) -> bool:
356
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
357
358
359
            with get_db() as db:

                db.query(Chat).filter_by(id=id).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
360
                db.commit()
361

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
362
                return True and self.delete_shared_chat_by_chat_id(id)
363
        except Exception:
364
365
            return False

366
    def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
367
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
368
            with get_db() as db:
369

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
370
                db.query(Chat).filter_by(id=id, user_id=user_id).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
371
                db.commit()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
372
373

                return True and self.delete_shared_chat_by_chat_id(id)
374
        except Exception:
375
376
            return False

377
    def delete_chats_by_user_id(self, user_id: str) -> bool:
378
        try:
379

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
380
381
382
383
384
            with get_db() as db:

                self.delete_shared_chats_by_user_id(user_id)

                db.query(Chat).filter_by(user_id=user_id).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
385
386
                db.commit()

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
387
                return True
388
        except Exception:
Timothy J. Baek's avatar
Timothy J. Baek committed
389
390
            return False

391
    def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
392
393
        try:

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
394
395
396
397
398
399
            with get_db() as db:

                chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
                shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]

                db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
400
                db.commit()
Timothy J. Baek's avatar
Timothy J. Baek committed
401

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
402
                return True
403
        except Exception:
404
405
            return False

Timothy J. Baek's avatar
Timothy J. Baek committed
406

407
Chats = ChatTable()