Commit bee835cb authored by Jonathan Rohde's avatar Jonathan Rohde
Browse files

feat(sqlalchemy): remove session reference from router

parent df09d083
......@@ -7,7 +7,6 @@ from fastapi import APIRouter
from pydantic import BaseModel
import logging
from apps.webui.internal.db import get_db
from apps.webui.models.memories import Memories, MemoryModel
from utils.utils import get_verified_user
......@@ -32,8 +31,8 @@ async def get_embeddings(request: Request):
@router.get("/", response_model=List[MemoryModel])
async def get_memories(user=Depends(get_verified_user), db=Depends(get_db)):
return Memories.get_memories_by_user_id(db, user.id)
async def get_memories(user=Depends(get_verified_user)):
return Memories.get_memories_by_user_id(user.id)
############################
......@@ -54,9 +53,8 @@ async def add_memory(
request: Request,
form_data: AddMemoryForm,
user=Depends(get_verified_user),
db=Depends(get_db),
):
memory = Memories.insert_new_memory(db, user.id, form_data.content)
memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
......@@ -76,9 +74,8 @@ async def update_memory_by_id(
request: Request,
form_data: MemoryUpdateModel,
user=Depends(get_verified_user),
db=Depends(get_db),
):
memory = Memories.update_memory_by_id(db, memory_id, form_data.content)
memory = Memories.update_memory_by_id(memory_id, form_data.content)
if memory is None:
raise HTTPException(status_code=404, detail="Memory not found")
......@@ -129,12 +126,12 @@ async def query_memory(
############################
@router.get("/reset", response_model=bool)
async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user), db=Depends(get_db)
request: Request, user=Depends(get_verified_user)
):
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(db, user.id)
memories = Memories.get_memories_by_user_id(user.id)
for memory in memories:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection.upsert(
......@@ -151,8 +148,8 @@ async def reset_memory_from_vector_db(
@router.delete("/user", response_model=bool)
async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(get_db)):
result = Memories.delete_memories_by_user_id(db, user.id)
async def delete_memory_by_user_id(user=Depends(get_verified_user)):
result = Memories.delete_memories_by_user_id(user.id)
if result:
try:
......@@ -171,9 +168,9 @@ async def delete_memory_by_user_id(user=Depends(get_verified_user), db=Depends(g
@router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id(
memory_id: str, user=Depends(get_verified_user), db=Depends(get_db)
memory_id: str, user=Depends(get_verified_user)
):
result = Memories.delete_memory_by_id_and_user_id(db, memory_id, user.id)
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result:
collection = CHROMA_CLIENT.get_or_create_collection(
......
......@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.internal.db import get_db
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user
......@@ -20,8 +19,8 @@ router = APIRouter()
@router.get("/", response_model=List[ModelResponse])
async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
return Models.get_all_models(db)
async def get_models(user=Depends(get_verified_user)):
return Models.get_all_models()
############################
......@@ -34,7 +33,6 @@ async def add_new_model(
request: Request,
form_data: ModelForm,
user=Depends(get_admin_user),
db=Depends(get_db),
):
if form_data.id in request.app.state.MODELS:
raise HTTPException(
......@@ -42,7 +40,7 @@ async def add_new_model(
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
)
else:
model = Models.insert_new_model(db, form_data, user.id)
model = Models.insert_new_model(form_data, user.id)
if model:
return model
......@@ -59,8 +57,8 @@ async def add_new_model(
@router.get("/{id}", response_model=Optional[ModelModel])
async def get_model_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)):
model = Models.get_model_by_id(db, id)
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
if model:
return model
......@@ -82,15 +80,14 @@ async def update_model_by_id(
id: str,
form_data: ModelForm,
user=Depends(get_admin_user),
db=Depends(get_db),
):
model = Models.get_model_by_id(db, id)
model = Models.get_model_by_id(id)
if model:
model = Models.update_model_by_id(db, id, form_data)
model = Models.update_model_by_id(id, form_data)
return model
else:
if form_data.id in request.app.state.MODELS:
model = Models.insert_new_model(db, form_data, user.id)
model = Models.insert_new_model(form_data, user.id)
if model:
return model
else:
......@@ -111,6 +108,6 @@ async def update_model_by_id(
@router.delete("/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
result = Models.delete_model_by_id(db, id)
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
result = Models.delete_model_by_id(id)
return result
......@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.internal.db import get_db
from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user, get_admin_user
......@@ -20,8 +19,8 @@ router = APIRouter()
@router.get("/", response_model=List[PromptModel])
async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)):
return Prompts.get_prompts(db)
async def get_prompts(user=Depends(get_current_user)):
return Prompts.get_prompts()
############################
......@@ -31,11 +30,11 @@ async def get_prompts(user=Depends(get_current_user), db=Depends(get_db)):
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(
form_data: PromptForm, user=Depends(get_admin_user), db=Depends(get_db)
form_data: PromptForm, user=Depends(get_admin_user)
):
prompt = Prompts.get_prompt_by_command(db, form_data.command)
prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt == None:
prompt = Prompts.insert_new_prompt(db, user.id, form_data)
prompt = Prompts.insert_new_prompt(user.id, form_data)
if prompt:
return prompt
......@@ -56,9 +55,9 @@ async def create_new_prompt(
@router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command(
command: str, user=Depends(get_current_user), db=Depends(get_db)
command: str, user=Depends(get_current_user)
):
prompt = Prompts.get_prompt_by_command(db, f"/{command}")
prompt = Prompts.get_prompt_by_command(f"/{command}")
if prompt:
return prompt
......@@ -79,9 +78,8 @@ async def update_prompt_by_command(
command: str,
form_data: PromptForm,
user=Depends(get_admin_user),
db=Depends(get_db),
):
prompt = Prompts.update_prompt_by_command(db, f"/{command}", form_data)
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:
return prompt
else:
......@@ -98,7 +96,7 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(
command: str, user=Depends(get_admin_user), db=Depends(get_db)
command: str, user=Depends(get_admin_user)
):
result = Prompts.delete_prompt_by_command(db, f"/{command}")
result = Prompts.delete_prompt_by_command(f"/{command}")
return result
......@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.internal.db import get_db
from apps.webui.models.users import Users
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
......@@ -34,7 +33,7 @@ router = APIRouter()
@router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)):
async def get_toolkits(user=Depends(get_verified_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
......@@ -45,8 +44,8 @@ async def get_toolkits(user=Depends(get_verified_user), db=Depends(get_db)):
@router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user), db=Depends(get_db)):
toolkits = [toolkit for toolkit in Tools.get_tools(db)]
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
......@@ -60,7 +59,6 @@ async def create_new_toolkit(
request: Request,
form_data: ToolForm,
user=Depends(get_admin_user),
db=Depends(get_db),
):
if not form_data.id.isidentifier():
raise HTTPException(
......@@ -70,7 +68,7 @@ async def create_new_toolkit(
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(db, form_data.id)
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
......@@ -84,7 +82,7 @@ async def create_new_toolkit(
TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(db, user.id, form_data, specs)
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True)
......@@ -115,8 +113,8 @@ async def create_new_toolkit(
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)):
toolkit = Tools.get_tool_by_id(db, id)
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
return toolkit
......@@ -138,7 +136,6 @@ async def update_toolkit_by_id(
id: str,
form_data: ToolForm,
user=Depends(get_admin_user),
db=Depends(get_db),
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
......@@ -160,7 +157,7 @@ async def update_toolkit_by_id(
}
print(updated)
toolkit = Tools.update_tool_by_id(db, id, updated)
toolkit = Tools.update_tool_by_id(id, updated)
if toolkit:
return toolkit
......@@ -184,9 +181,9 @@ async def update_toolkit_by_id(
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(
request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db)
request: Request, id: str, user=Depends(get_admin_user)
):
result = Tools.delete_tool_by_id(db, id)
result = Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
......
......@@ -9,7 +9,6 @@ import time
import uuid
import logging
from apps.webui.internal.db import get_db
from apps.webui.models.users import (
UserModel,
UserUpdateForm,
......@@ -42,9 +41,9 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel])
async def get_users(
skip: int = 0, limit: int = 50, user=Depends(get_admin_user), db=Depends(get_db)
skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
):
return Users.get_users(db, skip, limit)
return Users.get_users(skip, limit)
############################
......@@ -72,11 +71,11 @@ async def update_user_permissions(
@router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(
form_data: UserRoleUpdateForm, user=Depends(get_admin_user), db=Depends(get_db)
form_data: UserRoleUpdateForm, user=Depends(get_admin_user)
):
if user.id != form_data.id and form_data.id != Users.get_first_user(db).id:
return Users.update_user_role_by_id(db, form_data.id, form_data.role)
if user.id != form_data.id and form_data.id != Users.get_first_user().id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
......@@ -91,9 +90,9 @@ async def update_user_role(
@router.get("/user/settings", response_model=Optional[UserSettings])
async def get_user_settings_by_session_user(
user=Depends(get_verified_user), db=Depends(get_db)
user=Depends(get_verified_user)
):
user = Users.get_user_by_id(db, user.id)
user = Users.get_user_by_id(user.id)
if user:
return user.settings
else:
......@@ -110,9 +109,9 @@ async def get_user_settings_by_session_user(
@router.post("/user/settings/update", response_model=UserSettings)
async def update_user_settings_by_session_user(
form_data: UserSettings, user=Depends(get_verified_user), db=Depends(get_db)
form_data: UserSettings, user=Depends(get_verified_user)
):
user = Users.update_user_by_id(db, user.id, {"settings": form_data.model_dump()})
user = Users.update_user_by_id(user.id, {"settings": form_data.model_dump()})
if user:
return user.settings
else:
......@@ -129,9 +128,9 @@ async def update_user_settings_by_session_user(
@router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user(
user=Depends(get_verified_user), db=Depends(get_db)
user=Depends(get_verified_user)
):
user = Users.get_user_by_id(db, user.id)
user = Users.get_user_by_id(user.id)
if user:
return user.info
else:
......@@ -148,15 +147,15 @@ async def get_user_info_by_session_user(
@router.post("/user/info/update", response_model=Optional[dict])
async def update_user_info_by_session_user(
form_data: dict, user=Depends(get_verified_user), db=Depends(get_db)
form_data: dict, user=Depends(get_verified_user)
):
user = Users.get_user_by_id(db, user.id)
user = Users.get_user_by_id(user.id)
if user:
if user.info is None:
user.info = {}
user = Users.update_user_by_id(
db, user.id, {"info": {**user.info, **form_data}}
user.id, {"info": {**user.info, **form_data}}
)
if user:
return user.info
......@@ -184,14 +183,14 @@ class UserResponse(BaseModel):
@router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id(
user_id: str, user=Depends(get_verified_user), db=Depends(get_db)
user_id: str, user=Depends(get_verified_user)
):
# Check if user_id is a shared chat
# If it is, get the user_id from the chat
if user_id.startswith("shared-"):
chat_id = user_id.replace("shared-", "")
chat = Chats.get_chat_by_id(db, chat_id)
chat = Chats.get_chat_by_id(chat_id)
if chat:
user_id = chat.user_id
else:
......@@ -200,7 +199,7 @@ async def get_user_by_id(
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
user = Users.get_user_by_id(db, user_id)
user = Users.get_user_by_id(user_id)
if user:
return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
......@@ -221,13 +220,12 @@ async def update_user_by_id(
user_id: str,
form_data: UserUpdateForm,
session_user=Depends(get_admin_user),
db=Depends(get_db),
):
user = Users.get_user_by_id(db, user_id)
user = Users.get_user_by_id(user_id)
if user:
if form_data.email.lower() != user.email:
email_user = Users.get_user_by_email(db, form_data.email.lower())
email_user = Users.get_user_by_email(form_data.email.lower())
if email_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
......@@ -237,11 +235,10 @@ async def update_user_by_id(
if form_data.password:
hashed = get_password_hash(form_data.password)
log.debug(f"hashed: {hashed}")
Auths.update_user_password_by_id(db, user_id, hashed)
Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(db, user_id, form_data.email.lower())
Auths.update_email_by_id(user_id, form_data.email.lower())
updated_user = Users.update_user_by_id(
db,
user_id,
{
"name": form_data.name,
......@@ -271,10 +268,10 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(
user_id: str, user=Depends(get_admin_user), db=Depends(get_db)
user_id: str, user=Depends(get_admin_user)
):
if user.id != user_id:
result = Auths.delete_auth_by_id(db, user_id)
result = Auths.delete_auth_by_id(user_id)
if result:
return True
......
......@@ -57,7 +57,7 @@ from apps.webui.main import (
get_pipe_models,
generate_function_chat_completion,
)
from apps.webui.internal.db import get_db, SessionLocal
from apps.webui.internal.db import get_session, SessionLocal
from pydantic import BaseModel
......@@ -410,7 +410,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
SessionLocal(),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
......@@ -800,9 +799,7 @@ app.add_middleware(
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
db = SessionLocal()
await get_all_models(db)
db.commit()
await get_all_models()
else:
pass
......@@ -836,12 +833,12 @@ app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models(db: Session):
async def get_all_models():
pipe_models = []
openai_models = []
ollama_models = []
pipe_models = await get_pipe_models(db)
pipe_models = await get_pipe_models()
if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models()
......@@ -863,7 +860,7 @@ async def get_all_models(db: Session):
models = pipe_models + openai_models + ollama_models
custom_models = Models.get_all_models(db)
custom_models = Models.get_all_models()
for custom_model in custom_models:
if custom_model.base_model_id == None:
for model in models:
......@@ -903,8 +900,8 @@ async def get_all_models(db: Session):
@app.get("/api/models")
async def get_models(user=Depends(get_verified_user), db=Depends(get_db)):
models = await get_all_models(db)
async def get_models(user=Depends(get_verified_user)):
models = await get_all_models()
# Filter out filter pipelines
models = [
......@@ -1608,9 +1605,8 @@ async def get_pipeline_valves(
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models(db)
models = await get_all_models()
r = None
try:
......@@ -1649,9 +1645,8 @@ async def get_pipeline_valves_spec(
urlIdx: Optional[int],
pipeline_id: str,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models(db)
models = await get_all_models()
r = None
try:
......@@ -1690,9 +1685,8 @@ async def update_pipeline_valves(
pipeline_id: str,
form_data: dict,
user=Depends(get_admin_user),
db=Depends(get_db),
):
models = await get_all_models(db)
models = await get_all_models()
r = None
try:
......@@ -2040,7 +2034,8 @@ async def healthcheck():
@app.get("/health/db")
async def healthcheck_with_db(db: Session = Depends(get_db)):
async def healthcheck_with_db():
with get_session() as db:
result = db.execute(text("SELECT 1;")).all()
return {"status": True}
......
"""init
Revision ID: 22b5ab2667b8
Revises:
Create Date: 2024-06-20 13:22:40.397002
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.engine.reflection import Inspector
import apps.webui.internal.db
# revision identifiers, used by Alembic.
revision: str = "22b5ab2667b8"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
con = op.get_bind()
inspector = Inspector.from_engine(con)
tables = set(inspector.get_table_names())
# ### commands auto generated by Alembic - please adjust! ###
if not "auth" in tables:
op.create_table(
"auth",
sa.Column("id", sa.String(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("password", sa.String(), nullable=True),
sa.Column("active", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "chat" in tables:
op.create_table(
"chat",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("chat", sa.String(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("share_id", sa.String(), nullable=True),
sa.Column("archived", sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("share_id"),
)
if not "chatidtag" in tables:
op.create_table(
"chatidtag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("tag_name", sa.String(), nullable=True),
sa.Column("chat_id", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "document" in tables:
op.create_table(
"document",
sa.Column("collection_name", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("filename", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("collection_name"),
sa.UniqueConstraint("name"),
)
if not "memory" in tables:
op.create_table(
"memory",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "model" in tables:
op.create_table(
"model",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("base_model_id", sa.String(), nullable=True),
sa.Column("name", sa.String(), nullable=True),
sa.Column("params", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "prompt" in tables:
op.create_table(
"prompt",
sa.Column("command", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("title", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("timestamp", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("command"),
)
if not "tag" in tables:
op.create_table(
"tag",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("data", sa.String(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "tool" in tables:
op.create_table(
"tool",
sa.Column("id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=True),
sa.Column("name", sa.String(), nullable=True),
sa.Column("content", sa.String(), nullable=True),
sa.Column("specs", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("meta", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint("id"),
)
if not "user" in tables:
op.create_table(
"user",
sa.Column("id", sa.String(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("email", sa.String(), nullable=True),
sa.Column("role", sa.String(), nullable=True),
sa.Column("profile_image_url", sa.String(), nullable=True),
sa.Column("last_active_at", sa.BigInteger(), nullable=True),
sa.Column("updated_at", sa.BigInteger(), nullable=True),
sa.Column("created_at", sa.BigInteger(), nullable=True),
sa.Column("api_key", sa.String(), nullable=True),
sa.Column("settings", apps.webui.internal.db.JSONField(), nullable=True),
sa.Column("info", apps.webui.internal.db.JSONField(), nullable=True),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("api_key"),
)
if not "file" in tables:
op.create_table('file',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
if not "function" in tables:
op.create_table('function',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('type', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
# do nothing as we assume we had previous migrations from peewee-migrate
pass
# ### end Alembic commands ###
"""init
Revision ID: ba76b0bae648
Revises:
Create Date: 2024-06-24 09:09:11.636336
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
import apps.webui.internal.db
# revision identifiers, used by Alembic.
revision: str = 'ba76b0bae648'
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('auth',
sa.Column('id', sa.String(), nullable=False),
sa.Column('email', sa.String(), nullable=True),
sa.Column('password', sa.String(), nullable=True),
sa.Column('active', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('chat',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('chat', sa.String(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('share_id', sa.String(), nullable=True),
sa.Column('archived', sa.Boolean(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('share_id')
)
op.create_table('chatidtag',
sa.Column('id', sa.String(), nullable=False),
sa.Column('tag_name', sa.String(), nullable=True),
sa.Column('chat_id', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('document',
sa.Column('collection_name', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('collection_name'),
sa.UniqueConstraint('name')
)
op.create_table('file',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('filename', sa.String(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('function',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.Text(), nullable=True),
sa.Column('type', sa.Text(), nullable=True),
sa.Column('content', sa.Text(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('is_active', sa.Boolean(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('memory',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('model',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('base_model_id', sa.String(), nullable=True),
sa.Column('name', sa.String(), nullable=True),
sa.Column('params', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('prompt',
sa.Column('command', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('title', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('timestamp', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('command')
)
op.create_table('tag',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('data', sa.String(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('tool',
sa.Column('id', sa.String(), nullable=False),
sa.Column('user_id', sa.String(), nullable=True),
sa.Column('name', sa.String(), nullable=True),
sa.Column('content', sa.String(), nullable=True),
sa.Column('specs', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('meta', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('valves', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_table('user',
sa.Column('id', sa.String(), nullable=False),
sa.Column('name', sa.String(), nullable=True),
sa.Column('email', sa.String(), nullable=True),
sa.Column('role', sa.String(), nullable=True),
sa.Column('profile_image_url', sa.String(), nullable=True),
sa.Column('last_active_at', sa.BigInteger(), nullable=True),
sa.Column('updated_at', sa.BigInteger(), nullable=True),
sa.Column('created_at', sa.BigInteger(), nullable=True),
sa.Column('api_key', sa.String(), nullable=True),
sa.Column('settings', apps.webui.internal.db.JSONField(), nullable=True),
sa.Column('info', apps.webui.internal.db.JSONField(), nullable=True),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('api_key')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('user')
op.drop_table('tool')
op.drop_table('tag')
op.drop_table('prompt')
op.drop_table('model')
op.drop_table('memory')
op.drop_table('function')
op.drop_table('file')
op.drop_table('document')
op.drop_table('chatidtag')
op.drop_table('chat')
op.drop_table('auth')
# ### end Alembic commands ###
......@@ -31,7 +31,6 @@ class TestAuths(AbstractPostgresTest):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
......@@ -45,7 +44,7 @@ class TestAuths(AbstractPostgresTest):
json={"name": "John Doe 2", "profile_image_url": "/user2.png"},
)
assert response.status_code == 200
db_user = self.users.get_user_by_id(self.db_session, user.id)
db_user = self.users.get_user_by_id(user.id)
assert db_user.name == "John Doe 2"
assert db_user.profile_image_url == "/user2.png"
......@@ -53,7 +52,6 @@ class TestAuths(AbstractPostgresTest):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password=get_password_hash("old_password"),
name="John Doe",
......@@ -69,11 +67,11 @@ class TestAuths(AbstractPostgresTest):
assert response.status_code == 200
old_auth = self.auths.authenticate_user(
self.db_session, "john.doe@openwebui.com", "old_password"
"john.doe@openwebui.com", "old_password"
)
assert old_auth is None
new_auth = self.auths.authenticate_user(
self.db_session, "john.doe@openwebui.com", "new_password"
"john.doe@openwebui.com", "new_password"
)
assert new_auth is not None
......@@ -81,7 +79,6 @@ class TestAuths(AbstractPostgresTest):
from utils.utils import get_password_hash
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password=get_password_hash("password"),
name="John Doe",
......@@ -144,7 +141,6 @@ class TestAuths(AbstractPostgresTest):
def test_get_admin_details(self):
self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password="password",
name="John Doe",
......@@ -162,7 +158,6 @@ class TestAuths(AbstractPostgresTest):
def test_create_api_key_(self):
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password="password",
name="John Doe",
......@@ -178,31 +173,29 @@ class TestAuths(AbstractPostgresTest):
def test_delete_api_key(self):
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
self.users.update_user_api_key_by_id(self.db_session, user.id, "abc")
self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id):
response = self.fast_api_client.delete(self.create_url("/api_key"))
assert response.status_code == 200
assert response.json() == True
db_user = self.users.get_user_by_id(self.db_session, user.id)
db_user = self.users.get_user_by_id(user.id)
assert db_user.api_key is None
def test_get_api_key(self):
user = self.auths.insert_new_auth(
self.db_session,
email="john.doe@openwebui.com",
password="password",
name="John Doe",
profile_image_url="/user.png",
role="admin",
)
self.users.update_user_api_key_by_id(self.db_session, user.id, "abc")
self.users.update_user_api_key_by_id(user.id, "abc")
with mock_webui_user(id=user.id):
response = self.fast_api_client.get(self.create_url("/api_key"))
assert response.status_code == 200
......
......@@ -18,7 +18,6 @@ class TestChats(AbstractPostgresTest):
self.chats = Chats
self.chats.insert_new_chat(
self.db_session,
"2",
ChatForm(
**{
......@@ -46,7 +45,7 @@ class TestChats(AbstractPostgresTest):
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url("/"))
assert response.status_code == 200
assert len(self.chats.get_chats(self.db_session)) == 0
assert len(self.chats.get_chats()) == 0
def test_get_user_chat_list_by_user_id(self):
with mock_webui_user(id="3"):
......@@ -84,14 +83,13 @@ class TestChats(AbstractPostgresTest):
assert data["title"] == "New Chat"
assert data["updated_at"] is not None
assert data["created_at"] is not None
assert len(self.chats.get_chats(self.db_session)) == 2
assert len(self.chats.get_chats()) == 2
def test_get_user_chats(self):
self.test_get_session_user_chat_list()
def test_get_user_archived_chats(self):
self.chats.archive_all_chats_by_user_id(self.db_session, "2")
self.db_session.commit()
self.chats.archive_all_chats_by_user_id("2")
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/all/archived"))
assert response.status_code == 200
......@@ -114,12 +112,11 @@ class TestChats(AbstractPostgresTest):
with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url("/archive/all"))
assert response.status_code == 200
assert len(self.chats.get_archived_chats_by_user_id(self.db_session, "2")) == 1
assert len(self.chats.get_archived_chats_by_user_id("2")) == 1
def test_get_shared_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
self.chats.update_chat_share_id_by_id(self.db_session, chat_id, chat_id)
self.db_session.commit()
chat_id = self.chats.get_chats()[0].id
self.chats.update_chat_share_id_by_id(chat_id, chat_id)
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/share/{chat_id}"))
assert response.status_code == 200
......@@ -136,7 +133,7 @@ class TestChats(AbstractPostgresTest):
assert data["title"] == "New Chat"
def test_get_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}"))
assert response.status_code == 200
......@@ -153,7 +150,7 @@ class TestChats(AbstractPostgresTest):
assert data["user_id"] == "2"
def test_update_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.post(
self.create_url(f"/{chat_id}"),
......@@ -181,14 +178,14 @@ class TestChats(AbstractPostgresTest):
assert data["user_id"] == "2"
def test_delete_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}"))
assert response.status_code == 200
assert response.json() is True
def test_clone_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/clone"))
......@@ -209,31 +206,30 @@ class TestChats(AbstractPostgresTest):
assert data["user_id"] == "2"
def test_archive_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url(f"/{chat_id}/archive"))
assert response.status_code == 200
chat = self.chats.get_chat_by_id(self.db_session, chat_id)
chat = self.chats.get_chat_by_id(chat_id)
assert chat.archived is True
def test_share_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
chat_id = self.chats.get_chats()[0].id
with mock_webui_user(id="2"):
response = self.fast_api_client.post(self.create_url(f"/{chat_id}/share"))
assert response.status_code == 200
chat = self.chats.get_chat_by_id(self.db_session, chat_id)
chat = self.chats.get_chat_by_id(chat_id)
assert chat.share_id is not None
def test_delete_shared_chat_by_id(self):
chat_id = self.chats.get_chats(self.db_session)[0].id
chat_id = self.chats.get_chats()[0].id
share_id = str(uuid.uuid4())
self.chats.update_chat_share_id_by_id(self.db_session, chat_id, share_id)
self.db_session.commit()
self.chats.update_chat_share_id_by_id(chat_id, share_id)
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(self.create_url(f"/{chat_id}/share"))
assert response.status_code
chat = self.chats.get_chat_by_id(self.db_session, chat_id)
chat = self.chats.get_chat_by_id(chat_id)
assert chat.share_id is None
......@@ -14,7 +14,7 @@ class TestDocuments(AbstractPostgresTest):
def test_documents(self):
# Empty database
assert len(self.documents.get_docs(self.db_session)) == 0
assert len(self.documents.get_docs()) == 0
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/"))
assert response.status_code == 200
......@@ -34,7 +34,7 @@ class TestDocuments(AbstractPostgresTest):
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name"
assert len(self.documents.get_docs(self.db_session)) == 1
assert len(self.documents.get_docs()) == 1
# Get the document
with mock_webui_user(id="2"):
......@@ -61,7 +61,7 @@ class TestDocuments(AbstractPostgresTest):
)
assert response.status_code == 200
assert response.json()["name"] == "doc_name 2"
assert len(self.documents.get_docs(self.db_session)) == 2
assert len(self.documents.get_docs()) == 2
# Get all documents
with mock_webui_user(id="2"):
......@@ -95,7 +95,7 @@ class TestDocuments(AbstractPostgresTest):
assert data["content"] == {
"tags": [{"name": "testing-tag"}, {"name": "another-tag"}]
}
assert len(self.documents.get_docs(self.db_session)) == 2
assert len(self.documents.get_docs()) == 2
# Delete the first document
with mock_webui_user(id="2"):
......@@ -103,4 +103,4 @@ class TestDocuments(AbstractPostgresTest):
self.create_url("/doc/delete?name=doc_name rework")
)
assert response.status_code == 200
assert len(self.documents.get_docs(self.db_session)) == 1
assert len(self.documents.get_docs()) == 1
......@@ -68,6 +68,16 @@ class TestPrompts(AbstractPostgresTest):
assert data["content"] == "description Updated"
assert data["user_id"] == "3"
# Get prompt by command
with mock_webui_user(id="2"):
response = self.fast_api_client.get(self.create_url("/command/my-command2"))
assert response.status_code == 200
data = response.json()
assert data["command"] == "/my-command2"
assert data["title"] == "Hello World Updated"
assert data["content"] == "description Updated"
assert data["user_id"] == "3"
# Delete prompt
with mock_webui_user(id="2"):
response = self.fast_api_client.delete(
......
......@@ -33,7 +33,6 @@ class TestUsers(AbstractPostgresTest):
def setup_method(self):
super().setup_method()
self.users.insert_new_user(
self.db_session,
id="1",
name="user 1",
email="user1@openwebui.com",
......@@ -41,7 +40,6 @@ class TestUsers(AbstractPostgresTest):
role="user",
)
self.users.insert_new_user(
self.db_session,
id="2",
name="user 2",
email="user2@openwebui.com",
......
......@@ -2,7 +2,6 @@ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends, Request
from sqlalchemy.orm import Session
from apps.webui.internal.db import get_db
from apps.webui.models.users import Users
from pydantic import BaseModel
......@@ -79,7 +78,6 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user(
request: Request,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
db=Depends(get_db),
):
token = None
......@@ -94,19 +92,19 @@ def get_current_user(
# auth by api key
if token.startswith("sk-"):
return get_current_user_by_api_key(db, token)
return get_current_user_by_api_key(token)
# auth by jwt token
data = decode_token(token)
if data != None and "id" in data:
user = Users.get_user_by_id(db, data["id"])
user = Users.get_user_by_id(data["id"])
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
Users.update_user_last_active_by_id(db, user.id)
Users.update_user_last_active_by_id(user.id)
return user
else:
raise HTTPException(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment