Commit df09d083 authored by Jonathan Rohde's avatar Jonathan Rohde
Browse files

feat(sqlalchemy): Replace peewee with sqlalchemy

parent 8dac2a21
from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from apps.webui.internal.db import get_db
from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter
from pydantic import BaseModel
......@@ -43,9 +45,9 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse])
@router.get("/list", response_model=List[ChatTitleIdResponse])
async def get_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db)
):
return Chats.get_chat_list_by_user_id(user.id, skip, limit)
return Chats.get_chat_list_by_user_id(db, user.id, skip, limit)
############################
......@@ -54,7 +56,9 @@ async def get_session_user_chat_list(
@router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
async def delete_all_user_chats(
request: Request, user=Depends(get_current_user), db=Depends(get_db)
):
if (
user.role == "user"
......@@ -65,7 +69,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chats_by_user_id(user.id)
result = Chats.delete_chats_by_user_id(db, user.id)
return result
......@@ -76,10 +80,14 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
user_id: str,
user=Depends(get_admin_user),
skip: int = 0,
limit: int = 50,
db=Depends(get_db),
):
return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit
db, user_id, include_archived=True, skip=skip, limit=limit
)
......@@ -89,9 +97,11 @@ async def get_user_chat_list_by_user_id(
@router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
async def create_new_chat(
form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db)
):
try:
chat = Chats.insert_new_chat(user.id, form_data)
chat = Chats.insert_new_chat(db, user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e:
log.exception(e)
......@@ -106,10 +116,10 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
@router.get("/all", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user)):
async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats_by_user_id(user.id)
for chat in Chats.get_chats_by_user_id(db, user.id)
]
......@@ -119,10 +129,10 @@ async def get_user_chats(user=Depends(get_current_user)):
@router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user)):
async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get_db)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id)
for chat in Chats.get_archived_chats_by_user_id(db, user.id)
]
......@@ -132,7 +142,7 @@ async def get_user_chats(user=Depends(get_current_user)):
@router.get("/all/db", response_model=List[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_db)):
if not ENABLE_ADMIN_EXPORT:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
......@@ -140,7 +150,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
)
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats()
for chat in Chats.get_chats(db)
]
......@@ -151,9 +161,9 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
@router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db)
):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
return Chats.get_archived_chat_list_by_user_id(db, user.id, skip, limit)
############################
......@@ -162,8 +172,8 @@ async def get_archived_session_user_chat_list(
@router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(user.id)
async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)):
return Chats.archive_all_chats_by_user_id(db, user.id)
############################
......@@ -172,16 +182,18 @@ async def archive_all_chats(user=Depends(get_current_user)):
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
async def get_shared_chat_by_id(
share_id: str, user=Depends(get_current_user), db=Depends(get_db)
):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
chat = Chats.get_chat_by_share_id(db, share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
chat = Chats.get_chat_by_id(db, share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
......@@ -204,21 +216,23 @@ class TagNameForm(BaseModel):
@router.post("/tags", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_current_user)
form_data: TagNameForm, user=Depends(get_current_user), db=Depends(get_db)
):
print(form_data)
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
form_data.name, user.id
db, form_data.name, user.id
)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
chats = Chats.get_chat_list_by_chat_ids(
db, chat_ids, form_data.skip, form_data.limit
)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
Tags.delete_tag_by_tag_name_and_user_id(db, form_data.name, user.id)
return chats
......@@ -229,9 +243,9 @@ async def get_user_chat_list_by_tag_name(
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)):
try:
tags = Tags.get_tags_by_user_id(user.id)
tags = Tags.get_tags_by_user_id(db, user.id)
return tags
except Exception as e:
log.exception(e)
......@@ -246,8 +260,8 @@ async def get_all_tags(user=Depends(get_current_user)):
@router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
......@@ -264,13 +278,13 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(
id: str, form_data: ChatForm, user=Depends(get_current_user)
id: str, form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db)
):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat)
chat = Chats.update_chat_by_id(db, id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
......@@ -285,10 +299,12 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
async def delete_chat_by_id(
request: Request, id: str, user=Depends(get_current_user), db=Depends(get_db)
):
if user.role == "admin":
result = Chats.delete_chat_by_id(id)
result = Chats.delete_chat_by_id(db, id)
return result
else:
if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
......@@ -297,7 +313,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
result = Chats.delete_chat_by_id_and_user_id(db, id, user.id)
return result
......@@ -307,8 +323,8 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
if chat:
chat_body = json.loads(chat.chat)
......@@ -319,7 +335,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
"title": f"Clone of {chat.title}",
}
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
chat = Chats.insert_new_chat(db, user.id, ChatForm(**{"chat": updated_chat}))
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
......@@ -333,10 +349,12 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
@router.get("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def archive_chat_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db)
):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
if chat:
chat = Chats.toggle_chat_archive_by_id(id)
chat = Chats.toggle_chat_archive_by_id(db, id)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
......@@ -350,16 +368,16 @@ async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
if chat:
if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
shared_chat = Chats.update_shared_chat_by_chat_id(db, chat.id)
return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
)
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
shared_chat = Chats.insert_shared_chat_by_chat_id(db, chat.id)
if not shared_chat:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
......@@ -382,14 +400,16 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)):
@router.delete("/{id}/share", response_model=Optional[bool])
async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
async def delete_shared_chat_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db)
):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id)
if chat:
if not chat.share_id:
return False
result = Chats.delete_shared_chat_by_chat_id(id)
update_result = Chats.update_chat_share_id_by_id(id, None)
result = Chats.delete_shared_chat_by_chat_id(db, id)
update_result = Chats.update_chat_share_id_by_id(db, id, None)
return result and update_result != None
else:
......@@ -405,8 +425,10 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
@router.get("/{id}/tags", response_model=List[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
async def get_chat_tags_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db)
):
tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id)
if tags != None:
return tags
......@@ -423,12 +445,15 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
async def add_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
id: str,
form_data: ChatIdTagForm,
user=Depends(get_current_user),
db=Depends(get_db),
):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id)
if form_data.tag_name not in tags:
tag = Tags.add_tag_to_chat(user.id, form_data)
tag = Tags.add_tag_to_chat(db, user.id, form_data)
if tag:
return tag
......@@ -450,10 +475,13 @@ async def add_chat_tag_by_id(
@router.delete("/{id}/tags", response_model=Optional[bool])
async def delete_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
id: str,
form_data: ChatIdTagForm,
user=Depends(get_current_user),
db=Depends(get_db),
):
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
form_data.tag_name, id, user.id
db, form_data.tag_name, id, user.id
)
if result:
......@@ -470,8 +498,10 @@ async def delete_chat_tag_by_id(
@router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
async def delete_all_chat_tags_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db)
):
result = Tags.delete_tags_by_chat_id_and_user_id(db, id, user.id)
if result:
return result
......
......@@ -6,6 +6,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.internal.db import get_db
from apps.webui.models.documents import (
Documents,
DocumentForm,
......@@ -25,7 +26,7 @@ router = APIRouter()
@router.get("/", response_model=List[DocumentResponse])
async def get_documents(user=Depends(get_current_user)):
async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
docs = [
DocumentResponse(
**{
......@@ -33,7 +34,7 @@ async def get_documents(user=Depends(get_current_user)):
"content": json.loads(doc.content if doc.content else "{}"),
}
)
for doc in Documents.get_docs()
for doc in Documents.get_docs(db)
]
return docs
......@@ -44,10 +45,12 @@ async def get_documents(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[DocumentResponse])
async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
doc = Documents.get_doc_by_name(form_data.name)
async def create_new_doc(
form_data: DocumentForm, user=Depends(get_admin_user), db=Depends(get_db)
):
doc = Documents.get_doc_by_name(db, form_data.name)
if doc == None:
doc = Documents.insert_new_doc(user.id, form_data)
doc = Documents.insert_new_doc(db, user.id, form_data)
if doc:
return DocumentResponse(
......@@ -74,8 +77,10 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
@router.get("/doc", response_model=Optional[DocumentResponse])
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
doc = Documents.get_doc_by_name(name)
async def get_doc_by_name(
name: str, user=Depends(get_current_user), db=Depends(get_db)
):
doc = Documents.get_doc_by_name(db, name)
if doc:
return DocumentResponse(
......@@ -106,8 +111,12 @@ class TagDocumentForm(BaseModel):
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
async def tag_doc_by_name(
form_data: TagDocumentForm, user=Depends(get_current_user), db=Depends(get_db)
):
doc = Documents.update_doc_content_by_name(
db, form_data.name, {"tags": form_data.tags}
)
if doc:
return DocumentResponse(
......@@ -130,9 +139,12 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
@router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
name: str,
form_data: DocumentUpdateForm,
user=Depends(get_admin_user),
db=Depends(get_db),
):
doc = Documents.update_doc_by_name(name, form_data)
doc = Documents.update_doc_by_name(db, name, form_data)
if doc:
return DocumentResponse(
**{
......@@ -153,6 +165,8 @@ async def update_doc_by_name(
@router.delete("/doc/delete", response_model=bool)
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
result = Documents.delete_doc_by_name(name)
async def delete_doc_by_name(
name: str, user=Depends(get_admin_user), db=Depends(get_db)
):
result = Documents.delete_doc_by_name(db, name)
return result
......@@ -20,6 +20,7 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from pydantic import BaseModel
import json
from apps.webui.internal.db import get_db
from apps.webui.models.files import (
Files,
FileForm,
......@@ -53,6 +54,7 @@ router = APIRouter()
def upload_file(
file: UploadFile = File(...),
user=Depends(get_verified_user),
db=Depends(get_db)
):
log.info(f"file.content_type: {file.content_type}")
try:
......@@ -70,6 +72,7 @@ def upload_file(
f.close()
file = Files.insert_new_file(
db,
user.id,
FileForm(
**{
......@@ -106,8 +109,8 @@ def upload_file(
@router.get("/", response_model=List[FileModel])
async def list_files(user=Depends(get_verified_user)):
files = Files.get_files()
async def list_files(user=Depends(get_verified_user), db=Depends(get_db)):
files = Files.get_files(db)
return files
......@@ -117,8 +120,8 @@ async def list_files(user=Depends(get_verified_user)):
@router.delete("/all")
async def delete_all_files(user=Depends(get_admin_user)):
result = Files.delete_all_files()
async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)):
result = Files.delete_all_files(db)
if result:
folder = f"{UPLOAD_DIR}"
......@@ -154,8 +157,8 @@ async def delete_all_files(user=Depends(get_admin_user)):
@router.get("/{id}", response_model=Optional[FileModel])
async def get_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
file = Files.get_file_by_id(db, id)
if file:
return file
......@@ -172,8 +175,8 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user)):
@router.get("/{id}/content", response_model=Optional[FileModel])
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
async def get_file_content_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
file = Files.get_file_by_id(db, id)
if file:
file_path = Path(file.meta["path"])
......@@ -223,11 +226,11 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.delete("/{id}")
async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
async def delete_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
file = Files.get_file_by_id(db, id)
if file:
result = Files.delete_file_by_id(id)
result = Files.delete_file_by_id(db, id)
if result:
return {"message": "File deleted successfully"}
else:
......
......@@ -6,6 +6,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.internal.db import get_db
from apps.webui.models.functions import (
Functions,
FunctionForm,
......@@ -31,8 +32,8 @@ router = APIRouter()
@router.get("/", response_model=List[FunctionResponse])
async def get_functions(user=Depends(get_verified_user)):
return Functions.get_functions()
async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)):
return Functions.get_functions(db)
############################
......@@ -41,8 +42,8 @@ async def get_functions(user=Depends(get_verified_user)):
@router.get("/export", response_model=List[FunctionModel])
async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions()
async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)):
return Functions.get_functions(db)
############################
......@@ -52,7 +53,7 @@ async def get_functions(user=Depends(get_admin_user)):
@router.post("/create", response_model=Optional[FunctionResponse])
async def create_new_function(
request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db)
):
if not form_data.id.isidentifier():
raise HTTPException(
......@@ -62,7 +63,7 @@ async def create_new_function(
form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(form_data.id)
function = Functions.get_function_by_id(db, form_data.id)
if function == None:
function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
try:
......@@ -77,7 +78,7 @@ async def create_new_function(
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module
function = Functions.insert_new_function(user.id, function_type, form_data)
function = Functions.insert_new_function(db, user.id, function_type, form_data)
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True)
......@@ -108,8 +109,8 @@ async def create_new_function(
@router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
async def get_function_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
function = Functions.get_function_by_id(db, id)
if function:
return function
......@@ -154,7 +155,7 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_function_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db)
):
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
......@@ -171,7 +172,7 @@ async def update_function_by_id(
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
print(updated)
function = Functions.update_function_by_id(id, updated)
function = Functions.update_function_by_id(db, id, updated)
if function:
return function
......@@ -195,9 +196,9 @@ async def update_function_by_id(
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user)
request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db)
):
result = Functions.delete_function_by_id(id)
result = Functions.delete_function_by_id(db, id)
if result:
FUNCTIONS = request.app.state.FUNCTIONS
......
......@@ -7,6 +7,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import logging
from apps.webui.internal.db import get_db
from apps.webui.models.memories import Memories, MemoryModel
from utils.utils import get_verified_user
......@@ -31,8 +32,8 @@ async def get_embeddings(request: Request):
@router.get("/", response_model=List[MemoryModel])
async def get_memories(user=Depends(get_verified_user)):
return Memories.get_memories_by_user_id(user.id)
async def get_memories(user=Depends(get_verified_user), db=Depends(get_db)):
return Memories.get_memories_by_user_id(db, user.id)
############################
......@@ -50,9 +51,12 @@ class MemoryUpdateModel(BaseModel):
@router.post("/add", response_model=Optional[MemoryModel])
async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
request: Request,
form_data: AddMemoryForm,
user=Depends(get_verified_user),
db=Depends(get_db),
):
memory = Memories.insert_new_memory(user.id, form_data.content)
memory = Memories.insert_new_memory(db, user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
......@@ -72,8 +76,9 @@ async def update_memory_by_id(
request: Request,
form_data: MemoryUpdateModel,
user=Depends(get_verified_user),
db=Depends(get_db),
):
memory = Memories.update_memory_by_id(memory_id, form_data.content)
memory = Memories.update_memory_by_id(db, memory_id, form_data.content)
if memory is None:
raise HTTPException(status_code=404, detail="Memory not found")
......@@ -124,12 +129,12 @@ async def query_memory(
############################
@router.get("/reset", response_model=bool)
async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user)
request: Request, user=Depends(get_verified_user), db=Depends(get_db)
):
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id)
memories = Memories.get_memories_by_user_id(db, user.id)
for memory in memories:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection.upsert(
......@@ -146,8 +151,8 @@ async def reset_memory_from_vector_db(
@router.delete("/user", response_model=bool)
async def delete_memory_by_user_id(user=Depends(get_verified_user)):
result = Memories.delete_memories_by_user_id(user.id)
async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(get_db)):
result = Memories.delete_memories_by_user_id(db, user.id)
if result:
try:
......@@ -165,8 +170,10 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user)):
@router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
async def delete_memory_by_id(
memory_id: str, user=Depends(get_verified_user), db=Depends(get_db)
):
result = Memories.delete_memory_by_id_and_user_id(db, memory_id, user.id)
if result:
collection = CHROMA_CLIENT.get_or_create_collection(
......
......@@ -5,6 +5,8 @@ from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.internal.db import get_db
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user
......@@ -18,8 +20,8 @@ router = APIRouter()
@router.get("/", response_model=List[ModelResponse])
async def get_models(user=Depends(get_verified_user)):
return Models.get_all_models()
async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
return Models.get_all_models(db)
############################
......@@ -29,7 +31,10 @@ async def get_models(user=Depends(get_verified_user)):
@router.post("/add", response_model=Optional[ModelModel])
async def add_new_model(
request: Request, form_data: ModelForm, user=Depends(get_admin_user)
request: Request,
form_data: ModelForm,
user=Depends(get_admin_user),
db=Depends(get_db),
):
if form_data.id in request.app.state.MODELS:
raise HTTPException(
......@@ -37,7 +42,7 @@ async def add_new_model(
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
)
else:
model = Models.insert_new_model(form_data, user.id)
model = Models.insert_new_model(db, form_data, user.id)
if model:
return model
......@@ -53,9 +58,9 @@ async def add_new_model(
############################
@router.get("/", response_model=Optional[ModelModel])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
@router.get("/{id}", response_model=Optional[ModelModel])
async def get_model_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
model = Models.get_model_by_id(db, id)
if model:
return model
......@@ -73,15 +78,19 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
@router.post("/update", response_model=Optional[ModelModel])
async def update_model_by_id(
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
request: Request,
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
db=Depends(get_db),
):
model = Models.get_model_by_id(id)
model = Models.get_model_by_id(db, id)
if model:
model = Models.update_model_by_id(id, form_data)
model = Models.update_model_by_id(db, id, form_data)
return model
else:
if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(form_data, user.id)
model = Models.insert_new_model(db, form_data, user.id)
if model:
return model
else:
......@@ -102,6 +111,6 @@ async def update_model_by_id(
@router.delete("/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
result = Models.delete_model_by_id(id)
async def delete_model_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
result = Models.delete_model_by_id(db, id)
return result
......@@ -6,6 +6,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.internal.db import get_db
from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user, get_admin_user
......@@ -19,8 +20,8 @@ router = APIRouter()
@router.get("/", response_model=List[PromptModel])
async def get_prompts(user=Depends(get_current_user)):
return Prompts.get_prompts()
async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)):
return Prompts.get_prompts(db)
############################
......@@ -29,10 +30,12 @@ async def get_prompts(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
prompt = Prompts.get_prompt_by_command(form_data.command)
async def create_new_prompt(
form_data: PromptForm, user=Depends(get_admin_user), db=Depends(get_db)
):
prompt = Prompts.get_prompt_by_command(db, form_data.command)
if prompt == None:
prompt = Prompts.insert_new_prompt(user.id, form_data)
prompt = Prompts.insert_new_prompt(db, user.id, form_data)
if prompt:
return prompt
......@@ -52,8 +55,10 @@ async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user))
@router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}")
async def get_prompt_by_command(
command: str, user=Depends(get_current_user), db=Depends(get_db)
):
prompt = Prompts.get_prompt_by_command(db, f"/{command}")
if prompt:
return prompt
......@@ -71,9 +76,12 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
@router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_admin_user)
command: str,
form_data: PromptForm,
user=Depends(get_admin_user),
db=Depends(get_db),
):
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
prompt = Prompts.update_prompt_by_command(db, f"/{command}", form_data)
if prompt:
return prompt
else:
......@@ -89,6 +97,8 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
result = Prompts.delete_prompt_by_command(f"/{command}")
async def delete_prompt_by_command(
command: str, user=Depends(get_admin_user), db=Depends(get_db)
):
result = Prompts.delete_prompt_by_command(db, f"/{command}")
return result
......@@ -6,7 +6,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.internal.db import get_db
from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
......@@ -34,7 +34,7 @@ router = APIRouter()
@router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_verified_user)):
async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
......@@ -45,8 +45,8 @@ async def get_toolkits(user=Depends(get_verified_user)):
@router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
async def get_toolkits(user=Depends(get_admin_user), db=Depends(get_db)):
toolkits = [toolkit for toolkit in Tools.get_tools(db)]
return toolkits
......@@ -57,7 +57,10 @@ async def get_toolkits(user=Depends(get_admin_user)):
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
request: Request,
form_data: ToolForm,
user=Depends(get_admin_user),
db=Depends(get_db),
):
if not form_data.id.isidentifier():
raise HTTPException(
......@@ -67,7 +70,7 @@ async def create_new_toolkit(
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
toolkit = Tools.get_tool_by_id(db, form_data.id)
if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
......@@ -81,7 +84,7 @@ async def create_new_toolkit(
TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
toolkit = Tools.insert_new_tool(db, user.id, form_data, specs)
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True)
......@@ -112,8 +115,8 @@ async def create_new_toolkit(
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
toolkit = Tools.get_tool_by_id(db, id)
if toolkit:
return toolkit
......@@ -131,7 +134,11 @@ async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
request: Request,
id: str,
form_data: ToolForm,
user=Depends(get_admin_user),
db=Depends(get_db),
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
......@@ -153,7 +160,7 @@ async def update_toolkit_by_id(
}
print(updated)
toolkit = Tools.update_tool_by_id(id, updated)
toolkit = Tools.update_tool_by_id(db, id, updated)
if toolkit:
return toolkit
......@@ -176,8 +183,10 @@ async def update_toolkit_by_id(
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
result = Tools.delete_tool_by_id(id)
async def delete_toolkit_by_id(
request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db)
):
result = Tools.delete_tool_by_id(db, id)
if result:
TOOLS = request.app.state.TOOLS
......
......@@ -9,6 +9,7 @@ import time
import uuid
import logging
from apps.webui.internal.db import get_db
from apps.webui.models.users import (
UserModel,
UserUpdateForm,
......@@ -40,8 +41,10 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel])
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
return Users.get_users(skip, limit)
async def get_users(
skip: int = 0, limit: int = 50, user=Depends(get_admin_user), db=Depends(get_db)
):
return Users.get_users(db, skip, limit)
############################
......@@ -68,10 +71,12 @@ async def update_user_permissions(
@router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
async def update_user_role(
form_data: UserRoleUpdateForm, user=Depends(get_admin_user), db=Depends(get_db)
):
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
if user.id != form_data.id and form_data.id != Users.get_first_user(db).id:
return Users.update_user_role_by_id(db, form_data.id, form_data.role)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
......@@ -85,8 +90,10 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
@router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
async def get_user_settings_by_session_user(
user=Depends(get_verified_user), db=Depends(get_db)
):
user = Users.get_user_by_id(db, user.id)
if user:
return user.settings
else:
......@@ -103,9 +110,9 @@ async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user)
form_data: UserSettings, user=Depends(get_verified_user), db=Depends(get_db)
):
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
user = Users.update_user_by_id(db, user.id, {"settings": form_data.model_dump()})
if user:
return user.settings
else:
......@@ -121,8 +128,10 @@ async def update_user_settings_by_session_user(
@router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
async def get_user_info_by_session_user(
user=Depends(get_verified_user), db=Depends(get_db)
):
user = Users.get_user_by_id(db, user.id)
if user:
return user.info
else:
......@@ -138,15 +147,17 @@ async def get_user_info_by_session_user(user=Depends(get_verified_user)):
@router.post("/user/info/update", response_model=Optional[dict])
async def update_user_settings_by_session_user(
form_data: dict, user=Depends(get_verified_user)
async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user), db=Depends(get_db)
):
user = Users.get_user_by_id(user.id)
user = Users.get_user_by_id(db, user.id)
if user:
if user.info is None:
user.info = {}
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
user = Users.update_user_by_id(
db, user.id, {"info": {**user.info, **form_data}}
)
if user:
return user.info
else:
......@@ -172,13 +183,15 @@ class UserResponse(BaseModel):
@router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
async def get_user_by_id(
user_id: str, user=Depends(get_verified_user), db=Depends(get_db)
):
# Check if user_id is a shared chat
# If it is, get the user_id from the chat
if user_id.startswith("shared-"):
chat_id = user_id.replace("shared-", "")
chat = Chats.get_chat_by_id(chat_id)
chat = Chats.get_chat_by_id(db, chat_id)
if chat:
user_id = chat.user_id
else:
......@@ -187,7 +200,7 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
user = Users.get_user_by_id(user_id)
user = Users.get_user_by_id(db, user_id)
if user:
return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
......@@ -205,13 +218,16 @@ async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
@router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
user_id: str,
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
db=Depends(get_db),
):
user = Users.get_user_by_id(user_id)
user = Users.get_user_by_id(db, user_id)
if user:
if form_data.email.lower() != user.email:
email_user = Users.get_user_by_email(form_data.email.lower())
email_user = Users.get_user_by_email(db, form_data.email.lower())
if email_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
......@@ -221,10 +237,11 @@ async def update_user_by_id(
if form_data.password:
hashed = get_password_hash(form_data.password)
log.debug(f"hashed: {hashed}")
Auths.update_user_password_by_id(user_id, hashed)
Auths.update_user_password_by_id(db, user_id, hashed)
Auths.update_email_by_id(user_id, form_data.email.lower())
Auths.update_email_by_id(db, user_id, form_data.email.lower())
updated_user = Users.update_user_by_id(
db,
user_id,
{
"name": form_data.name,
......@@ -253,9 +270,11 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
async def delete_user_by_id(
user_id: str, user=Depends(get_admin_user), db=Depends(get_db)
):
if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
result = Auths.delete_auth_by_id(db, user_id)
if result:
return True
......
from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status
from peewee import SqliteDatabase
from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel
......@@ -10,7 +9,6 @@ import markdown
import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url
......@@ -114,13 +112,15 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if not isinstance(DB, SqliteDatabase):
from apps.webui.internal.db import engine
if engine.name != "sqlite":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
)
return FileResponse(
DB.database,
engine.url.database,
media_type="application/octet-stream",
filename="webui.db",
)
......
import base64
import uuid
import subprocess
from contextlib import asynccontextmanager
from authlib.integrations.starlette_client import OAuth
......@@ -27,6 +28,8 @@ from fastapi.responses import JSONResponse
from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from sqlalchemy import text
from sqlalchemy.orm import Session
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.sessions import SessionMiddleware
......@@ -54,6 +57,7 @@ from apps.webui.main import (
get_pipe_models,
generate_function_chat_completion,
)
from apps.webui.internal.db import get_db, SessionLocal
from pydantic import BaseModel
......@@ -124,6 +128,8 @@ from config import (
WEBUI_SESSION_COOKIE_SAME_SITE,
WEBUI_SESSION_COOKIE_SECURE,
AppConfig,
BACKEND_DIR,
DATABASE_URL,
)
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from utils.webhook import post_webhook
......@@ -166,8 +172,19 @@ https://github.com/open-webui/open-webui
)
def run_migrations():
from alembic.config import Config
from alembic import command
alembic_cfg = Config(f"{BACKEND_DIR}/alembic.ini")
alembic_cfg.set_main_option("sqlalchemy.url", DATABASE_URL)
alembic_cfg.set_main_option("script_location", f"{BACKEND_DIR}/migrations")
command.upgrade(alembic_cfg, "head")
@asynccontextmanager
async def lifespan(app: FastAPI):
run_migrations()
yield
......@@ -393,6 +410,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
SessionLocal(),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
......@@ -736,6 +754,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
SessionLocal(),
)
try:
......@@ -781,7 +800,9 @@ app.add_middleware(
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
db = SessionLocal()
await get_all_models(db)
db.commit()
else:
pass
......@@ -815,12 +836,12 @@ app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models():
async def get_all_models(db: Session):
pipe_models = []
openai_models = []
ollama_models = []
pipe_models = await get_pipe_models()
pipe_models = await get_pipe_models(db)
if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models()
......@@ -842,7 +863,7 @@ async def get_all_models():
models = pipe_models + openai_models + ollama_models
custom_models = Models.get_all_models()
custom_models = Models.get_all_models(db)
for custom_model in custom_models:
if custom_model.base_model_id == None:
for model in models:
......@@ -882,8 +903,8 @@ async def get_all_models():
@app.get("/api/models")
async def get_models(user=Depends(get_verified_user)):
models = await get_all_models()
async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
models = await get_all_models(db)
# Filter out filter pipelines
models = [
......@@ -1584,9 +1605,12 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
@app.get("/api/pipelines/{pipeline_id}/valves")
async def get_pipeline_valves(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models()
models = await get_all_models(db)
r = None
try:
......@@ -1622,9 +1646,12 @@ async def get_pipeline_valves(
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec(
urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models()
models = await get_all_models(db)
r = None
try:
......@@ -1663,8 +1690,9 @@ async def update_pipeline_valves(
pipeline_id: str,
form_data: dict,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models()
models = await get_all_models(db)
r = None
try:
......@@ -2011,6 +2039,12 @@ async def healthcheck():
return {"status": True}
@app.get("/health/db")
async def healthcheck_with_db(db: Session = Depends(get_db)):
result = db.execute(text("SELECT 1;")).all()
return {"status": True}
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
......
Generic single-database configuration.
Create new migrations with
DATABASE_URL=<replace with actual url> alembic revision --autogenerate -m "a description"
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config
from sqlalchemy import pool
from alembic import context
from apps.webui.models.auths import Auth
from apps.webui.models.chats import Chat
from apps.webui.models.documents import Document
from apps.webui.models.memories import Memory
from apps.webui.models.models import Model
from apps.webui.models.prompts import Prompt
from apps.webui.models.tags import Tag, ChatIdTag
from apps.webui.models.tools import Tool
from apps.webui.models.users import User
from apps.webui.models.files import File
from apps.webui.models.functions import Function
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = Auth.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
database_url = os.getenv("DATABASE_URL", None)
if database_url:
config.set_main_option("sqlalchemy.url", database_url)
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}
"""init
Revision ID: 22b5ab2667b8
Revises:
Create Date: 2024-06-20 13:22:40.397002
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector
import apps.webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "22b5ab2667b8"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
con = op.get_bind()
inspector = Inspector.from_engine(con)
tables = set(inspector.get_table_names())
# ### commands auto generated by Alembic - please adjust! ###
if not "auth" in tables:
op.create_table(
"auth",
sa.Column("id", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("password", sa.String(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "chat" in tables:
op.create_table(
"chat",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("chat", sa.String(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.String(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("share_id"),
)
if not "chatidtag" in tables:
op.create_table(
"chatidtag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "document" in tables:
op.create_table(
"document",
sa.Column("collection_name", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("filename", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("collection_name"),
sa.UniqueConstraint("name"),
)
if not "memory" in tables:
op.create_table(
"memory",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "model" in tables:
op.create_table(
"model",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("base_model_id", sa.String(), nullable=True),
sa.Column("name", sa.String(), nullable=True),
sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "prompt" in tables:
op.create_table(
"prompt",
sa.Column("command", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"),
)
if not "tag" in tables:
op.create_table(
"tag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("data", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "tool" in tables:
op.create_table(
"tool",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "user" in tables:
op.create_table(
"user",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True),
sa.Column("profile_image_url", sa.String(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
)
if not "file" in tables:
op.create_table('file',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if not "function" in tables:
op.create_table('function',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('type', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# do nothing as we assume we had previous migrations from peewee-migrate
pass
# ### end Alembic commands ###
......@@ -12,8 +12,10 @@ passlib[bcrypt]==1.7.4
requests==2.32.2
aiohttp==3.9.5
peewee==3.17.5
peewee-migrate==1.12.2
sqlalchemy==2.0.30
alembic==1.13.1
# peewee==3.17.5
# peewee-migrate==1.12.2
psycopg2-binary==2.9.9
PyMySQL==1.1.1
bcrypt==4.1.3
......@@ -67,4 +69,9 @@ pytube==15.0.0
extract_msg
pydub
duckduckgo-search~=6.1.5
\ No newline at end of file
duckduckgo-search~=6.1.5
## Tests
docker~=7.1.0
pytest~=8.2.1
pytest-docker~=3.1.1
import pytest
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestAuths(AbstractPostgresTest):
BASE_PATH = "/api/v1/auths"
def setup_class(cls):
super().setup_class()
from apps.webui.models.users import Users
from apps.webui.models.auths import Auths
cls.users = Users
cls.auths = Auths
def test_get_session_user(self):
with mock_webui_user():
response = self.fast_api_client.get(self.create_url(""))
assert response.status_code == 200
assert response.json() == {
"id": "1",
"name": "John Doe",
"email": "john.doe@openwebui.com",
"role": "user",
"profile_image_url": "/user.png",
}
def test_update_profile(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(
self.create_url("/update/profile"),
json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
)
assert response.status_code == 200
db_user = self.users.get_user_by_id(self.db_session, user.id)
assert db_user.name == "John Doe 2"
assert db_user.profile_image_url == "/user2.png"
def test_update_password(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(
self.create_url("/update/password"),
json={"password": "old_password", "new_password": "new_password"},
)
assert response.status_code == 200
old_auth = self.auths.authenticate_user(
self.db_session, "john.doe@openwebui.com", "old_password"
)
assert old_auth is None
new_auth = self.auths.authenticate_user(
self.db_session, "john.doe@openwebui.com", "new_password"
)
assert new_auth is not None
def test_signin(self):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password=get_password_hash("password"),
name="John Doe",
profile_image_url="/user.png",
role="user",
)
response = self.fast_api_client.post(
self.create_url("/signin"),
json={"email": "john.doe@openwebui.com", "password": "password"},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == user.id
assert data["name"] == "John Doe"
assert data["email"] == "john.doe@openwebui.com"
assert data["role"] == "user"
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_signup(self):
response = self.fast_api_client.post(
self.create_url("/signup"),
json={
"name": "John Doe",
"email": "john.doe@openwebui.com",
"password": "password",
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] is not None and len(data["id"]) > 0
assert data["name"] == "John Doe"
assert data["email"] == "john.doe@openwebui.com"
assert data["role"] in ["admin", "user", "pending"]
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_add_user(self):
with mock_webui_user():
response = self.fast_api_client.post(
self.create_url("/add"),
json={
"name": "John Doe 2",
"email": "john.doe2@openwebui.com",
"password": "password2",
"role": "admin",
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] is not None and len(data["id"]) > 0
assert data["name"] == "John Doe 2"
assert data["email"] == "john.doe2@openwebui.com"
assert data["role"] == "admin"
assert data["profile_image_url"] == "/user.png"
assert data["token"] is not None and len(data["token"]) > 0
assert data["token_type"] == "Bearer"
def test_get_admin_details(self):
self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
with mock_webui_user():
response = self.fast_api_client.get(self.create_url("/admin/details"))
assert response.status_code == 200
assert response.json() == {
"name": "John Doe",
"email": "john.doe@openwebui.com",
}
def test_create_api_key_(self):
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
with mock_webui_user(id=user.id):
response = self.fast_api_client.post(self.create_url("/api_key"))
assert response.status_code == 200
data = response.json()
assert data["api_key"] is not None
assert len(data["api_key"]) > 0
def test_delete_api_key(self):
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
self.users.update_user_api_key_by_id(self.db_session, user.id, "abc")
with mock_webui_user(id=user.id):
response = self.fast_api_client.delete(self.create_url("/api_key"))
assert response.status_code == 200
assert response.json() == True
db_user = self.users.get_user_by_id(self.db_session, user.id)
assert db_user.api_key is None
def test_get_api_key(self):
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
self.users.update_user_api_key_by_id(self.db_session, user.id, "abc")
with mock_webui_user(id=user.id):
response = self.fast_api_client.get(self.create_url("/api_key"))
assert response.status_code == 200
assert response.json() == {"api_key": "abc"}
import uuid
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestChats(AbstractPostgresTest):
BASE_PATH = "/api/v1/chats"
def setup_class(cls):
super().setup_class()
def setup_method(self):
super().setup_method()
from apps.webui.models.chats import ChatForm
from apps.webui.models.chats import Chats
self.chats = Chats
self.chats.insert_new_chat(
self.db_session,
"2",
ChatForm(
**{
"chat": {
"name": "chat1",
"description": "chat1 description",
"tags": ["tag1", "tag2"],
"history": {"currentId": "1", "messages": []},
}
}
),
)
def test_get_session_user_chat_list(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
first_chat = response.json()[0]
assert first_chat["id"] is not None
assert first_chat["title"] == "New Chat"
assert first_chat["created_at"] is not None
assert first_chat["updated_at"] is not None
def test_delete_all_user_chats(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url("/"))
assert response.status_code == 200
assert len(self.chats.get_chats(self.db_session)) == 0
def test_get_user_chat_list_by_user_id(self):
with mock_webui_user(id="3"):
response = self.fast_api_client.get(self.create_url("/list/user/2"))
assert response.status_code == 200
first_chat = response.json()[0]
assert first_chat["id"] is not None
assert first_chat["title"] == "New Chat"
assert first_chat["created_at"] is not None
assert first_chat["updated_at"] is not None
def test_create_new_chat(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/new"),
json={
"chat": {
"name": "chat2",
"description": "chat2 description",
"tags": ["tag1", "tag2"],
}
},
)
assert response.status_code == 200
data = response.json()
assert data["archived"] is False
assert data["chat"] == {
"name": "chat2",
"description": "chat2 description",
"tags": ["tag1", "tag2"],
}
assert data["user_id"] == "2"
assert data["id"] is not None
assert data["share_id"] is None
assert data["title"] == "New Chat"
assert data["updated_at"] is not None
assert data["created_at"] is not None
assert len(self.chats.get_chats(self.db_session)) == 2
def test_get_user_chats(self):
self.test_get_session_user_chat_list()
def test_get_user_archived_chats(self):
self.chats.archive_all_chats_by_user_id(self.db_session, "2")
self.db_session.commit()
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/all/archived"))
assert response.status_code == 200
first_chat = response.json()[0]
assert first_chat["id"] is not None
assert first_chat["title"] == "New Chat"
assert first_chat["created_at"] is not None
assert first_chat["updated_at"] is not None
def test_get_all_user_chats_in_db(self):
with mock_webui_user(id="4"):
response = self.fast_api_client.get(self.create_url("/all/db"))
assert response.status_code == 200
assert len(response.json()) == 1
def test_get_archived_session_user_chat_list(self):
self.test_get_user_archived_chats()
def test_archive_all_chats(self):
with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url("/archive/all"))
assert response.status_code == 200
assert len(self.chats.get_archived_chats_by_user_id(self.db_session, "2")) == 1
def test_get_shared_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
self.chats.update_chat_share_id_by_id(self.db_session, chat_id, chat_id)
self.db_session.commit()
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
assert response.status_code == 200
data = response.json()
assert data["id"] == chat_id
assert data["chat"] == {
"name": "chat1",
"description": "chat1 description",
"tags": ["tag1", "tag2"],
"history": {"currentId": "1", "messages": []},
}
assert data["id"] == chat_id
assert data["share_id"] == chat_id
assert data["title"] == "New Chat"
def test_get_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
assert response.status_code == 200
data = response.json()
assert data["id"] == chat_id
assert data["chat"] == {
"name": "chat1",
"description": "chat1 description",
"tags": ["tag1", "tag2"],
"history": {"currentId": "1", "messages": []},
}
assert data["share_id"] is None
assert data["title"] == "New Chat"
assert data["user_id"] == "2"
def test_update_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url(f"/{chat_id}"),
json={
"chat": {
"name": "chat2",
"description": "chat2 description",
"tags": ["tag2", "tag4"],
"title": "Just another title",
}
},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == chat_id
assert data["chat"] == {
"name": "chat2",
"title": "Just another title",
"description": "chat2 description",
"tags": ["tag2", "tag4"],
"history": {"currentId": "1", "messages": []},
}
assert data["share_id"] is None
assert data["title"] == "Just another title"
assert data["user_id"] == "2"
def test_delete_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
assert response.status_code == 200
assert response.json() is True
def test_clone_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
assert response.status_code == 200
data = response.json()
assert data["id"] != chat_id
assert data["chat"] == {
"branchPointMessageId": "1",
"description": "chat1 description",
"history": {"currentId": "1", "messages": []},
"name": "chat1",
"originalChatId": chat_id,
"tags": ["tag1", "tag2"],
"title": "Clone of New Chat",
}
assert data["share_id"] is None
assert data["title"] == "Clone of New Chat"
assert data["user_id"] == "2"
def test_archive_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
assert response.status_code == 200
chat = self.chats.get_chat_by_id(self.db_session, chat_id)
assert chat.archived is True
def test_share_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
assert response.status_code == 200
chat = self.chats.get_chat_by_id(self.db_session, chat_id)
assert chat.share_id is not None
def test_delete_shared_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
share_id = str(uuid.uuid4())
self.chats.update_chat_share_id_by_id(self.db_session, chat_id, share_id)
self.db_session.commit()
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
assert response.status_code
chat = self.chats.get_chat_by_id(self.db_session, chat_id)
assert chat.share_id is None
from test.util.abstract_integration_test import AbstractPostgresTest
from test.util.mock_user import mock_webui_user
class TestDocuments(AbstractPostgresTest):
BASE_PATH = "/api/v1/documents"
def setup_class(cls):
super().setup_class()
from apps.webui.models.documents import Documents
cls.documents = Documents
def test_documents(self):
# Empty database
assert len(self.documents.get_docs(self.db_session)) == 0
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 0
# Create a new document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"name": "doc_name",
"title": "doc title",
"collection_name": "custom collection",
"filename": "doc_name.pdf",
"content": "",
},
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name"
assert len(self.documents.get_docs(self.db_session)) == 1
# Get the document
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/doc?name=doc_name"))
assert response.status_code == 200
data = response.json()
assert data["collection_name"] == "custom collection"
assert data["name"] == "doc_name"
assert data["title"] == "doc title"
assert data["filename"] == "doc_name.pdf"
assert data["content"] == {}
# Create another document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/create"),
json={
"name": "doc_name 2",
"title": "doc title 2",
"collection_name": "custom collection 2",
"filename": "doc_name2.pdf",
"content": "",
},
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name 2"
assert len(self.documents.get_docs(self.db_session)) == 2
# Get all documents
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
assert len(response.json()) == 2
# Update the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/doc/update?name=doc_name"),
json={"name": "doc_name rework", "title": "updated title"},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "doc_name rework"
assert data["title"] == "updated title"
# Tag the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url("/doc/tags"),
json={
"name": "doc_name rework",
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}],
},
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "doc_name rework"
assert data["content"] == {
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
}
assert len(self.documents.get_docs(self.db_session)) == 2
# Delete the first document
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
self.create_url("/doc/delete?name=doc_name rework")
)
assert response.status_code == 200
assert len(self.documents.get_docs(self.db_session)) == 1
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment