chats.py 10.5 KB
Newer Older
1
from pydantic import BaseModel, ConfigDict
Timothy J. Baek's avatar
Timothy J. Baek committed
2
3
4
5
6
7
from typing import List, Union, Optional

import json
import uuid
import time

8
from sqlalchemy import Column, String, BigInteger, Boolean, Text
9

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
10
from apps.webui.internal.db import Base, get_db
11

Timothy J. Baek's avatar
Timothy J. Baek committed
12
13
14
15
16
17

####################
# Chat DB Schema
####################


18
19
class Chat(Base):
    __tablename__ = "chat"
20

21
22
    id = Column(String, primary_key=True)
    user_id = Column(String)
23
24
    title = Column(Text)
    chat = Column(Text)  # Save Chat JSON as Text
25

26
27
    created_at = Column(BigInteger)
    updated_at = Column(BigInteger)
Timothy J. Baek's avatar
Timothy J. Baek committed
28

29
    share_id = Column(Text, unique=True, nullable=True)
30
    archived = Column(Boolean, default=False)
Timothy J. Baek's avatar
Timothy J. Baek committed
31
32
33


class ChatModel(BaseModel):
34
35
    model_config = ConfigDict(from_attributes=True)

Timothy J. Baek's avatar
Timothy J. Baek committed
36
37
38
    id: str
    user_id: str
    title: str
Timothy J. Baek's avatar
Timothy J. Baek committed
39
    chat: str
40
41
42
43

    created_at: int  # timestamp in epoch
    updated_at: int  # timestamp in epoch

44
    share_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
45
    archived: bool = False
Timothy J. Baek's avatar
Timothy J. Baek committed
46
47
48
49
50
51
52
53
54
55
56


####################
# Forms
####################


class ChatForm(BaseModel):
    chat: dict


Timothy J. Baek's avatar
Timothy J. Baek committed
57
58
59
60
class ChatTitleForm(BaseModel):
    title: str


61
class ChatResponse(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
62
    id: str
63
64
65
    user_id: str
    title: str
    chat: dict
66
67
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch
68
    share_id: Optional[str] = None  # id of the chat to be shared
69
    archived: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
72
73
74


class ChatTitleIdResponse(BaseModel):
    id: str
    title: str
75
76
    updated_at: int
    created_at: int
Timothy J. Baek's avatar
Timothy J. Baek committed
77
78
79
80


class ChatTable:

81
    def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
        with get_db() as db:

            id = str(uuid.uuid4())
            chat = ChatModel(
                **{
                    "id": id,
                    "user_id": user_id,
                    "title": (
                        form_data.chat["title"]
                        if "title" in form_data.chat
                        else "New Chat"
                    ),
                    "chat": json.dumps(form_data.chat),
                    "created_at": int(time.time()),
                    "updated_at": int(time.time()),
                }
            )

            result = Chat(**chat.model_dump())
            db.add(result)
            db.commit()
            db.refresh(result)
            return ChatModel.model_validate(result) if result else None
Timothy J. Baek's avatar
Timothy J. Baek committed
105

106
    def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
107
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
108
109
110
111
112
113
114
115
116
117
            with get_db() as db:

                chat_obj = db.get(Chat, id)
                chat_obj.chat = json.dumps(chat)
                chat_obj.title = chat["title"] if "title" in chat else "New Chat"
                chat_obj.updated_at = int(time.time())
                db.commit()
                db.refresh(chat_obj)

                return ChatModel.model_validate(chat_obj)
118
119
        except Exception as e:
            return None
120

121
    def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
        with get_db() as db:

            # Get the existing chat to share
            chat = db.get(Chat, chat_id)
            # Check if the chat is already shared
            if chat.share_id:
                return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
            # Create a new chat with the same data, but with a new ID
            shared_chat = ChatModel(
                **{
                    "id": str(uuid.uuid4()),
                    "user_id": f"shared-{chat_id}",
                    "title": chat.title,
                    "chat": chat.chat,
                    "created_at": chat.created_at,
                    "updated_at": int(time.time()),
                }
            )
            shared_result = Chat(**shared_chat.model_dump())
            db.add(shared_result)
            db.commit()
            db.refresh(shared_result)
            # Update the original chat with the share_id
            result = (
                db.query(Chat)
                .filter_by(id=chat_id)
                .update({"share_id": shared_chat.id})
            )

            return shared_chat if (shared_result and result) else None
152

153
    def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
154
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
155
156
157
158
159
160
161
162
163
164
165
            with get_db() as db:

                print("update_shared_chat_by_id")
                chat = db.get(Chat, chat_id)
                print(chat)
                chat.title = chat.title
                chat.chat = chat.chat
                db.commit()
                db.refresh(chat)

                return self.get_chat_by_id(chat.share_id)
166
167
        except:
            return None
Timothy J. Baek's avatar
Timothy J. Baek committed
168

169
    def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
170
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
171
172
173
174
            with get_db() as db:

                db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
                return True
Timothy J. Baek's avatar
Timothy J. Baek committed
175
176
177
        except:
            return False

178
    def update_chat_share_id_by_id(
179
        self, id: str, share_id: Optional[str]
180
181
    ) -> Optional[ChatModel]:
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
182
183
184
185
186
187
188
            with get_db() as db:

                chat = db.get(Chat, id)
                chat.share_id = share_id
                db.commit()
                db.refresh(chat)
                return ChatModel.model_validate(chat)
189
190
191
        except:
            return None

192
    def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
193
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
194
195
196
197
198
199
200
            with get_db() as db:

                chat = db.get(Chat, id)
                chat.archived = not chat.archived
                db.commit()
                db.refresh(chat)
                return ChatModel.model_validate(chat)
Timothy J. Baek's avatar
Timothy J. Baek committed
201
202
203
        except:
            return None

204
    def archive_all_chats_by_user_id(self, user_id: str) -> bool:
205
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
206
207
208
209
            with get_db() as db:

                db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
                return True
210
211
212
        except:
            return False

Timothy J. Baek's avatar
Timothy J. Baek committed
213
    def get_archived_chat_list_by_user_id(
214
        self, user_id: str, skip: int = 0, limit: int = 50
215
    ) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
216
217
218
219
220
221
222
223
224
225
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                .filter_by(user_id=user_id, archived=True)
                .order_by(Chat.updated_at.desc())
                # .limit(limit).offset(skip)
                .all()
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
226

Timothy J. Baek's avatar
Timothy J. Baek committed
227
    def get_chat_list_by_user_id(
228
229
230
231
232
        self,
        user_id: str,
        include_archived: bool = False,
        skip: int = 0,
        limit: int = 50,
Timothy J. Baek's avatar
Timothy J. Baek committed
233
    ) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
234
235
236
237
238
239
240
241
242
243
        with get_db() as db:
            query = db.query(Chat).filter_by(user_id=user_id)
            if not include_archived:
                query = query.filter_by(archived=False)
            all_chats = (
                query.order_by(Chat.updated_at.desc())
                # .limit(limit).offset(skip)
                .all()
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
Timothy J. Baek's avatar
Timothy J. Baek committed
244

Timothy J. Baek's avatar
Timothy J. Baek committed
245
    def get_chat_list_by_chat_ids(
246
        self, chat_ids: List[str], skip: int = 0, limit: int = 50
Timothy J. Baek's avatar
Timothy J. Baek committed
247
    ) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
248
249
250
251
252
253
254
255
256
257
258

        with get_db() as db:

            all_chats = (
                db.query(Chat)
                .filter(Chat.id.in_(chat_ids))
                .filter_by(archived=False)
                .order_by(Chat.updated_at.desc())
                .all()
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
259
260

    def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
261
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
262
263
264
265
            with get_db() as db:

                chat = db.get(Chat, id)
                return ChatModel.model_validate(chat)
Timothy J. Baek's avatar
Timothy J. Baek committed
266
267
268
        except:
            return None

269
    def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
270
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
271
            with get_db() as db:
272

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
273
274
275
276
277
278
                chat = db.query(Chat).filter_by(share_id=id).first()

                if chat:
                    return self.get_chat_by_id(id)
                else:
                    return None
279
        except Exception as e:
280
281
            return None

282
    def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
283
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
284
285
286
287
            with get_db() as db:

                chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
                return ChatModel.model_validate(chat)
Timothy J. Baek's avatar
Timothy J. Baek committed
288
289
290
        except:
            return None

291
    def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
292
293
294
295
296
297
298
299
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                # .limit(limit).offset(skip)
                .order_by(Chat.updated_at.desc())
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
300

301
    def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
302
303
304
305
306
307
308
309
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                .filter_by(user_id=user_id)
                .order_by(Chat.updated_at.desc())
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
310

311
    def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
312
313
314
315
316
317
318
319
        with get_db() as db:

            all_chats = (
                db.query(Chat)
                .filter_by(user_id=user_id, archived=True)
                .order_by(Chat.updated_at.desc())
            )
            return [ChatModel.model_validate(chat) for chat in all_chats]
320
321

    def delete_chat_by_id(self, id: str) -> bool:
322
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
323
324
325
            with get_db() as db:

                db.query(Chat).filter_by(id=id).delete()
326

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
327
                return True and self.delete_shared_chat_by_chat_id(id)
328
329
330
        except:
            return False

331
    def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
332
        try:
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
333
            with get_db() as db:
334

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
335
336
337
                db.query(Chat).filter_by(id=id, user_id=user_id).delete()

                return True and self.delete_shared_chat_by_chat_id(id)
338
339
340
        except:
            return False

341
    def delete_chats_by_user_id(self, user_id: str) -> bool:
342
        try:
343

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
344
345
346
347
348
349
            with get_db() as db:

                self.delete_shared_chats_by_user_id(user_id)

                db.query(Chat).filter_by(user_id=user_id).delete()
                return True
Timothy J. Baek's avatar
Timothy J. Baek committed
350
351
352
        except:
            return False

353
    def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
354
355
        try:

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
356
357
358
359
360
361
            with get_db() as db:

                chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
                shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]

                db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
362

Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
363
                return True
364
365
366
        except:
            return False

Timothy J. Baek's avatar
Timothy J. Baek committed
367

368
Chats = ChatTable()