chats.py 9.58 KB
Newer Older
1
from pydantic import BaseModel, ConfigDict
Timothy J. Baek's avatar
Timothy J. Baek committed
2
3
4
5
6
7
from typing import List, Union, Optional

import json
import uuid
import time

8
9
10
11
12
from sqlalchemy import Column, String, BigInteger, Boolean
from sqlalchemy.orm import Session

from apps.webui.internal.db import Base

Timothy J. Baek's avatar
Timothy J. Baek committed
13
14
15
16
17
18

####################
# Chat DB Schema
####################


19
20
class Chat(Base):
    __tablename__ = "chat"
21

22
23
24
25
    id = Column(String, primary_key=True)
    user_id = Column(String)
    title = Column(String)
    chat = Column(String)  # Save Chat JSON as Text
26

27
28
    created_at = Column(BigInteger)
    updated_at = Column(BigInteger)
Timothy J. Baek's avatar
Timothy J. Baek committed
29

30
31
    share_id = Column(String, unique=True, nullable=True)
    archived = Column(Boolean, default=False)
Timothy J. Baek's avatar
Timothy J. Baek committed
32
33
34


class ChatModel(BaseModel):
35
36
    model_config = ConfigDict(from_attributes=True)

Timothy J. Baek's avatar
Timothy J. Baek committed
37
38
39
    id: str
    user_id: str
    title: str
Timothy J. Baek's avatar
Timothy J. Baek committed
40
    chat: str
41
42
43
44

    created_at: int  # timestamp in epoch
    updated_at: int  # timestamp in epoch

45
    share_id: Optional[str] = None
Timothy J. Baek's avatar
Timothy J. Baek committed
46
    archived: bool = False
Timothy J. Baek's avatar
Timothy J. Baek committed
47
48
49
50
51
52
53
54
55
56
57


####################
# Forms
####################


class ChatForm(BaseModel):
    chat: dict


Timothy J. Baek's avatar
Timothy J. Baek committed
58
59
60
61
class ChatTitleForm(BaseModel):
    title: str


62
class ChatResponse(BaseModel):
Timothy J. Baek's avatar
Timothy J. Baek committed
63
    id: str
64
65
66
    user_id: str
    title: str
    chat: dict
67
68
    updated_at: int  # timestamp in epoch
    created_at: int  # timestamp in epoch
69
    share_id: Optional[str] = None  # id of the chat to be shared
70
    archived: bool
Timothy J. Baek's avatar
Timothy J. Baek committed
71
72
73
74
75


class ChatTitleIdResponse(BaseModel):
    id: str
    title: str
76
77
    updated_at: int
    created_at: int
Timothy J. Baek's avatar
Timothy J. Baek committed
78
79
80
81


class ChatTable:

82
83
84
    def insert_new_chat(
        self, db: Session, user_id: str, form_data: ChatForm
    ) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
85
86
87
88
89
        id = str(uuid.uuid4())
        chat = ChatModel(
            **{
                "id": id,
                "user_id": user_id,
90
91
92
                "title": (
                    form_data.chat["title"] if "title" in form_data.chat else "New Chat"
                ),
Timothy J. Baek's avatar
Timothy J. Baek committed
93
                "chat": json.dumps(form_data.chat),
94
95
                "created_at": int(time.time()),
                "updated_at": int(time.time()),
Timothy J. Baek's avatar
Timothy J. Baek committed
96
97
            }
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
98

99
100
101
102
103
        result = Chat(**chat.model_dump())
        db.add(result)
        db.commit()
        db.refresh(result)
        return ChatModel.model_validate(result) if result else None
Timothy J. Baek's avatar
Timothy J. Baek committed
104

105
106
107
    def update_chat_by_id(
        self, db: Session, id: str, chat: dict
    ) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
108
        try:
109
110
111
112
113
114
115
116
117
            db.query(Chat).filter_by(id=id).update(
                {
                    "chat": json.dumps(chat),
                    "title": chat["title"] if "title" in chat else "New Chat",
                    "updated_at": int(time.time()),
                }
            )

            return self.get_chat_by_id(db, id)
Timothy J. Baek's avatar
Timothy J. Baek committed
118
119
120
        except:
            return None

121
122
123
    def insert_shared_chat_by_chat_id(
        self, db: Session, chat_id: str
    ) -> Optional[ChatModel]:
124
        # Get the existing chat to share
125
        chat = db.get(Chat, chat_id)
126
127
        # Check if the chat is already shared
        if chat.share_id:
128
            return self.get_chat_by_id_and_user_id(db, chat.share_id, "shared")
129
130
131
132
        # Create a new chat with the same data, but with a new ID
        shared_chat = ChatModel(
            **{
                "id": str(uuid.uuid4()),
Timothy J. Baek's avatar
Timothy J. Baek committed
133
                "user_id": f"shared-{chat_id}",
134
135
                "title": chat.title,
                "chat": chat.chat,
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
136
137
                "created_at": chat.created_at,
                "updated_at": int(time.time()),
138
139
            }
        )
140
141
142
143
        shared_result = Chat(**shared_chat.model_dump())
        db.add(shared_result)
        db.commit()
        db.refresh(shared_result)
144
145
        # Update the original chat with the share_id
        result = (
146
            db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
147
148
149
150
        )

        return shared_chat if (shared_result and result) else None

151
152
153
    def update_shared_chat_by_chat_id(
        self, db: Session, chat_id: str
    ) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
154
155
        try:
            print("update_shared_chat_by_id")
156
            chat = db.get(Chat, chat_id)
Timothy J. Baek's avatar
Timothy J. Baek committed
157
158
            print(chat)

159
160
161
            db.query(Chat).filter_by(id=chat.share_id).update(
                {"title": chat.title, "chat": chat.chat}
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
162

163
            return self.get_chat_by_id(db, chat.share_id)
Timothy J. Baek's avatar
Timothy J. Baek committed
164
165
166
        except:
            return None

167
    def delete_shared_chat_by_chat_id(self, db: Session, chat_id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
168
        try:
169
            db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
170
171
172
173
            return True
        except:
            return False

174
    def update_chat_share_id_by_id(
175
        self, db: Session, id: str, share_id: Optional[str]
176
177
    ) -> Optional[ChatModel]:
        try:
178
            db.query(Chat).filter_by(id=id).update({"share_id": share_id})
179

180
            return self.get_chat_by_id(db, id)
181
182
183
        except:
            return None

184
    def toggle_chat_archive_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
185
        try:
186
187
            chat = self.get_chat_by_id(db, id)
            db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
Timothy J. Baek's avatar
Timothy J. Baek committed
188

189
            return self.get_chat_by_id(db, id)
Timothy J. Baek's avatar
Timothy J. Baek committed
190
191
192
        except:
            return None

193
    def archive_all_chats_by_user_id(self, db: Session, user_id: str) -> bool:
194
        try:
195
            db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
196
197
198
199
200

            return True
        except:
            return False

Timothy J. Baek's avatar
Timothy J. Baek committed
201
    def get_archived_chat_list_by_user_id(
202
        self, db: Session, user_id: str, skip: int = 0, limit: int = 50
203
    ) -> List[ChatModel]:
204
205
206
        all_chats = (
            db.query(Chat)
            .filter_by(user_id=user_id, archived=True)
207
            .order_by(Chat.updated_at.desc())
208
209
210
211
            # .limit(limit).offset(skip)
            .all()
        )
        return [ChatModel.model_validate(chat) for chat in all_chats]
212

Timothy J. Baek's avatar
Timothy J. Baek committed
213
    def get_chat_list_by_user_id(
214
        self,
215
        db: Session,
216
217
218
219
        user_id: str,
        include_archived: bool = False,
        skip: int = 0,
        limit: int = 50,
Timothy J. Baek's avatar
Timothy J. Baek committed
220
    ) -> List[ChatModel]:
221
222
223
224
225
226
227
228
229
        query = db.query(Chat).filter_by(user_id=user_id)
        if not include_archived:
            query = query.filter_by(archived=False)
        all_chats = (
            query.order_by(Chat.updated_at.desc())
            # .limit(limit).offset(skip)
            .all()
        )
        return [ChatModel.model_validate(chat) for chat in all_chats]
Timothy J. Baek's avatar
Timothy J. Baek committed
230

Timothy J. Baek's avatar
Timothy J. Baek committed
231
    def get_chat_list_by_chat_ids(
232
        self, db: Session, chat_ids: List[str], skip: int = 0, limit: int = 50
Timothy J. Baek's avatar
Timothy J. Baek committed
233
    ) -> List[ChatModel]:
234
235
236
237
        all_chats = (
            db.query(Chat)
            .filter(Chat.id.in_(chat_ids))
            .filter_by(archived=False)
238
            .order_by(Chat.updated_at.desc())
239
240
241
            .all()
        )
        return [ChatModel.model_validate(chat) for chat in all_chats]
Timothy J. Baek's avatar
Timothy J. Baek committed
242

243
    def get_chat_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
244
        try:
245
246
            chat = db.get(Chat, id)
            return ChatModel.model_validate(chat)
Timothy J. Baek's avatar
Timothy J. Baek committed
247
248
249
        except:
            return None

250
    def get_chat_by_share_id(self, db: Session, id: str) -> Optional[ChatModel]:
251
        try:
252
            chat = db.query(Chat).filter_by(share_id=id).first()
253
254

            if chat:
255
                return self.get_chat_by_id(db, id)
256
257
            else:
                return None
258
        except Exception as e:
259
260
            return None

261
262
263
    def get_chat_by_id_and_user_id(
        self, db: Session, id: str, user_id: str
    ) -> Optional[ChatModel]:
Timothy J. Baek's avatar
Timothy J. Baek committed
264
        try:
265
266
            chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
            return ChatModel.model_validate(chat)
Timothy J. Baek's avatar
Timothy J. Baek committed
267
268
269
        except:
            return None

270
271
272
    def get_chats(self, db: Session, skip: int = 0, limit: int = 50) -> List[ChatModel]:
        all_chats = (
            db.query(Chat)
Timothy J. Baek's avatar
Timothy J. Baek committed
273
274
            # .limit(limit).offset(skip)
            .order_by(Chat.updated_at.desc())
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
        )
        return [ChatModel.model_validate(chat) for chat in all_chats]

    def get_chats_by_user_id(self, db: Session, user_id: str) -> List[ChatModel]:
        all_chats = (
            db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
        )
        return [ChatModel.model_validate(chat) for chat in all_chats]

    def get_archived_chats_by_user_id(
        self, db: Session, user_id: str
    ) -> List[ChatModel]:
        all_chats = (
            db.query(Chat)
            .filter_by(user_id=user_id, archived=True)
Timothy J. Baek's avatar
Timothy J. Baek committed
290
            .order_by(Chat.updated_at.desc())
291
292
        )
        return [ChatModel.model_validate(chat) for chat in all_chats]
Timothy J. Baek's avatar
Timothy J. Baek committed
293

294
    def delete_chat_by_id(self, db: Session, id: str) -> bool:
295
        try:
296
            db.query(Chat).filter_by(id=id).delete()
297

298
            return True and self.delete_shared_chat_by_chat_id(db, id)
299
300
301
        except:
            return False

302
    def delete_chat_by_id_and_user_id(self, db: Session, id: str, user_id: str) -> bool:
303
        try:
304
            db.query(Chat).filter_by(id=id, user_id=user_id).delete()
305

306
            return True and self.delete_shared_chat_by_chat_id(db, id)
307
308
309
        except:
            return False

310
    def delete_chats_by_user_id(self, db: Session, user_id: str) -> bool:
311
        try:
312

313
            self.delete_shared_chats_by_user_id(db, user_id)
314

315
            db.query(Chat).filter_by(user_id=user_id).delete()
316
            return True
Timothy J. Baek's avatar
Timothy J. Baek committed
317
318
319
        except:
            return False

320
    def delete_shared_chats_by_user_id(self, db: Session, user_id: str) -> bool:
Timothy J. Baek's avatar
Timothy J. Baek committed
321
        try:
322
323
            chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
            shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
Timothy J. Baek's avatar
Timothy J. Baek committed
324

325
            db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
Timothy J. Baek's avatar
Timothy J. Baek committed
326

327
328
329
330
            return True
        except:
            return False

Timothy J. Baek's avatar
Timothy J. Baek committed
331

332
Chats = ChatTable()