chats.py 13.3 KB
Newer Older
Anuraag Jain's avatar
Anuraag Jain committed
1
from fastapi import Depends, Request, HTTPException, status
Timothy J. Baek's avatar
Timothy J. Baek committed
2
3
from datetime import datetime, timedelta
from typing import List, Union, Optional
4

5
from utils.utils import get_current_user, get_admin_user
Timothy J. Baek's avatar
Timothy J. Baek committed
6
7
from fastapi import APIRouter
from pydantic import BaseModel
8
import json
9
import logging
Timothy J. Baek's avatar
Timothy J. Baek committed
10

11
12
from apps.webui.models.users import Users
from apps.webui.models.chats import (
Timothy J. Baek's avatar
Timothy J. Baek committed
13
    ChatModel,
14
    ChatResponse,
Timothy J. Baek's avatar
Timothy J. Baek committed
15
    ChatTitleForm,
Timothy J. Baek's avatar
Timothy J. Baek committed
16
17
18
19
20
    ChatForm,
    ChatTitleIdResponse,
    Chats,
)

21

22
from apps.webui.models.tags import (
23
    TagModel,
24
    ChatIdTagModel,
25
26
27
28
29
    ChatIdTagForm,
    ChatTagsResponse,
    Tags,
)

Timothy J. Baek's avatar
Timothy J. Baek committed
30
31
from constants import ERROR_MESSAGES

32
from config import SRC_LOG_LEVELS, ENABLE_ADMIN_EXPORT
Timothy J. Baek's avatar
Timothy J. Baek committed
33

34
35
36
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])

Timothy J. Baek's avatar
Timothy J. Baek committed
37
38
39
router = APIRouter()

############################
Timothy J. Baek's avatar
Timothy J. Baek committed
40
# GetChatList
Timothy J. Baek's avatar
Timothy J. Baek committed
41
42
43
44
############################


@router.get("/", response_model=List[ChatTitleIdResponse])
Timothy J. Baek's avatar
Timothy J. Baek committed
45
46
@router.get("/list", response_model=List[ChatTitleIdResponse])
async def get_session_user_chat_list(
47
    user=Depends(get_current_user), skip: int = 0, limit: int = 50
Timothy J. Baek's avatar
Timothy J. Baek committed
48
):
49
    return Chats.get_chat_list_by_user_id(user.id, skip, limit)
Timothy J. Baek's avatar
Timothy J. Baek committed
50
51


Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
52
53
54
55
56
57
############################
# DeleteAllChats
############################


@router.delete("/", response_model=bool)
58
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
59
60
61

    if (
        user.role == "user"
62
        and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
63
64
65
66
67
68
    ):
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )

69
    result = Chats.delete_chats_by_user_id(user.id)
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
70
71
72
    return result


Timothy J. Baek's avatar
Timothy J. Baek committed
73
74
75
76
77
78
79
############################
# GetUserChatList
############################


@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id(
80
81
82
83
    user_id: str,
    user=Depends(get_admin_user),
    skip: int = 0,
    limit: int = 50,
Timothy J. Baek's avatar
Timothy J. Baek committed
84
):
85
    return Chats.get_chat_list_by_user_id(
86
        user_id, include_archived=True, skip=skip, limit=limit
87
    )
Timothy J. Baek's avatar
Timothy J. Baek committed
88
89


90
############################
91
# CreateNewChat
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
92
93
94
############################


95
@router.post("/new", response_model=Optional[ChatResponse])
96
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
97
    try:
98
        chat = Chats.insert_new_chat(user.id, form_data)
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
99
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
100
101
    except Exception as e:
        log.exception(e)
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
102
        raise HTTPException(
103
            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
104
105
106
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
107
############################
Timothy J. Baek's avatar
Timothy J. Baek committed
108
# GetChats
Timothy J. Baek's avatar
Timothy J. Baek committed
109
110
111
112
############################


@router.get("/all", response_model=List[ChatResponse])
113
async def get_user_chats(user=Depends(get_current_user)):
Anuraag Jain's avatar
Anuraag Jain committed
114
    return [
Timothy J. Baek's avatar
Timothy J. Baek committed
115
        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
116
        for chat in Chats.get_chats_by_user_id(user.id)
Timothy J. Baek's avatar
Timothy J. Baek committed
117
118
119
120
121
122
123
124
125
    ]


############################
# GetArchivedChats
############################


@router.get("/all/archived", response_model=List[ChatResponse])
126
async def get_user_archived_chats(user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
127
128
    return [
        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
129
        for chat in Chats.get_archived_chats_by_user_id(user.id)
130
    ]
Timothy J. Baek's avatar
Timothy J. Baek committed
131
132


133
134
135
136
137
138
############################
# GetAllChatsInDB
############################


@router.get("/all/db", response_model=List[ChatResponse])
139
async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
140
    if not ENABLE_ADMIN_EXPORT:
141
142
143
144
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )
145
146
    return [
        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
147
        for chat in Chats.get_chats()
148
    ]
149
150


Timothy J. Baek's avatar
Timothy J. Baek committed
151
############################
152
# GetArchivedChats
Timothy J. Baek's avatar
Timothy J. Baek committed
153
154
155
############################


156
157
@router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
158
    user=Depends(get_current_user), skip: int = 0, limit: int = 50
159
):
160
    return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
161
162
163
164
165
166
167


############################
# ArchiveAllChats
############################


Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
168
@router.post("/archive/all", response_model=bool)
169
170
async def archive_all_chats(user=Depends(get_current_user)):
    return Chats.archive_all_chats_by_user_id(user.id)
171
172
173
174
175
176
177
178


############################
# GetSharedChatById
############################


@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
179
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
180
181
182
183
184
185
    if user.role == "pending":
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
        )

    if user.role == "user":
186
        chat = Chats.get_chat_by_share_id(share_id)
187
    elif user.role == "admin":
188
        chat = Chats.get_chat_by_id(share_id)
189
190

    if chat:
Timothy J. Baek's avatar
Timothy J. Baek committed
191
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
192
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
193
        raise HTTPException(
194
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
Timothy J. Baek's avatar
Timothy J. Baek committed
195
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
196
197


Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
198
199
200
201
202
203
204
205
206
207
208
209
210
############################
# GetChatsByTags
############################


class TagNameForm(BaseModel):
    name: str
    skip: Optional[int] = 0
    limit: Optional[int] = 50


@router.post("/tags", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
211
    form_data: TagNameForm, user=Depends(get_current_user)
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
212
213
214
215
216
217
):

    print(form_data)
    chat_ids = [
        chat_id_tag.chat_id
        for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
218
            form_data.name, user.id
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
219
220
221
        )
    ]

222
    chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
223
224

    if len(chats) == 0:
225
        Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
226
227
228
229
230
231
232
233
234
235

    return chats


############################
# GetAllTags
############################


@router.get("/tags/all", response_model=List[TagModel])
236
async def get_all_tags(user=Depends(get_current_user)):
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
237
    try:
238
        tags = Tags.get_tags_by_user_id(user.id)
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
239
240
241
242
243
244
245
246
        return tags
    except Exception as e:
        log.exception(e)
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
247
248
249
250
251
############################
# GetChatById
############################


252
@router.get("/{id}", response_model=Optional[ChatResponse])
253
254
async def get_chat_by_id(id: str, user=Depends(get_current_user)):
    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
Timothy J. Baek's avatar
Timothy J. Baek committed
255

Anuraag Jain's avatar
Anuraag Jain committed
256
    if chat:
Timothy J. Baek's avatar
Timothy J. Baek committed
257
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
Timothy J. Baek's avatar
Timothy J. Baek committed
258
    else:
Timothy J. Baek's avatar
Timothy J. Baek committed
259
260
261
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
262
263
264
265
266
267
268


############################
# UpdateChatById
############################


269
@router.post("/{id}", response_model=Optional[ChatResponse])
Timothy J. Baek's avatar
Timothy J. Baek committed
270
async def update_chat_by_id(
271
    id: str, form_data: ChatForm, user=Depends(get_current_user)
Timothy J. Baek's avatar
Timothy J. Baek committed
272
):
273
    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
Anuraag Jain's avatar
Anuraag Jain committed
274
    if chat:
275
        updated_chat = {**json.loads(chat.chat), **form_data.chat}
Timothy J. Baek's avatar
Timothy J. Baek committed
276

277
        chat = Chats.update_chat_by_id(id, updated_chat)
Timothy J. Baek's avatar
Timothy J. Baek committed
278
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
Anuraag Jain's avatar
Anuraag Jain committed
279
    else:
280
281
282
283
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )
284
285
286
287
288
289
290
291


############################
# DeleteChatById
############################


@router.delete("/{id}", response_model=bool)
292
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
Timothy J. Baek's avatar
Timothy J. Baek committed
293

294
    if user.role == "admin":
295
        result = Chats.delete_chat_by_id(id)
296
297
        return result
    else:
298
        if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
299
300
301
302
303
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
            )

304
        result = Chats.delete_chat_by_id_and_user_id(id, user.id)
305
        return result
Timothy J. Baek's avatar
Timothy J. Baek committed
306

307

Timothy J. Baek's avatar
Timothy J. Baek committed
308
309
310
311
312
313
############################
# CloneChat
############################


@router.get("/{id}/clone", response_model=Optional[ChatResponse])
314
315
async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
Timothy J. Baek's avatar
Timothy J. Baek committed
316
317
318
319
320
321
322
323
324
325
    if chat:

        chat_body = json.loads(chat.chat)
        updated_chat = {
            **chat_body,
            "originalChatId": chat.id,
            "branchPointMessageId": chat_body["history"]["currentId"],
            "title": f"Clone of {chat.title}",
        }

326
        chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
Timothy J. Baek's avatar
Timothy J. Baek committed
327
328
329
330
331
332
333
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
        )


Timothy J. Baek's avatar
Timothy J. Baek committed
334
335
336
337
338
339
############################
# ArchiveChat
############################


@router.get("/{id}/archive", response_model=Optional[ChatResponse])
340
async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
341
    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
Timothy J. Baek's avatar
Timothy J. Baek committed
342
    if chat:
343
        chat = Chats.toggle_chat_archive_by_id(id)
Timothy J. Baek's avatar
Timothy J. Baek committed
344
345
346
347
348
349
350
        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
        )


351
352
353
354
355
356
############################
# ShareChatById
############################


@router.post("/{id}/share", response_model=Optional[ChatResponse])
357
358
async def share_chat_by_id(id: str, user=Depends(get_current_user)):
    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
359
360
    if chat:
        if chat.share_id:
361
            shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
362
363
364
365
            return ChatResponse(
                **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
            )

366
        shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
        if not shared_chat:
            raise HTTPException(
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
                detail=ERROR_MESSAGES.DEFAULT(),
            )

        return ChatResponse(
            **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
        )
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )


############################
# DeletedSharedChatById
############################


388
@router.delete("/{id}/share", response_model=Optional[bool])
389
async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
390
    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
391
392
393
394
    if chat:
        if not chat.share_id:
            return False

395
396
        result = Chats.delete_shared_chat_by_chat_id(id)
        update_result = Chats.update_chat_share_id_by_id(id, None)
397
398

        return result and update_result != None
399
400
401
402
403
404
405
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )


406
407
408
409
410
411
############################
# GetChatTagsById
############################


@router.get("/{id}/tags", response_model=List[TagModel])
412
async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
413
    tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
414

415
    if tags != None:
416
417
418
419
420
421
422
423
424
425
426
427
        return tags
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
        )


############################
# AddChatTagById
############################


428
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
429
async def add_chat_tag_by_id(
430
    id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
431
):
432
    tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
433
434

    if form_data.tag_name not in tags:
435
        tag = Tags.add_tag_to_chat(user.id, form_data)
436

437
438
439
440
441
442
443
        if tag:
            return tag
        else:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=ERROR_MESSAGES.NOT_FOUND,
            )
444
445
    else:
        raise HTTPException(
446
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
447
448
449
450
451
452
453
454
455
        )


############################
# DeleteChatTagById
############################


@router.delete("/{id}/tags", response_model=Optional[bool])
456
async def delete_chat_tag_by_id(
457
458
459
    id: str,
    form_data: ChatIdTagForm,
    user=Depends(get_current_user),
460
):
461
    result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
462
        form_data.tag_name, id, user.id
463
464
    )

465
466
467
468
469
470
471
472
473
474
475
476
477
478
    if result:
        return result
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
        )


############################
# DeleteAllChatTagsById
############################


@router.delete("/{id}/tags/all", response_model=Optional[bool])
479
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
480
    result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
481
482
483

    if result:
        return result
484
485
486
487
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
        )