"vscode:/vscode.git/clone" did not exist on "453bafb96f48dfa1423bcec0a8cdc6e75f04fe1c"
Commit bee835cb authored by Jonathan Rohde's avatar Jonathan Rohde
Browse files

feat(sqlalchemy): remove session reference from router

parent df09d083
...@@ -7,7 +7,6 @@ from fastapi import APIRouter ...@@ -7,7 +7,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import logging import logging
from apps.webui.internal.db import get_db
from apps.webui.models.memories import Memories, MemoryModel from apps.webui.models.memories import Memories, MemoryModel
from utils.utils import get_verified_user from utils.utils import get_verified_user
...@@ -32,8 +31,8 @@ async def get_embeddings(request: Request): ...@@ -32,8 +31,8 @@ async def get_embeddings(request: Request):
@router.get("/", response_model=List[MemoryModel]) @router.get("/", response_model=List[MemoryModel])
async def get_memories(user=Depends(get_verified_user), db=Depends(get_db)): async def get_memories(user=Depends(get_verified_user)):
return Memories.get_memories_by_user_id(db, user.id) return Memories.get_memories_by_user_id(user.id)
############################ ############################
...@@ -54,9 +53,8 @@ async def add_memory( ...@@ -54,9 +53,8 @@ async def add_memory(
request: Request, request: Request,
form_data: AddMemoryForm, form_data: AddMemoryForm,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db=Depends(get_db),
): ):
memory = Memories.insert_new_memory(db, user.id, form_data.content) memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
...@@ -76,9 +74,8 @@ async def update_memory_by_id( ...@@ -76,9 +74,8 @@ async def update_memory_by_id(
request: Request, request: Request,
form_data: MemoryUpdateModel, form_data: MemoryUpdateModel,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db=Depends(get_db),
): ):
memory = Memories.update_memory_by_id(db, memory_id, form_data.content) memory = Memories.update_memory_by_id(memory_id, form_data.content)
if memory is None: if memory is None:
raise HTTPException(status_code=404, detail="Memory not found") raise HTTPException(status_code=404, detail="Memory not found")
...@@ -129,12 +126,12 @@ async def query_memory( ...@@ -129,12 +126,12 @@ async def query_memory(
############################ ############################
@router.get("/reset", response_model=bool) @router.get("/reset", response_model=bool)
async def reset_memory_from_vector_db( async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user), db=Depends(get_db) request: Request, user=Depends(get_verified_user)
): ):
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}") CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}") collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(db, user.id) memories = Memories.get_memories_by_user_id(user.id)
for memory in memories: for memory in memories:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content) memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection.upsert( collection.upsert(
...@@ -151,8 +148,8 @@ async def reset_memory_from_vector_db( ...@@ -151,8 +148,8 @@ async def reset_memory_from_vector_db(
@router.delete("/user", response_model=bool) @router.delete("/user", response_model=bool)
async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(get_db)): async def delete_memory_by_user_id(user=Depends(get_verified_user)):
result = Memories.delete_memories_by_user_id(db, user.id) result = Memories.delete_memories_by_user_id(user.id)
if result: if result:
try: try:
...@@ -171,9 +168,9 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(g ...@@ -171,9 +168,9 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(g
@router.delete("/{memory_id}", response_model=bool) @router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id( async def delete_memory_by_id(
memory_id: str, user=Depends(get_verified_user), db=Depends(get_db) memory_id: str, user=Depends(get_verified_user)
): ):
result = Memories.delete_memory_by_id_and_user_id(db, memory_id, user.id) result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result: if result:
collection = CHROMA_CLIENT.get_or_create_collection( collection = CHROMA_CLIENT.get_or_create_collection(
......
...@@ -6,7 +6,6 @@ from fastapi import APIRouter ...@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user from utils.utils import get_verified_user, get_admin_user
...@@ -20,8 +19,8 @@ router = APIRouter() ...@@ -20,8 +19,8 @@ router = APIRouter()
@router.get("/", response_model=List[ModelResponse]) @router.get("/", response_model=List[ModelResponse])
async def get_models(user=Depends(get_verified_user), db=Depends(get_db)): async def get_models(user=Depends(get_verified_user)):
return Models.get_all_models(db) return Models.get_all_models()
############################ ############################
...@@ -34,7 +33,6 @@ async def add_new_model( ...@@ -34,7 +33,6 @@ async def add_new_model(
request: Request, request: Request,
form_data: ModelForm, form_data: ModelForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
if form_data.id in request.app.state.MODELS: if form_data.id in request.app.state.MODELS:
raise HTTPException( raise HTTPException(
...@@ -42,7 +40,7 @@ async def add_new_model( ...@@ -42,7 +40,7 @@ async def add_new_model(
detail=ERROR_MESSAGES.MODEL_ID_TAKEN, detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
) )
else: else:
model = Models.insert_new_model(db, form_data, user.id) model = Models.insert_new_model(form_data, user.id)
if model: if model:
return model return model
...@@ -59,8 +57,8 @@ async def add_new_model( ...@@ -59,8 +57,8 @@ async def add_new_model(
@router.get("/{id}", response_model=Optional[ModelModel]) @router.get("/{id}", response_model=Optional[ModelModel])
async def get_model_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(db, id) model = Models.get_model_by_id(id)
if model: if model:
return model return model
...@@ -82,15 +80,14 @@ async def update_model_by_id( ...@@ -82,15 +80,14 @@ async def update_model_by_id(
id: str, id: str,
form_data: ModelForm, form_data: ModelForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
model = Models.get_model_by_id(db, id) model = Models.get_model_by_id(id)
if model: if model:
model = Models.update_model_by_id(db, id, form_data) model = Models.update_model_by_id(id, form_data)
return model return model
else: else:
if form_data.id in request.app.state.MODELS: if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(db, form_data, user.id) model = Models.insert_new_model(form_data, user.id)
if model: if model:
return model return model
else: else:
...@@ -111,6 +108,6 @@ async def update_model_by_id( ...@@ -111,6 +108,6 @@ async def update_model_by_id(
@router.delete("/delete", response_model=bool) @router.delete("/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
result = Models.delete_model_by_id(db, id) result = Models.delete_model_by_id(id)
return result return result
...@@ -6,7 +6,6 @@ from fastapi import APIRouter ...@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.prompts import Prompts, PromptForm, PromptModel from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
...@@ -20,8 +19,8 @@ router = APIRouter() ...@@ -20,8 +19,8 @@ router = APIRouter()
@router.get("/", response_model=List[PromptModel]) @router.get("/", response_model=List[PromptModel])
async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)): async def get_prompts(user=Depends(get_current_user)):
return Prompts.get_prompts(db) return Prompts.get_prompts()
############################ ############################
...@@ -31,11 +30,11 @@ async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)): ...@@ -31,11 +30,11 @@ async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)):
@router.post("/create", response_model=Optional[PromptModel]) @router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt( async def create_new_prompt(
form_data: PromptForm, user=Depends(get_admin_user), db=Depends(get_db) form_data: PromptForm, user=Depends(get_admin_user)
): ):
prompt = Prompts.get_prompt_by_command(db, form_data.command) prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt == None: if prompt == None:
prompt = Prompts.insert_new_prompt(db, user.id, form_data) prompt = Prompts.insert_new_prompt(user.id, form_data)
if prompt: if prompt:
return prompt return prompt
...@@ -56,9 +55,9 @@ async def create_new_prompt( ...@@ -56,9 +55,9 @@ async def create_new_prompt(
@router.get("/command/{command}", response_model=Optional[PromptModel]) @router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command( async def get_prompt_by_command(
command: str, user=Depends(get_current_user), db=Depends(get_db) command: str, user=Depends(get_current_user)
): ):
prompt = Prompts.get_prompt_by_command(db, f"/{command}") prompt = Prompts.get_prompt_by_command(f"/{command}")
if prompt: if prompt:
return prompt return prompt
...@@ -79,9 +78,8 @@ async def update_prompt_by_command( ...@@ -79,9 +78,8 @@ async def update_prompt_by_command(
command: str, command: str,
form_data: PromptForm, form_data: PromptForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
prompt = Prompts.update_prompt_by_command(db, f"/{command}", form_data) prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt: if prompt:
return prompt return prompt
else: else:
...@@ -98,7 +96,7 @@ async def update_prompt_by_command( ...@@ -98,7 +96,7 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool) @router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command( async def delete_prompt_by_command(
command: str, user=Depends(get_admin_user), db=Depends(get_db) command: str, user=Depends(get_admin_user)
): ):
result = Prompts.delete_prompt_by_command(db, f"/{command}") result = Prompts.delete_prompt_by_command(f"/{command}")
return result return result
...@@ -6,7 +6,6 @@ from fastapi import APIRouter ...@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id from apps.webui.utils import load_toolkit_module_by_id
...@@ -34,7 +33,7 @@ router = APIRouter() ...@@ -34,7 +33,7 @@ router = APIRouter()
@router.get("/", response_model=List[ToolResponse]) @router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)): async def get_toolkits(user=Depends(get_verified_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()] toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits return toolkits
...@@ -45,8 +44,8 @@ async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)): ...@@ -45,8 +44,8 @@ async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)):
@router.get("/export", response_model=List[ToolModel]) @router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user), db=Depends(get_db)): async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools(db)] toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits return toolkits
...@@ -60,7 +59,6 @@ async def create_new_toolkit( ...@@ -60,7 +59,6 @@ async def create_new_toolkit(
request: Request, request: Request,
form_data: ToolForm, form_data: ToolForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
if not form_data.id.isidentifier(): if not form_data.id.isidentifier():
raise HTTPException( raise HTTPException(
...@@ -70,7 +68,7 @@ async def create_new_toolkit( ...@@ -70,7 +68,7 @@ async def create_new_toolkit(
form_data.id = form_data.id.lower() form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(db, form_data.id) toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None: if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py") toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try: try:
...@@ -84,7 +82,7 @@ async def create_new_toolkit( ...@@ -84,7 +82,7 @@ async def create_new_toolkit(
TOOLS[form_data.id] = toolkit_module TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id]) specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(db, user.id, form_data, specs) toolkit = Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True) tool_cache_dir.mkdir(parents=True, exist_ok=True)
...@@ -115,8 +113,8 @@ async def create_new_toolkit( ...@@ -115,8 +113,8 @@ async def create_new_toolkit(
@router.get("/id/{id}", response_model=Optional[ToolModel]) @router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(db, id) toolkit = Tools.get_tool_by_id(id)
if toolkit: if toolkit:
return toolkit return toolkit
...@@ -138,7 +136,6 @@ async def update_toolkit_by_id( ...@@ -138,7 +136,6 @@ async def update_toolkit_by_id(
id: str, id: str,
form_data: ToolForm, form_data: ToolForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py") toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
...@@ -160,7 +157,7 @@ async def update_toolkit_by_id( ...@@ -160,7 +157,7 @@ async def update_toolkit_by_id(
} }
print(updated) print(updated)
toolkit = Tools.update_tool_by_id(db, id, updated) toolkit = Tools.update_tool_by_id(id, updated)
if toolkit: if toolkit:
return toolkit return toolkit
...@@ -184,9 +181,9 @@ async def update_toolkit_by_id( ...@@ -184,9 +181,9 @@ async def update_toolkit_by_id(
@router.delete("/id/{id}/delete", response_model=bool) @router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id( async def delete_toolkit_by_id(
request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db) request: Request, id: str, user=Depends(get_admin_user)
): ):
result = Tools.delete_tool_by_id(db, id) result = Tools.delete_tool_by_id(id)
if result: if result:
TOOLS = request.app.state.TOOLS TOOLS = request.app.state.TOOLS
......
...@@ -9,7 +9,6 @@ import time ...@@ -9,7 +9,6 @@ import time
import uuid import uuid
import logging import logging
from apps.webui.internal.db import get_db
from apps.webui.models.users import ( from apps.webui.models.users import (
UserModel, UserModel,
UserUpdateForm, UserUpdateForm,
...@@ -42,9 +41,9 @@ router = APIRouter() ...@@ -42,9 +41,9 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel]) @router.get("/", response_model=List[UserModel])
async def get_users( async def get_users(
skip: int = 0, limit: int = 50, user=Depends(get_admin_user), db=Depends(get_db) skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
): ):
return Users.get_users(db, skip, limit) return Users.get_users(skip, limit)
############################ ############################
...@@ -72,11 +71,11 @@ async def update_user_permissions( ...@@ -72,11 +71,11 @@ async def update_user_permissions(
@router.post("/update/role", response_model=Optional[UserModel]) @router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role( async def update_user_role(
form_data: UserRoleUpdateForm, user=Depends(get_admin_user), db=Depends(get_db) form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
): ):
if user.id != form_data.id and form_data.id != Users.get_first_user(db).id: if user.id != form_data.id and form_data.id != Users.get_first_user().id:
return Users.update_user_role_by_id(db, form_data.id, form_data.role) return Users.update_user_role_by_id(form_data.id, form_data.role)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
...@@ -91,9 +90,9 @@ async def update_user_role( ...@@ -91,9 +90,9 @@ async def update_user_role(
@router.get("/user/settings", response_model=Optional[UserSettings]) @router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user( async def get_user_settings_by_session_user(
user=Depends(get_verified_user), db=Depends(get_db) user=Depends(get_verified_user)
): ):
user = Users.get_user_by_id(db, user.id) user = Users.get_user_by_id(user.id)
if user: if user:
return user.settings return user.settings
else: else:
...@@ -110,9 +109,9 @@ async def get_user_settings_by_session_user( ...@@ -110,9 +109,9 @@ async def get_user_settings_by_session_user(
@router.post("/user/settings/update", response_model=UserSettings) @router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user( async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user), db=Depends(get_db) form_data: UserSettings, user=Depends(get_verified_user)
): ):
user = Users.update_user_by_id(db, user.id, {"settings": form_data.model_dump()}) user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
if user: if user:
return user.settings return user.settings
else: else:
...@@ -129,9 +128,9 @@ async def update_user_settings_by_session_user( ...@@ -129,9 +128,9 @@ async def update_user_settings_by_session_user(
@router.get("/user/info", response_model=Optional[dict]) @router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user( async def get_user_info_by_session_user(
user=Depends(get_verified_user), db=Depends(get_db) user=Depends(get_verified_user)
): ):
user = Users.get_user_by_id(db, user.id) user = Users.get_user_by_id(user.id)
if user: if user:
return user.info return user.info
else: else:
...@@ -148,15 +147,15 @@ async def get_user_info_by_session_user( ...@@ -148,15 +147,15 @@ async def get_user_info_by_session_user(
@router.post("/user/info/update", response_model=Optional[dict]) @router.post("/user/info/update", response_model=Optional[dict])
async def update_user_info_by_session_user( async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user), db=Depends(get_db) form_data: dict, user=Depends(get_verified_user)
): ):
user = Users.get_user_by_id(db, user.id) user = Users.get_user_by_id(user.id)
if user: if user:
if user.info is None: if user.info is None:
user.info = {} user.info = {}
user = Users.update_user_by_id( user = Users.update_user_by_id(
db, user.id, {"info": {**user.info, **form_data}} user.id, {"info": {**user.info, **form_data}}
) )
if user: if user:
return user.info return user.info
...@@ -184,14 +183,14 @@ class UserResponse(BaseModel): ...@@ -184,14 +183,14 @@ class UserResponse(BaseModel):
@router.get("/{user_id}", response_model=UserResponse) @router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id( async def get_user_by_id(
user_id: str, user=Depends(get_verified_user), db=Depends(get_db) user_id: str, user=Depends(get_verified_user)
): ):
# Check if user_id is a shared chat # Check if user_id is a shared chat
# If it is, get the user_id from the chat # If it is, get the user_id from the chat
if user_id.startswith("shared-"): if user_id.startswith("shared-"):
chat_id = user_id.replace("shared-", "") chat_id = user_id.replace("shared-", "")
chat = Chats.get_chat_by_id(db, chat_id) chat = Chats.get_chat_by_id(chat_id)
if chat: if chat:
user_id = chat.user_id user_id = chat.user_id
else: else:
...@@ -200,7 +199,7 @@ async def get_user_by_id( ...@@ -200,7 +199,7 @@ async def get_user_by_id(
detail=ERROR_MESSAGES.USER_NOT_FOUND, detail=ERROR_MESSAGES.USER_NOT_FOUND,
) )
user = Users.get_user_by_id(db, user_id) user = Users.get_user_by_id(user_id)
if user: if user:
return UserResponse(name=user.name, profile_image_url=user.profile_image_url) return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
...@@ -221,13 +220,12 @@ async def update_user_by_id( ...@@ -221,13 +220,12 @@ async def update_user_by_id(
user_id: str, user_id: str,
form_data: UserUpdateForm, form_data: UserUpdateForm,
session_user=Depends(get_admin_user), session_user=Depends(get_admin_user),
db=Depends(get_db),
): ):
user = Users.get_user_by_id(db, user_id) user = Users.get_user_by_id(user_id)
if user: if user:
if form_data.email.lower() != user.email: if form_data.email.lower() != user.email:
email_user = Users.get_user_by_email(db, form_data.email.lower()) email_user = Users.get_user_by_email(form_data.email.lower())
if email_user: if email_user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
...@@ -237,11 +235,10 @@ async def update_user_by_id( ...@@ -237,11 +235,10 @@ async def update_user_by_id(
if form_data.password: if form_data.password:
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
log.debug(f"hashed: {hashed}") log.debug(f"hashed: {hashed}")
Auths.update_user_password_by_id(db, user_id, hashed) Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(db, user_id, form_data.email.lower()) Auths.update_email_by_id(user_id, form_data.email.lower())
updated_user = Users.update_user_by_id( updated_user = Users.update_user_by_id(
db,
user_id, user_id,
{ {
"name": form_data.name, "name": form_data.name,
...@@ -271,10 +268,10 @@ async def update_user_by_id( ...@@ -271,10 +268,10 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool) @router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id( async def delete_user_by_id(
user_id: str, user=Depends(get_admin_user), db=Depends(get_db) user_id: str, user=Depends(get_admin_user)
): ):
if user.id != user_id: if user.id != user_id:
result = Auths.delete_auth_by_id(db, user_id) result = Auths.delete_auth_by_id(user_id)
if result: if result:
return True return True
......
...@@ -57,7 +57,7 @@ from apps.webui.main import ( ...@@ -57,7 +57,7 @@ from apps.webui.main import (
get_pipe_models, get_pipe_models,
generate_function_chat_completion, generate_function_chat_completion,
) )
from apps.webui.internal.db import get_db, SessionLocal from apps.webui.internal.db import get_session, SessionLocal
from pydantic import BaseModel from pydantic import BaseModel
...@@ -410,7 +410,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -410,7 +410,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user = get_current_user( user = get_current_user(
request, request,
get_http_authorization_cred(request.headers.get("Authorization")), get_http_authorization_cred(request.headers.get("Authorization")),
SessionLocal(),
) )
# Flag to skip RAG completions if file_handler is present in tools/functions # Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False skip_files = False
...@@ -800,9 +799,7 @@ app.add_middleware( ...@@ -800,9 +799,7 @@ app.add_middleware(
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0: if len(app.state.MODELS) == 0:
db = SessionLocal() await get_all_models()
await get_all_models(db)
db.commit()
else: else:
pass pass
...@@ -836,12 +833,12 @@ app.mount("/api/v1", webui_app) ...@@ -836,12 +833,12 @@ app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models(db: Session): async def get_all_models():
pipe_models = [] pipe_models = []
openai_models = [] openai_models = []
ollama_models = [] ollama_models = []
pipe_models = await get_pipe_models(db) pipe_models = await get_pipe_models()
if app.state.config.ENABLE_OPENAI_API: if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models() openai_models = await get_openai_models()
...@@ -863,7 +860,7 @@ async def get_all_models(db: Session): ...@@ -863,7 +860,7 @@ async def get_all_models(db: Session):
models = pipe_models + openai_models + ollama_models models = pipe_models + openai_models + ollama_models
custom_models = Models.get_all_models(db) custom_models = Models.get_all_models()
for custom_model in custom_models: for custom_model in custom_models:
if custom_model.base_model_id == None: if custom_model.base_model_id == None:
for model in models: for model in models:
...@@ -903,8 +900,8 @@ async def get_all_models(db: Session): ...@@ -903,8 +900,8 @@ async def get_all_models(db: Session):
@app.get("/api/models") @app.get("/api/models")
async def get_models(user=Depends(get_verified_user), db=Depends(get_db)): async def get_models(user=Depends(get_verified_user)):
models = await get_all_models(db) models = await get_all_models()
# Filter out filter pipelines # Filter out filter pipelines
models = [ models = [
...@@ -1608,9 +1605,8 @@ async def get_pipeline_valves( ...@@ -1608,9 +1605,8 @@ async def get_pipeline_valves(
urlIdx: Optional[int], urlIdx: Optional[int],
pipeline_id: str, pipeline_id: str,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
models = await get_all_models(db) models = await get_all_models()
r = None r = None
try: try:
...@@ -1649,9 +1645,8 @@ async def get_pipeline_valves_spec( ...@@ -1649,9 +1645,8 @@ async def get_pipeline_valves_spec(
urlIdx: Optional[int], urlIdx: Optional[int],
pipeline_id: str, pipeline_id: str,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
models = await get_all_models(db) models = await get_all_models()
r = None r = None
try: try:
...@@ -1690,9 +1685,8 @@ async def update_pipeline_valves( ...@@ -1690,9 +1685,8 @@ async def update_pipeline_valves(
pipeline_id: str, pipeline_id: str,
form_data: dict, form_data: dict,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
models = await get_all_models(db) models = await get_all_models()
r = None r = None
try: try:
...@@ -2040,8 +2034,9 @@ async def healthcheck(): ...@@ -2040,8 +2034,9 @@ async def healthcheck():
@app.get("/health/db") @app.get("/health/db")
async def healthcheck_with_db(db: Session = Depends(get_db)): async def healthcheck_with_db():
result = db.execute(text("SELECT 1;")).all() with get_session() as db:
result = db.execute(text("SELECT 1;")).all()
return {"status": True} return {"status": True}
......
"""init
Revision ID: 22b5ab2667b8
Revises:
Create Date: 2024-06-20 13:22:40.397002
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector
import apps.webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "22b5ab2667b8"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
con = op.get_bind()
inspector = Inspector.from_engine(con)
tables = set(inspector.get_table_names())
# ### commands auto generated by Alembic - please adjust! ###
if not "auth" in tables:
op.create_table(
"auth",
sa.Column("id", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("password", sa.String(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "chat" in tables:
op.create_table(
"chat",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("chat", sa.String(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.String(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("share_id"),
)
if not "chatidtag" in tables:
op.create_table(
"chatidtag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "document" in tables:
op.create_table(
"document",
sa.Column("collection_name", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("filename", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("collection_name"),
sa.UniqueConstraint("name"),
)
if not "memory" in tables:
op.create_table(
"memory",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "model" in tables:
op.create_table(
"model",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("base_model_id", sa.String(), nullable=True),
sa.Column("name", sa.String(), nullable=True),
sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "prompt" in tables:
op.create_table(
"prompt",
sa.Column("command", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"),
)
if not "tag" in tables:
op.create_table(
"tag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("data", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "tool" in tables:
op.create_table(
"tool",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "user" in tables:
op.create_table(
"user",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True),
sa.Column("profile_image_url", sa.String(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
)
if not "file" in tables:
op.create_table('file',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if not "function" in tables:
op.create_table('function',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('type', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# do nothing as we assume we had previous migrations from peewee-migrate
pass
# ### end Alembic commands ###
"""init
Revision ID: ba76b0bae648
Revises:
Create Date: 2024-06-24 09:09:11.636336
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
# revision identifiers, used by Alembic.
revision: str = 'ba76b0bae648'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('auth',
sa.Column('id', sa.String(), nullable=False),
sa.Column('email', sa.String(), nullable=True),
sa.Column('password', sa.String(), nullable=True),
sa.Column('active', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('chat',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('chat', sa.String(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('share_id', sa.String(), nullable=True),
sa.Column('archived', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('share_id')
)
op.create_table('chatidtag',
sa.Column('id', sa.String(), nullable=False),
sa.Column('tag_name', sa.String(), nullable=True),
sa.Column('chat_id', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('document',
sa.Column('collection_name', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('collection_name'),
sa.UniqueConstraint('name')
)
op.create_table('file',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('function',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('type', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('memory',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('model',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('base_model_id', sa.String(), nullable=True),
sa.Column('name', sa.String(), nullable=True),
sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('prompt',
sa.Column('command', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('command')
)
op.create_table('tag',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('data', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('tool',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('email', sa.String(), nullable=True),
sa.Column('role', sa.String(), nullable=True),
sa.Column('profile_image_url', sa.String(), nullable=True),
sa.Column('last_active_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('api_key', sa.String(), nullable=True),
sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('api_key')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('user')
op.drop_table('tool')
op.drop_table('tag')
op.drop_table('prompt')
op.drop_table('model')
op.drop_table('memory')
op.drop_table('function')
op.drop_table('file')
op.drop_table('document')
op.drop_table('chatidtag')
op.drop_table('chat')
op.drop_table('auth')
# ### end Alembic commands ###
...@@ -31,7 +31,6 @@ class TestAuths(AbstractPostgresTest): ...@@ -31,7 +31,6 @@ class TestAuths(AbstractPostgresTest):
from utils.utils import get_password_hash from utils.utils import get_password_hash
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password=get_password_hash("old_password"), password=get_password_hash("old_password"),
name="John Doe", name="John Doe",
...@@ -45,7 +44,7 @@ class TestAuths(AbstractPostgresTest): ...@@ -45,7 +44,7 @@ class TestAuths(AbstractPostgresTest):
json={"name": "John Doe 2", "profile_image_url": "/user2.png"}, json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
) )
assert response.status_code == 200 assert response.status_code == 200
db_user = self.users.get_user_by_id(self.db_session, user.id) db_user = self.users.get_user_by_id(user.id)
assert db_user.name == "John Doe 2" assert db_user.name == "John Doe 2"
assert db_user.profile_image_url == "/user2.png" assert db_user.profile_image_url == "/user2.png"
...@@ -53,7 +52,6 @@ class TestAuths(AbstractPostgresTest): ...@@ -53,7 +52,6 @@ class TestAuths(AbstractPostgresTest):
from utils.utils import get_password_hash from utils.utils import get_password_hash
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password=get_password_hash("old_password"), password=get_password_hash("old_password"),
name="John Doe", name="John Doe",
...@@ -69,11 +67,11 @@ class TestAuths(AbstractPostgresTest): ...@@ -69,11 +67,11 @@ class TestAuths(AbstractPostgresTest):
assert response.status_code == 200 assert response.status_code == 200
old_auth = self.auths.authenticate_user( old_auth = self.auths.authenticate_user(
self.db_session, "john.doe@openwebui.com", "old_password" "john.doe@openwebui.com", "old_password"
) )
assert old_auth is None assert old_auth is None
new_auth = self.auths.authenticate_user( new_auth = self.auths.authenticate_user(
self.db_session, "john.doe@openwebui.com", "new_password" "john.doe@openwebui.com", "new_password"
) )
assert new_auth is not None assert new_auth is not None
...@@ -81,7 +79,6 @@ class TestAuths(AbstractPostgresTest): ...@@ -81,7 +79,6 @@ class TestAuths(AbstractPostgresTest):
from utils.utils import get_password_hash from utils.utils import get_password_hash
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password=get_password_hash("password"), password=get_password_hash("password"),
name="John Doe", name="John Doe",
...@@ -144,7 +141,6 @@ class TestAuths(AbstractPostgresTest): ...@@ -144,7 +141,6 @@ class TestAuths(AbstractPostgresTest):
def test_get_admin_details(self): def test_get_admin_details(self):
self.auths.insert_new_auth( self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password="password", password="password",
name="John Doe", name="John Doe",
...@@ -162,7 +158,6 @@ class TestAuths(AbstractPostgresTest): ...@@ -162,7 +158,6 @@ class TestAuths(AbstractPostgresTest):
def test_create_api_key_(self): def test_create_api_key_(self):
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password="password", password="password",
name="John Doe", name="John Doe",
...@@ -178,31 +173,29 @@ class TestAuths(AbstractPostgresTest): ...@@ -178,31 +173,29 @@ class TestAuths(AbstractPostgresTest):
def test_delete_api_key(self): def test_delete_api_key(self):
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password="password", password="password",
name="John Doe", name="John Doe",
profile_image_url="/user.png", profile_image_url="/user.png",
role="admin", role="admin",
) )
self.users.update_user_api_key_by_id(self.db_session, user.id, "abc") self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id): with mock_webui_user(id=user.id):
response = self.fast_api_client.delete(self.create_url("/api_key")) response = self.fast_api_client.delete(self.create_url("/api_key"))
assert response.status_code == 200 assert response.status_code == 200
assert response.json() == True assert response.json() == True
db_user = self.users.get_user_by_id(self.db_session, user.id) db_user = self.users.get_user_by_id(user.id)
assert db_user.api_key is None assert db_user.api_key is None
def test_get_api_key(self): def test_get_api_key(self):
user = self.auths.insert_new_auth( user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com", email="john.doe@openwebui.com",
password="password", password="password",
name="John Doe", name="John Doe",
profile_image_url="/user.png", profile_image_url="/user.png",
role="admin", role="admin",
) )
self.users.update_user_api_key_by_id(self.db_session, user.id, "abc") self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id): with mock_webui_user(id=user.id):
response = self.fast_api_client.get(self.create_url("/api_key")) response = self.fast_api_client.get(self.create_url("/api_key"))
assert response.status_code == 200 assert response.status_code == 200
......
...@@ -18,7 +18,6 @@ class TestChats(AbstractPostgresTest): ...@@ -18,7 +18,6 @@ class TestChats(AbstractPostgresTest):
self.chats = Chats self.chats = Chats
self.chats.insert_new_chat( self.chats.insert_new_chat(
self.db_session,
"2", "2",
ChatForm( ChatForm(
**{ **{
...@@ -46,7 +45,7 @@ class TestChats(AbstractPostgresTest): ...@@ -46,7 +45,7 @@ class TestChats(AbstractPostgresTest):
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url("/")) response = self.fast_api_client.delete(self.create_url("/"))
assert response.status_code == 200 assert response.status_code == 200
assert len(self.chats.get_chats(self.db_session)) == 0 assert len(self.chats.get_chats()) == 0
def test_get_user_chat_list_by_user_id(self): def test_get_user_chat_list_by_user_id(self):
with mock_webui_user(id="3"): with mock_webui_user(id="3"):
...@@ -84,14 +83,13 @@ class TestChats(AbstractPostgresTest): ...@@ -84,14 +83,13 @@ class TestChats(AbstractPostgresTest):
assert data["title"] == "New Chat" assert data["title"] == "New Chat"
assert data["updated_at"] is not None assert data["updated_at"] is not None
assert data["created_at"] is not None assert data["created_at"] is not None
assert len(self.chats.get_chats(self.db_session)) == 2 assert len(self.chats.get_chats()) == 2
def test_get_user_chats(self): def test_get_user_chats(self):
self.test_get_session_user_chat_list() self.test_get_session_user_chat_list()
def test_get_user_archived_chats(self): def test_get_user_archived_chats(self):
self.chats.archive_all_chats_by_user_id(self.db_session, "2") self.chats.archive_all_chats_by_user_id("2")
self.db_session.commit()
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/all/archived")) response = self.fast_api_client.get(self.create_url("/all/archived"))
assert response.status_code == 200 assert response.status_code == 200
...@@ -114,12 +112,11 @@ class TestChats(AbstractPostgresTest): ...@@ -114,12 +112,11 @@ class TestChats(AbstractPostgresTest):
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url("/archive/all")) response = self.fast_api_client.post(self.create_url("/archive/all"))
assert response.status_code == 200 assert response.status_code == 200
assert len(self.chats.get_archived_chats_by_user_id(self.db_session, "2")) == 1 assert len(self.chats.get_archived_chats_by_user_id("2")) == 1
def test_get_shared_chat_by_id(self): def test_get_shared_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
self.chats.update_chat_share_id_by_id(self.db_session, chat_id, chat_id) self.chats.update_chat_share_id_by_id(chat_id, chat_id)
self.db_session.commit()
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}")) response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
assert response.status_code == 200 assert response.status_code == 200
...@@ -136,7 +133,7 @@ class TestChats(AbstractPostgresTest): ...@@ -136,7 +133,7 @@ class TestChats(AbstractPostgresTest):
assert data["title"] == "New Chat" assert data["title"] == "New Chat"
def test_get_chat_by_id(self): def test_get_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}")) response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
assert response.status_code == 200 assert response.status_code == 200
...@@ -153,7 +150,7 @@ class TestChats(AbstractPostgresTest): ...@@ -153,7 +150,7 @@ class TestChats(AbstractPostgresTest):
assert data["user_id"] == "2" assert data["user_id"] == "2"
def test_update_chat_by_id(self): def test_update_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.post( response = self.fast_api_client.post(
self.create_url(f"/{chat_id}"), self.create_url(f"/{chat_id}"),
...@@ -181,14 +178,14 @@ class TestChats(AbstractPostgresTest): ...@@ -181,14 +178,14 @@ class TestChats(AbstractPostgresTest):
assert data["user_id"] == "2" assert data["user_id"] == "2"
def test_delete_chat_by_id(self): def test_delete_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}")) response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
assert response.status_code == 200 assert response.status_code == 200
assert response.json() is True assert response.json() is True
def test_clone_chat_by_id(self): def test_clone_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone")) response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
...@@ -209,31 +206,30 @@ class TestChats(AbstractPostgresTest): ...@@ -209,31 +206,30 @@ class TestChats(AbstractPostgresTest):
assert data["user_id"] == "2" assert data["user_id"] == "2"
def test_archive_chat_by_id(self): def test_archive_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive")) response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
assert response.status_code == 200 assert response.status_code == 200
chat = self.chats.get_chat_by_id(self.db_session, chat_id) chat = self.chats.get_chat_by_id(chat_id)
assert chat.archived is True assert chat.archived is True
def test_share_chat_by_id(self): def test_share_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share")) response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
assert response.status_code == 200 assert response.status_code == 200
chat = self.chats.get_chat_by_id(self.db_session, chat_id) chat = self.chats.get_chat_by_id(chat_id)
assert chat.share_id is not None assert chat.share_id is not None
def test_delete_shared_chat_by_id(self): def test_delete_shared_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id chat_id = self.chats.get_chats()[0].id
share_id = str(uuid.uuid4()) share_id = str(uuid.uuid4())
self.chats.update_chat_share_id_by_id(self.db_session, chat_id, share_id) self.chats.update_chat_share_id_by_id(chat_id, share_id)
self.db_session.commit()
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share")) response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
assert response.status_code assert response.status_code
chat = self.chats.get_chat_by_id(self.db_session, chat_id) chat = self.chats.get_chat_by_id(chat_id)
assert chat.share_id is None assert chat.share_id is None
...@@ -14,7 +14,7 @@ class TestDocuments(AbstractPostgresTest): ...@@ -14,7 +14,7 @@ class TestDocuments(AbstractPostgresTest):
def test_documents(self): def test_documents(self):
# Empty database # Empty database
assert len(self.documents.get_docs(self.db_session)) == 0 assert len(self.documents.get_docs()) == 0
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/")) response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200 assert response.status_code == 200
...@@ -34,7 +34,7 @@ class TestDocuments(AbstractPostgresTest): ...@@ -34,7 +34,7 @@ class TestDocuments(AbstractPostgresTest):
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "doc_name" assert response.json()["name"] == "doc_name"
assert len(self.documents.get_docs(self.db_session)) == 1 assert len(self.documents.get_docs()) == 1
# Get the document # Get the document
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
...@@ -61,7 +61,7 @@ class TestDocuments(AbstractPostgresTest): ...@@ -61,7 +61,7 @@ class TestDocuments(AbstractPostgresTest):
) )
assert response.status_code == 200 assert response.status_code == 200
assert response.json()["name"] == "doc_name 2" assert response.json()["name"] == "doc_name 2"
assert len(self.documents.get_docs(self.db_session)) == 2 assert len(self.documents.get_docs()) == 2
# Get all documents # Get all documents
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
...@@ -95,7 +95,7 @@ class TestDocuments(AbstractPostgresTest): ...@@ -95,7 +95,7 @@ class TestDocuments(AbstractPostgresTest):
assert data["content"] == { assert data["content"] == {
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}] "tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
} }
assert len(self.documents.get_docs(self.db_session)) == 2 assert len(self.documents.get_docs()) == 2
# Delete the first document # Delete the first document
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
...@@ -103,4 +103,4 @@ class TestDocuments(AbstractPostgresTest): ...@@ -103,4 +103,4 @@ class TestDocuments(AbstractPostgresTest):
self.create_url("/doc/delete?name=doc_name rework") self.create_url("/doc/delete?name=doc_name rework")
) )
assert response.status_code == 200 assert response.status_code == 200
assert len(self.documents.get_docs(self.db_session)) == 1 assert len(self.documents.get_docs()) == 1
...@@ -68,6 +68,16 @@ class TestPrompts(AbstractPostgresTest): ...@@ -68,6 +68,16 @@ class TestPrompts(AbstractPostgresTest):
assert data["content"] == "description Updated" assert data["content"] == "description Updated"
assert data["user_id"] == "3" assert data["user_id"] == "3"
# Get prompt by command
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/command/my-command2"))
assert response.status_code == 200
data = response.json()
assert data["command"] == "/my-command2"
assert data["title"] == "Hello World Updated"
assert data["content"] == "description Updated"
assert data["user_id"] == "3"
# Delete prompt # Delete prompt
with mock_webui_user(id="2"): with mock_webui_user(id="2"):
response = self.fast_api_client.delete( response = self.fast_api_client.delete(
......
...@@ -33,7 +33,6 @@ class TestUsers(AbstractPostgresTest): ...@@ -33,7 +33,6 @@ class TestUsers(AbstractPostgresTest):
def setup_method(self): def setup_method(self):
super().setup_method() super().setup_method()
self.users.insert_new_user( self.users.insert_new_user(
self.db_session,
id="1", id="1",
name="user 1", name="user 1",
email="user1@openwebui.com", email="user1@openwebui.com",
...@@ -41,7 +40,6 @@ class TestUsers(AbstractPostgresTest): ...@@ -41,7 +40,6 @@ class TestUsers(AbstractPostgresTest):
role="user", role="user",
) )
self.users.insert_new_user( self.users.insert_new_user(
self.db_session,
id="2", id="2",
name="user 2", name="user 2",
email="user2@openwebui.com", email="user2@openwebui.com",
......
...@@ -2,7 +2,6 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials ...@@ -2,7 +2,6 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends, Request from fastapi import HTTPException, status, Depends, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import get_db
from apps.webui.models.users import Users from apps.webui.models.users import Users
from pydantic import BaseModel from pydantic import BaseModel
...@@ -79,7 +78,6 @@ def get_http_authorization_cred(auth_header: str): ...@@ -79,7 +78,6 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user( def get_current_user(
request: Request, request: Request,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
db=Depends(get_db),
): ):
token = None token = None
...@@ -94,19 +92,19 @@ def get_current_user( ...@@ -94,19 +92,19 @@ def get_current_user(
# auth by api key # auth by api key
if token.startswith("sk-"): if token.startswith("sk-"):
return get_current_user_by_api_key(db, token) return get_current_user_by_api_key(token)
# auth by jwt token # auth by jwt token
data = decode_token(token) data = decode_token(token)
if data != None and "id" in data: if data != None and "id" in data:
user = Users.get_user_by_id(db, data["id"]) user = Users.get_user_by_id(data["id"])
if user is None: if user is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
else: else:
Users.update_user_last_active_by_id(db, user.id) Users.update_user_last_active_by_id(user.id)
return user return user
else: else:
raise HTTPException( raise HTTPException(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment