chats.py 14 KB
Newer Older
Anuraag Jain's avatar
Anuraag Jain committed
1
from fastapi import Depends, Request, HTTPException, status
Timothy J. Baek's avatar
Timothy J. Baek committed
2
3
from datetime import datetime, timedelta
from typing import List, Union, Optional
4
5

from apps.webui.internal.db import get_db
6
from utils.utils import get_current_user, get_admin_user
Timothy J. Baek's avatar
Timothy J. Baek committed
7
8
from fastapi import APIRouter
from pydantic import BaseModel
9
import json
10
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
11

12
13
from apps.webui.models.users import Users
from apps.webui.models.chats import (
Timothy J. Baek's avatar
Timothy J. Baek committed
14
    ChatModel,
15
    ChatResponse,
Timothy J. Baek's avatar
Timothy J. Baek committed
16
    ChatTitleForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
17
18
19
20
21
    ChatForm,
    ChatTitleIdResponse,
    Chats,
)

22

23
from apps.webui.models.tags import (
24
    TagModel,
25
    ChatIdTagModel,
26
27
28
29
30
    ChatIdTagForm,
    ChatTagsResponse,
    Tags,
)

Timothy J. Baek's avatar
Timothy J. Baek committed
31
32
from constants import ERROR_MESSAGES

33
from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT
Timothy J. Baek's avatar
Timothy J. Baek committed
34

35
36
37
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

Timothy J. Baek's avatar
Timothy J. Baek committed
38
39
40
router = APIRouter()

############################
Timothy J. Baek's avatar
Timothy J. Baek committed
41
# GetChatList
Timothy J. Baek's avatar
Timothy J. Baek committed
42
43
44
45
############################


@router.get("/", response_model=List[ChatTitleIdResponse])
Timothy J. Baek's avatar
Timothy J. Baek committed
46
47
@router.get("/list", response_model=List[ChatTitleIdResponse])
async def get_session_user_chat_list(
48
    user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db)
Timothy J. Baek's avatar
Timothy J. Baek committed
49
):
50
    return Chats.get_chat_list_by_user_id(db, user.id, skip, limit)
Timothy J. Baek's avatar
Timothy J. Baek committed
51
52


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
53
54
55
56
57
58
############################
# DeleteAllChats
############################


@router.delete("/", response_model=bool)
59
60
61
async def delete_all_user_chats(
    request: Request, user=Depends(get_current_user), db=Depends(get_db)
):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
62
63
64

    if (
        user.role == "user"
65
        and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
66
67
68
69
70
71
    ):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )

72
    result = Chats.delete_chats_by_user_id(db, user.id)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
73
74
75
    return result


Timothy J. Baek's avatar
Timothy J. Baek committed
76
77
78
79
80
81
82
############################
# GetUserChatList
############################


@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id(
83
84
85
86
87
    user_id: str,
    user=Depends(get_admin_user),
    skip: int = 0,
    limit: int = 50,
    db=Depends(get_db),
Timothy J. Baek's avatar
Timothy J. Baek committed
88
):
89
    return Chats.get_chat_list_by_user_id(
90
        db, user_id, include_archived=True, skip=skip, limit=limit
91
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
92
93


94
############################
95
# CreateNewChat
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
96
97
98
############################


99
@router.post("/new", response_model=Optional[ChatResponse])
100
101
102
async def create_new_chat(
    form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db)
):
103
    try:
104
        chat = Chats.insert_new_chat(db, user.id, form_data)
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
105
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
106
107
    except Exception as e:
        log.exception(e)
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
108
        raise HTTPException(
109
            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
110
111
112
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
113
############################
Timothy J. Baek's avatar
Timothy J. Baek committed
114
# GetChats
Timothy J. Baek's avatar
Timothy J. Baek committed
115
116
117
118
############################


@router.get("/all", response_model=List[ChatResponse])
119
async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)):
Anuraag Jain's avatar
Anuraag Jain committed
120
    return [
Timothy J. Baek's avatar
Timothy J. Baek committed
121
        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
122
        for chat in Chats.get_chats_by_user_id(db, user.id)
Timothy J. Baek's avatar
Timothy J. Baek committed
123
124
125
126
127
128
129
130
131
    ]


############################
# GetArchivedChats
############################


@router.get("/all/archived", response_model=List[ChatResponse])
132
async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get_db)):
Timothy J. Baek's avatar
Timothy J. Baek committed
133
134
    return [
        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
135
        for chat in Chats.get_archived_chats_by_user_id(db, user.id)
136
    ]
Timothy J. Baek's avatar
Timothy J. Baek committed
137
138


139
140
141
142
143
144
############################
# GetAllChatsInDB
############################


@router.get("/all/db", response_model=List[ChatResponse])
145
async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_db)):
146
    if not ENABLE_ADMIN_EXPORT:
147
148
149
150
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )
151
152
    return [
        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
153
        for chat in Chats.get_chats(db)
154
    ]
155
156


Timothy J. Baek's avatar
Timothy J. Baek committed
157
############################
158
# GetArchivedChats
Timothy J. Baek's avatar
Timothy J. Baek committed
159
160
161
############################


162
163
@router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
164
    user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db)
165
):
166
    return Chats.get_archived_chat_list_by_user_id(db, user.id, skip, limit)
167
168
169
170
171
172
173


############################
# ArchiveAllChats
############################


Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
174
@router.post("/archive/all", response_model=bool)
175
176
async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)):
    return Chats.archive_all_chats_by_user_id(db, user.id)
177
178
179
180
181
182
183
184


############################
# GetSharedChatById
############################


@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
185
186
187
async def get_shared_chat_by_id(
    share_id: str, user=Depends(get_current_user), db=Depends(get_db)
):
188
189
190
191
192
193
    if user.role == "pending":
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
        )

    if user.role == "user":
194
        chat = Chats.get_chat_by_share_id(db, share_id)
195
    elif user.role == "admin":
196
        chat = Chats.get_chat_by_id(db, share_id)
197
198

    if chat:
Timothy J. Baek's avatar
Timothy J. Baek committed
199
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
200
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
201
        raise HTTPException(
202
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
Timothy J. Baek's avatar
Timothy J. Baek committed
203
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
204
205


Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
206
207
208
209
210
211
212
213
214
215
216
217
218
############################
# GetChatsByTags
############################


class TagNameForm(BaseModel):
    name: str
    skip: Optional[int] = 0
    limit: Optional[int] = 50


@router.post("/tags", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
219
    form_data: TagNameForm, user=Depends(get_current_user), db=Depends(get_db)
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
220
221
222
223
224
225
):

    print(form_data)
    chat_ids = [
        chat_id_tag.chat_id
        for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
226
            db, form_data.name, user.id
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
227
228
229
        )
    ]

230
231
232
    chats = Chats.get_chat_list_by_chat_ids(
        db, chat_ids, form_data.skip, form_data.limit
    )
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
233
234

    if len(chats) == 0:
235
        Tags.delete_tag_by_tag_name_and_user_id(db, form_data.name, user.id)
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
236
237
238
239
240
241
242
243
244
245

    return chats


############################
# GetAllTags
############################


@router.get("/tags/all", response_model=List[TagModel])
246
async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)):
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
247
    try:
248
        tags = Tags.get_tags_by_user_id(db, user.id)
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
249
250
251
252
253
254
255
256
        return tags
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
257
258
259
260
261
############################
# GetChatById
############################


262
@router.get("/{id}", response_model=Optional[ChatResponse])
263
264
async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
Timothy J. Baek's avatar
Timothy J. Baek committed
265

Anuraag Jain's avatar
Anuraag Jain committed
266
    if chat:
Timothy J. Baek's avatar
Timothy J. Baek committed
267
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
Timothy J. Baek's avatar
Timothy J. Baek committed
268
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
269
270
271
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
272
273
274
275
276
277
278


############################
# UpdateChatById
############################


279
@router.post("/{id}", response_model=Optional[ChatResponse])
Timothy J. Baek's avatar
Timothy J. Baek committed
280
async def update_chat_by_id(
281
    id: str, form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db)
Timothy J. Baek's avatar
Timothy J. Baek committed
282
):
283
    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
Anuraag Jain's avatar
Anuraag Jain committed
284
    if chat:
285
        updated_chat = {**json.loads(chat.chat), **form_data.chat}
Timothy J. Baek's avatar
Timothy J. Baek committed
286

287
        chat = Chats.update_chat_by_id(db, id, updated_chat)
Timothy J. Baek's avatar
Timothy J. Baek committed
288
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
Anuraag Jain's avatar
Anuraag Jain committed
289
    else:
290
291
292
293
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )
294
295
296
297
298
299
300
301


############################
# DeleteChatById
############################


@router.delete("/{id}", response_model=bool)
302
303
304
async def delete_chat_by_id(
    request: Request, id: str, user=Depends(get_current_user), db=Depends(get_db)
):
Timothy J. Baek's avatar
Timothy J. Baek committed
305

306
    if user.role == "admin":
307
        result = Chats.delete_chat_by_id(db, id)
308
309
        return result
    else:
310
        if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
311
312
313
314
315
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
            )

316
        result = Chats.delete_chat_by_id_and_user_id(db, id, user.id)
317
        return result
Timothy J. Baek's avatar
Timothy J. Baek committed
318

319

Timothy J. Baek's avatar
Timothy J. Baek committed
320
321
322
323
324
325
############################
# CloneChat
############################


@router.get("/{id}/clone", response_model=Optional[ChatResponse])
326
327
async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
Timothy J. Baek's avatar
Timothy J. Baek committed
328
329
330
331
332
333
334
335
336
337
    if chat:

        chat_body = json.loads(chat.chat)
        updated_chat = {
            **chat_body,
            "originalChatId": chat.id,
            "branchPointMessageId": chat_body["history"]["currentId"],
            "title": f"Clone of {chat.title}",
        }

338
        chat = Chats.insert_new_chat(db, user.id, ChatForm(**{"chat": updated_chat}))
Timothy J. Baek's avatar
Timothy J. Baek committed
339
340
341
342
343
344
345
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
346
347
348
349
350
351
############################
# ArchiveChat
############################


@router.get("/{id}/archive", response_model=Optional[ChatResponse])
352
353
354
355
async def archive_chat_by_id(
    id: str, user=Depends(get_current_user), db=Depends(get_db)
):
    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
Timothy J. Baek's avatar
Timothy J. Baek committed
356
    if chat:
357
        chat = Chats.toggle_chat_archive_by_id(db, id)
Timothy J. Baek's avatar
Timothy J. Baek committed
358
359
360
361
362
363
364
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
        )


365
366
367
368
369
370
############################
# ShareChatById
############################


@router.post("/{id}/share", response_model=Optional[ChatResponse])
371
372
async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
373
374
    if chat:
        if chat.share_id:
375
            shared_chat = Chats.update_shared_chat_by_chat_id(db, chat.id)
376
377
378
379
            return ChatResponse(
                **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
            )

380
        shared_chat = Chats.insert_shared_chat_by_chat_id(db, chat.id)
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
        if not shared_chat:
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail=ERROR_MESSAGES.DEFAULT(),
            )

        return ChatResponse(
            **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
        )
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )


############################
# DeletedSharedChatById
############################


402
@router.delete("/{id}/share", response_model=Optional[bool])
403
404
405
406
async def delete_shared_chat_by_id(
    id: str, user=Depends(get_current_user), db=Depends(get_db)
):
    chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
407
408
409
410
    if chat:
        if not chat.share_id:
            return False

411
412
        result = Chats.delete_shared_chat_by_chat_id(db, id)
        update_result = Chats.update_chat_share_id_by_id(db, id, None)
413
414

        return result and update_result != None
415
416
417
418
419
420
421
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )


422
423
424
425
426
427
############################
# GetChatTagsById
############################


@router.get("/{id}/tags", response_model=List[TagModel])
428
429
430
431
async def get_chat_tags_by_id(
    id: str, user=Depends(get_current_user), db=Depends(get_db)
):
    tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id)
432

433
    if tags != None:
434
435
436
437
438
439
440
441
442
443
444
445
        return tags
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
        )


############################
# AddChatTagById
############################


446
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
447
async def add_chat_tag_by_id(
448
449
450
451
    id: str,
    form_data: ChatIdTagForm,
    user=Depends(get_current_user),
    db=Depends(get_db),
452
):
453
    tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id)
454
455

    if form_data.tag_name not in tags:
456
        tag = Tags.add_tag_to_chat(db, user.id, form_data)
457

458
459
460
461
462
463
464
        if tag:
            return tag
        else:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=ERROR_MESSAGES.NOT_FOUND,
            )
465
466
    else:
        raise HTTPException(
467
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
468
469
470
471
472
473
474
475
476
        )


############################
# DeleteChatTagById
############################


@router.delete("/{id}/tags", response_model=Optional[bool])
477
async def delete_chat_tag_by_id(
478
479
480
481
    id: str,
    form_data: ChatIdTagForm,
    user=Depends(get_current_user),
    db=Depends(get_db),
482
):
483
    result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
484
        db, form_data.tag_name, id, user.id
485
486
    )

487
488
489
490
491
492
493
494
495
496
497
498
499
500
    if result:
        return result
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
        )


############################
# DeleteAllChatTagsById
############################


@router.delete("/{id}/tags/all", response_model=Optional[bool])
501
502
503
504
async def delete_all_chat_tags_by_id(
    id: str, user=Depends(get_current_user), db=Depends(get_db)
):
    result = Tags.delete_tags_by_chat_id_and_user_id(db, id, user.id)
505
506
507

    if result:
        return result
508
509
510
511
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
        )