chats.py 11.9 KB
Newer Older
1
from pydantic import BaseModel, ConfigDict
Timothy J. Baek's avatar
Timothy J. Baek committed
2
3
4
5
6
7
from typing import List, Union, Optional

import json
import uuid
import time

8
from sqlalchemy import Column, String, BigInteger, Boolean, Text
9

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
10
from apps.webui.internal.db import Base, get_db
11

Timothy J. Baek's avatar
Timothy J. Baek committed
12
13
14
15
16
17

####################
# Chat DB Schema
####################


18
19
class Chat(Base):
    __tablename__ = "chat"
20

21
22
    id = Column(String, primary_key=True)
    user_id = Column(String)
23
24
    title = Column(Text)
    chat = Column(Text)  # Save Chat JSON as Text
25

26
27
    created_at = Column(BigInteger)
    updated_at = Column(BigInteger)
Timothy J. Baek's avatar
Timothy J. Baek committed
28

29
    share_id = Column(Text, unique=True, nullable=True)
30
    archived = Column(Boolean, default=False)
Timothy J. Baek's avatar
Timothy J. Baek committed
31
32
33


class ChatModel(BaseModel):
34
35
    model_config = ConfigDict(from_attributes=True)

Timothy J. Baek's avatar
Timothy J. Baek committed
36
37
38
    id: str
    user_id: str
    title: str
Timothy J. Baek's avatar
Timothy J. Baek committed
39
    chat: str
40
41
42
43

    created_at: int  # timestamp in epoch
    updated_at: int  # timestamp in epoch

44
    share_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
45
    archived: bool = False
Timothy J. Baek's avatar
Timothy J. Baek committed
46
47
48
49
50
51
52
53
54
55
56


####################
# Forms
####################


class ChatForm(BaseModel):
    chat: dict


Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
59
60
class ChatTitleForm(BaseModel):
    title: str


61
class ChatResponse(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
62
    id: str
63
64
65
    user_id: str
    title: str
    chat: dict
66
67
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch
68
    share_id: Optional[str] = None  # id of the chat to be shared
69
    archived: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
72
73
74


class ChatTitleIdResponse(BaseModel):
    id: str
    title: str
75
76
    updated_at: int
    created_at: int
Timothy J. Baek's avatar
Timothy J. Baek committed
77
78
79
80


class ChatTable:

81
    def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        with get_db() as db:

            id = str(uuid.uuid4())
            chat = ChatModel(
                **{
                    "id": id,
                    "user_id": user_id,
                    "title": (
                        form_data.chat["title"]
                        if "title" in form_data.chat
                        else "New Chat"
                    ),
                    "chat": json.dumps(form_data.chat),
                    "created_at": int(time.time()),
                    "updated_at": int(time.time()),
                }
            )

            result = Chat(**chat.model_dump())
            db.add(result)
            db.commit()
            db.refresh(result)
            return ChatModel.model_validate(result) if result else None
Timothy J. Baek's avatar
Timothy J. Baek committed
105

106
    def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
107
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
108
109
110
111
112
113
114
115
116
117
            with get_db() as db:

                chat_obj = db.get(Chat, id)
                chat_obj.chat = json.dumps(chat)
                chat_obj.title = chat["title"] if "title" in chat else "New Chat"
                chat_obj.updated_at = int(time.time())
                db.commit()
                db.refresh(chat_obj)

                return ChatModel.model_validate(chat_obj)
118
119
        except Exception as e:
            return None
120

121
    def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
        with get_db() as db:

            # Get the existing chat to share
            chat = db.get(Chat, chat_id)
            # Check if the chat is already shared
            if chat.share_id:
                return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
            # Create a new chat with the same data, but with a new ID
            shared_chat = ChatModel(
                **{
                    "id": str(uuid.uuid4()),
                    "user_id": f"shared-{chat_id}",
                    "title": chat.title,
                    "chat": chat.chat,
                    "created_at": chat.created_at,
                    "updated_at": int(time.time()),
                }
            )
            shared_result = Chat(**shared_chat.model_dump())
            db.add(shared_result)
            db.commit()
            db.refresh(shared_result)
Timothy J. Baek's avatar
Timothy J. Baek committed
144

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
145
146
147
148
149
150
            # Update the original chat with the share_id
            result = (
                db.query(Chat)
                .filter_by(id=chat_id)
                .update({"share_id": shared_chat.id})
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
151
            db.commit()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
152
            return shared_chat if (shared_result and result) else None
153

154
    def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
155
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
156
157
158
159
160
161
162
163
164
165
166
            with get_db() as db:

                print("update_shared_chat_by_id")
                chat = db.get(Chat, chat_id)
                print(chat)
                chat.title = chat.title
                chat.chat = chat.chat
                db.commit()
                db.refresh(chat)

                return self.get_chat_by_id(chat.share_id)
167
168
        except:
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
169

170
    def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
171
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
172
173
174
            with get_db() as db:

                db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
175
176
                db.commit()

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
177
                return True
Timothy J. Baek's avatar
Timothy J. Baek committed
178
179
180
        except:
            return False

181
    def update_chat_share_id_by_id(
182
        self, id: str, share_id: Optional[str]
183
184
    ) -> Optional[ChatModel]:
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
185
186
187
188
189
190
191
            with get_db() as db:

                chat = db.get(Chat, id)
                chat.share_id = share_id
                db.commit()
                db.refresh(chat)
                return ChatModel.model_validate(chat)
192
193
194
        except:
            return None

195
    def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
196
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
197
198
199
200
201
202
203
            with get_db() as db:

                chat = db.get(Chat, id)
                chat.archived = not chat.archived
                db.commit()
                db.refresh(chat)
                return ChatModel.model_validate(chat)
Timothy J. Baek's avatar
Timothy J. Baek committed
204
205
206
        except:
            return None

207
    def archive_all_chats_by_user_id(self, user_id: str) -> bool:
208
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
209
210
            with get_db() as db:
                db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
Timothy J. Baek's avatar
Timothy J. Baek committed
211
                db.commit()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
212
                return True
213
214
215
        except:
            return False

Timothy J. Baek's avatar
Timothy J. Baek committed
216
    def get_archived_chat_list_by_user_id(
217
        self, user_id: str, skip: int = 0, limit: int = 50
218
    ) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
219
220
221
222
223
224
225
226
227
228
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                .filter_by(user_id=user_id, archived=True)
                .order_by(Chat.updated_at.desc())
                # .limit(limit).offset(skip)
                .all()
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
229

Timothy J. Baek's avatar
Timothy J. Baek committed
230
    def get_chat_list_by_user_id(
231
232
233
234
235
        self,
        user_id: str,
        include_archived: bool = False,
        skip: int = 0,
        limit: int = 50,
Timothy J. Baek's avatar
Timothy J. Baek committed
236
    ) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
237
238
239
240
241
242
243
244
245
246
        with get_db() as db:
            query = db.query(Chat).filter_by(user_id=user_id)
            if not include_archived:
                query = query.filter_by(archived=False)
            all_chats = (
                query.order_by(Chat.updated_at.desc())
                # .limit(limit).offset(skip)
                .all()
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
Aryan Kothari's avatar
Aryan Kothari committed
247

248
249
250
251
252
253
    def get_chat_title_id_list_by_user_id(
        self,
        user_id: str,
        include_archived: bool = False,
        skip: int = 0,
        limit: int = 50,
Aryan Kothari's avatar
Aryan Kothari committed
254
    ) -> List[ChatTitleIdResponse]:
255
256
257
258
259
260
261
262
        with get_db() as db:
            query = db.query(Chat).filter_by(user_id=user_id)
            if not include_archived:
                query = query.filter_by(archived=False)

            all_chats = (
                query.order_by(Chat.updated_at.desc())
                # limit cols
Aryan Kothari's avatar
Aryan Kothari committed
263
264
265
                .with_entities(
                    Chat.id, Chat.title, Chat.updated_at, Chat.created_at
                ).all()
266
267
            )
            # result has to be destrctured from sqlalchemy `row` and mapped to a dict since the `ChatModel`is not the returned dataclass.
Aryan Kothari's avatar
Aryan Kothari committed
268
269
270
271
272
273
274
275
276
277
278
279
280
            return list(
                map(
                    lambda row: ChatTitleIdResponse.model_validate(
                        {
                            "id": row[0],
                            "title": row[1],
                            "updated_at": row[2],
                            "created_at": row[3],
                        }
                    ),
                    all_chats,
                )
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
281

Timothy J. Baek's avatar
Timothy J. Baek committed
282
    def get_chat_list_by_chat_ids(
283
        self, chat_ids: List[str], skip: int = 0, limit: int = 50
Timothy J. Baek's avatar
Timothy J. Baek committed
284
    ) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
285
286
287
288
289
290
291
292
293
        with get_db() as db:
            all_chats = (
                db.query(Chat)
                .filter(Chat.id.in_(chat_ids))
                .filter_by(archived=False)
                .order_by(Chat.updated_at.desc())
                .all()
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
294
295

    def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
296
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
297
298
299
300
            with get_db() as db:

                chat = db.get(Chat, id)
                return ChatModel.model_validate(chat)
Timothy J. Baek's avatar
Timothy J. Baek committed
301
302
303
        except:
            return None

304
    def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
305
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
306
            with get_db() as db:
307

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
308
309
310
311
312
313
                chat = db.query(Chat).filter_by(share_id=id).first()

                if chat:
                    return self.get_chat_by_id(id)
                else:
                    return None
314
        except Exception as e:
315
316
            return None

317
    def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
318
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
319
320
321
322
            with get_db() as db:

                chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
                return ChatModel.model_validate(chat)
Timothy J. Baek's avatar
Timothy J. Baek committed
323
324
325
        except:
            return None

326
    def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
327
328
329
330
331
332
333
334
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                # .limit(limit).offset(skip)
                .order_by(Chat.updated_at.desc())
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
335

336
    def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
337
338
339
340
341
342
343
344
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                .filter_by(user_id=user_id)
                .order_by(Chat.updated_at.desc())
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
345

346
    def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
347
348
349
350
351
352
353
354
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                .filter_by(user_id=user_id, archived=True)
                .order_by(Chat.updated_at.desc())
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
355
356

    def delete_chat_by_id(self, id: str) -> bool:
357
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
358
359
360
            with get_db() as db:

                db.query(Chat).filter_by(id=id).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
361
                db.commit()
362

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
363
                return True and self.delete_shared_chat_by_chat_id(id)
364
365
366
        except:
            return False

367
    def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
368
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
369
            with get_db() as db:
370

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
371
                db.query(Chat).filter_by(id=id, user_id=user_id).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
372
                db.commit()
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
373
374

                return True and self.delete_shared_chat_by_chat_id(id)
375
376
377
        except:
            return False

378
    def delete_chats_by_user_id(self, user_id: str) -> bool:
379
        try:
380

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
381
382
383
384
385
            with get_db() as db:

                self.delete_shared_chats_by_user_id(user_id)

                db.query(Chat).filter_by(user_id=user_id).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
386
387
                db.commit()

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
388
                return True
Timothy J. Baek's avatar
Timothy J. Baek committed
389
390
391
        except:
            return False

392
    def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
393
394
        try:

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
395
396
397
398
399
400
            with get_db() as db:

                chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
                shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]

                db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
401
                db.commit()
Timothy J. Baek's avatar
Timothy J. Baek committed
402

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
403
                return True
404
405
406
        except:
            return False

Timothy J. Baek's avatar
Timothy J. Baek committed
407

408
Chats = ChatTable()