Commit bee835cb authored by Jonathan Rohde's avatar Jonathan Rohde
Browse files

feat(sqlalchemy): remove session reference from router

parent df09d083
...@@ -31,7 +31,6 @@ from typing import Optional, List, Union ...@@ -31,7 +31,6 @@ from typing import Optional, List, Union
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from apps.webui.internal.db import get_db
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -712,7 +711,6 @@ async def generate_chat_completion( ...@@ -712,7 +711,6 @@ async def generate_chat_completion(
form_data: GenerateChatCompletionForm, form_data: GenerateChatCompletionForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db=Depends(get_db),
): ):
log.debug( log.debug(
...@@ -726,7 +724,7 @@ async def generate_chat_completion( ...@@ -726,7 +724,7 @@ async def generate_chat_completion(
} }
model_id = form_data.model model_id = form_data.model
model_info = Models.get_model_by_id(db, model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
...@@ -885,7 +883,6 @@ async def generate_openai_chat_completion( ...@@ -885,7 +883,6 @@ async def generate_openai_chat_completion(
form_data: dict, form_data: dict,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db=Depends(get_db),
): ):
form_data = OpenAIChatCompletionForm(**form_data) form_data = OpenAIChatCompletionForm(**form_data)
...@@ -894,7 +891,7 @@ async def generate_openai_chat_completion( ...@@ -894,7 +891,7 @@ async def generate_openai_chat_completion(
} }
model_id = form_data.model model_id = form_data.model
model_info = Models.get_model_by_id(db, model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
......
...@@ -11,7 +11,6 @@ import logging ...@@ -11,7 +11,6 @@ import logging
from pydantic import BaseModel from pydantic import BaseModel
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from apps.webui.internal.db import get_db
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -354,13 +353,12 @@ async def generate_chat_completion( ...@@ -354,13 +353,12 @@ async def generate_chat_completion(
form_data: dict, form_data: dict,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
db=Depends(get_db),
): ):
idx = 0 idx = 0
payload = {**form_data} payload = {**form_data}
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(db, model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
......
import os import os
import logging import logging
import json import json
from contextlib import contextmanager
from typing import Optional, Any from typing import Optional, Any
from typing_extensions import Self from typing_extensions import Self
...@@ -52,11 +53,12 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL: ...@@ -52,11 +53,12 @@ if "sqlite" in SQLALCHEMY_DATABASE_URL:
) )
else: else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True) engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine, expire_on_commit=False)
Base = declarative_base() Base = declarative_base()
def get_db(): @contextmanager
def get_session():
db = SessionLocal() db = SessionLocal()
try: try:
yield db yield db
...@@ -64,5 +66,4 @@ def get_db(): ...@@ -64,5 +66,4 @@ def get_db():
except Exception as e: except Exception as e:
db.rollback() db.rollback()
raise e raise e
finally:
db.close()
...@@ -114,8 +114,8 @@ async def get_status(): ...@@ -114,8 +114,8 @@ async def get_status():
} }
async def get_pipe_models(db: Session): async def get_pipe_models():
pipes = Functions.get_functions_by_type(db, "pipe", active_only=True) pipes = Functions.get_functions_by_type("pipe", active_only=True)
pipe_models = [] pipe_models = []
for pipe in pipes: for pipe in pipes:
......
...@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session ...@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
from apps.webui.models.users import UserModel, Users from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password from utils.utils import verify_password
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -96,7 +96,6 @@ class AuthsTable: ...@@ -96,7 +96,6 @@ class AuthsTable:
def insert_new_auth( def insert_new_auth(
self, self,
db: Session,
email: str, email: str,
password: str, password: str,
name: str, name: str,
...@@ -104,100 +103,107 @@ class AuthsTable: ...@@ -104,100 +103,107 @@ class AuthsTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info("insert_new_auth") with get_session() as db:
log.info("insert_new_auth")
id = str(uuid.uuid4()) id = str(uuid.uuid4())
auth = AuthModel( auth = AuthModel(
**{"id": id, "email": email, "password": password, "active": True} **{"id": id, "email": email, "password": password, "active": True}
) )
result = Auth(**auth.model_dump()) result = Auth(**auth.model_dump())
db.add(result) db.add(result)
user = Users.insert_new_user( user = Users.insert_new_user(
db, id, name, email, profile_image_url, role, oauth_sub id, name, email, profile_image_url, role, oauth_sub
) )
db.commit() db.commit()
db.refresh(result) db.refresh(result)
if result and user: if result and user:
return user return user
else: else:
return None return None
def authenticate_user( def authenticate_user(
self, db: Session, email: str, password: str self, email: str, password: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}") log.info(f"authenticate_user: {email}")
try: with get_session() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first() try:
if auth: auth = db.query(Auth).filter_by(email=email, active=True).first()
if verify_password(password, auth.password): if auth:
user = Users.get_user_by_id(db, auth.id) if verify_password(password, auth.password):
return user user = Users.get_user_by_id(auth.id)
return user
else:
return None
else: else:
return None return None
else: except:
return None return None
except:
return None
def authenticate_user_by_api_key( def authenticate_user_by_api_key(
self, db: Session, api_key: str self, api_key: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}") log.info(f"authenticate_user_by_api_key: {api_key}")
# if no api_key, return None with get_session() as db:
if not api_key: # if no api_key, return None
return None if not api_key:
return None
try: try:
user = Users.get_user_by_api_key(db, api_key) user = Users.get_user_by_api_key(api_key)
return user if user else None return user if user else None
except: except:
return False return False
def authenticate_user_by_trusted_header( def authenticate_user_by_trusted_header(
self, db: Session, email: str self, email: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}") log.info(f"authenticate_user_by_trusted_header: {email}")
try: with get_session() as db:
auth = db.query(Auth).filter(email=email, active=True).first() try:
if auth: auth = db.query(Auth).filter(email=email, active=True).first()
user = Users.get_user_by_id(auth.id) if auth:
return user user = Users.get_user_by_id(auth.id)
except: return user
return None except:
return None
def update_user_password_by_id( def update_user_password_by_id(
self, db: Session, id: str, new_password: str self, id: str, new_password: str
) -> bool: ) -> bool:
try: with get_session() as db:
result = db.query(Auth).filter_by(id=id).update({"password": new_password}) try:
return True if result == 1 else False result = db.query(Auth).filter_by(id=id).update({"password": new_password})
except: return True if result == 1 else False
return False except:
return False
def update_email_by_id(self, db: Session, id: str, email: str) -> bool:
try: def update_email_by_id(self, id: str, email: str) -> bool:
result = db.query(Auth).filter_by(id=id).update({"email": email}) with get_session() as db:
return True if result == 1 else False try:
except: result = db.query(Auth).filter_by(id=id).update({"email": email})
return False return True if result == 1 else False
except:
def delete_auth_by_id(self, db: Session, id: str) -> bool: return False
try:
# Delete User def delete_auth_by_id(self, id: str) -> bool:
result = Users.delete_user_by_id(db, id) with get_session() as db:
try:
if result: # Delete User
db.query(Auth).filter_by(id=id).delete() result = Users.delete_user_by_id(id)
return True if result:
else: db.query(Auth).filter_by(id=id).delete()
return True
else:
return False
except:
return False return False
except:
return False
Auths = AuthsTable() Auths = AuthsTable()
...@@ -8,7 +8,7 @@ import time ...@@ -8,7 +8,7 @@ import time
from sqlalchemy import Column, String, BigInteger, Boolean from sqlalchemy import Column, String, BigInteger, Boolean
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
#################### ####################
...@@ -80,249 +80,269 @@ class ChatTitleIdResponse(BaseModel): ...@@ -80,249 +80,269 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def insert_new_chat( def insert_new_chat(
self, db: Session, user_id: str, form_data: ChatForm self, user_id: str, form_data: ChatForm
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
id = str(uuid.uuid4()) with get_session() as db:
chat = ChatModel( id = str(uuid.uuid4())
**{ chat = ChatModel(
"id": id, **{
"user_id": user_id, "id": id,
"title": ( "user_id": user_id,
form_data.chat["title"] if "title" in form_data.chat else "New Chat" "title": (
), form_data.chat["title"] if "title" in form_data.chat else "New Chat"
"chat": json.dumps(form_data.chat), ),
"created_at": int(time.time()), "chat": json.dumps(form_data.chat),
"updated_at": int(time.time()), "created_at": int(time.time()),
}
)
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(
self, db: Session, id: str, chat: dict
) -> Optional[ChatModel]:
try:
db.query(Chat).filter_by(id=id).update(
{
"chat": json.dumps(chat),
"title": chat["title"] if "title" in chat else "New Chat",
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
return self.get_chat_by_id(db, id) result = Chat(**chat.model_dump())
except: db.add(result)
return None db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def insert_shared_chat_by_chat_id( def update_chat_by_id(
self, db: Session, chat_id: str self, id: str, chat: dict
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
# Get the existing chat to share with get_session() as db:
chat = db.get(Chat, chat_id) try:
# Check if the chat is already shared chat_obj = db.get(Chat, id)
if chat.share_id: chat_obj.chat = json.dumps(chat)
return self.get_chat_by_id_and_user_id(db, chat.share_id, "shared") chat_obj.title = chat["title"] if "title" in chat else "New Chat"
# Create a new chat with the same data, but with a new ID chat_obj.updated_at = int(time.time())
shared_chat = ChatModel( db.commit()
**{ db.refresh(chat_obj)
"id": str(uuid.uuid4()),
"user_id": f"shared-{chat_id}", return ChatModel.model_validate(chat_obj)
"title": chat.title, except Exception as e:
"chat": chat.chat, return None
"created_at": chat.created_at,
"updated_at": int(time.time()),
}
)
shared_result = Chat(**shared_chat.model_dump())
db.add(shared_result)
db.commit()
db.refresh(shared_result)
# Update the original chat with the share_id
result = (
db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
)
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id( def insert_shared_chat_by_chat_id(
self, db: Session, chat_id: str self, chat_id: str
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: with get_session() as db:
print("update_shared_chat_by_id") # Get the existing chat to share
chat = db.get(Chat, chat_id) chat = db.get(Chat, chat_id)
print(chat) # Check if the chat is already shared
if chat.share_id:
db.query(Chat).filter_by(id=chat.share_id).update( return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
{"title": chat.title, "chat": chat.chat} # Create a new chat with the same data, but with a new ID
shared_chat = ChatModel(
**{
"id": str(uuid.uuid4()),
"user_id": f"shared-{chat_id}",
"title": chat.title,
"chat": chat.chat,
"created_at": chat.created_at,
"updated_at": int(time.time()),
}
)
shared_result = Chat(**shared_chat.model_dump())
db.add(shared_result)
db.commit()
db.refresh(shared_result)
# Update the original chat with the share_id
result = (
db.query(Chat).filter_by(id=chat_id).update({"share_id": shared_chat.id})
) )
return self.get_chat_by_id(db, chat.share_id) return shared_chat if (shared_result and result) else None
except:
return None def update_shared_chat_by_chat_id(
self, chat_id: str
) -> Optional[ChatModel]:
with get_session() as db:
try:
print("update_shared_chat_by_id")
chat = db.get(Chat, chat_id)
print(chat)
chat.title = chat.title
chat.chat = chat.chat
db.commit()
db.refresh(chat)
return self.get_chat_by_id(chat.share_id)
except:
return None
def delete_shared_chat_by_chat_id(self, db: Session, chat_id: str) -> bool: def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try: try:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete() with get_session() as db:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
return True return True
except: except:
return False return False
def update_chat_share_id_by_id( def update_chat_share_id_by_id(
self, db: Session, id: str, share_id: Optional[str] self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
db.query(Chat).filter_by(id=id).update({"share_id": share_id}) with get_session() as db:
chat = db.get(Chat, id)
return self.get_chat_by_id(db, id) chat.share_id = share_id
db.commit()
db.refresh(chat)
return chat
except: except:
return None return None
def toggle_chat_archive_by_id(self, db: Session, id: str) -> Optional[ChatModel]: def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
chat = self.get_chat_by_id(db, id) with get_session() as db:
db.query(Chat).filter_by(id=id).update({"archived": not chat.archived}) chat = self.get_chat_by_id(id)
db.query(Chat).filter_by(id=id).update({"archived": not chat.archived})
return self.get_chat_by_id(db, id) return self.get_chat_by_id(id)
except: except:
return None return None
def archive_all_chats_by_user_id(self, db: Session, user_id: str) -> bool: def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try: try:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True}) with get_session() as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
return True return True
except: except:
return False return False
def get_archived_chat_list_by_user_id( def get_archived_chat_list_by_user_id(
self, db: Session, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
all_chats = ( with get_session() as db:
db.query(Chat) all_chats = (
.filter_by(user_id=user_id, archived=True) db.query(Chat)
.order_by(Chat.updated_at.desc()) .filter_by(user_id=user_id, archived=True)
# .limit(limit).offset(skip) .order_by(Chat.updated_at.desc())
.all() # .limit(limit).offset(skip)
) .all()
return [ChatModel.model_validate(chat) for chat in all_chats] )
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id( def get_chat_list_by_user_id(
self, self,
db: Session,
user_id: str, user_id: str,
include_archived: bool = False, include_archived: bool = False,
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
) -> List[ChatModel]: ) -> List[ChatModel]:
query = db.query(Chat).filter_by(user_id=user_id) with get_session() as db:
if not include_archived: query = db.query(Chat).filter_by(user_id=user_id)
query = query.filter_by(archived=False) if not include_archived:
all_chats = ( query = query.filter_by(archived=False)
query.order_by(Chat.updated_at.desc()) all_chats = (
# .limit(limit).offset(skip) query.order_by(Chat.updated_at.desc())
.all() # .limit(limit).offset(skip)
) .all()
return [ChatModel.model_validate(chat) for chat in all_chats] )
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_chat_ids( def get_chat_list_by_chat_ids(
self, db: Session, chat_ids: List[str], skip: int = 0, limit: int = 50 self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
all_chats = ( with get_session() as db:
db.query(Chat) all_chats = (
.filter(Chat.id.in_(chat_ids)) db.query(Chat)
.filter_by(archived=False) .filter(Chat.id.in_(chat_ids))
.order_by(Chat.updated_at.desc()) .filter_by(archived=False)
.all() .order_by(Chat.updated_at.desc())
) .all()
return [ChatModel.model_validate(chat) for chat in all_chats] )
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, db: Session, id: str) -> Optional[ChatModel]:
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try: try:
chat = db.get(Chat, id) with get_session() as db:
return ChatModel.model_validate(chat) chat = db.get(Chat, id)
return ChatModel.model_validate(chat)
except: except:
return None return None
def get_chat_by_share_id(self, db: Session, id: str) -> Optional[ChatModel]: def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try: try:
chat = db.query(Chat).filter_by(share_id=id).first() with get_session() as db:
chat = db.query(Chat).filter_by(share_id=id).first()
if chat: if chat:
return self.get_chat_by_id(db, id) return self.get_chat_by_id(id)
else: else:
return None return None
except Exception as e: except Exception as e:
return None return None
def get_chat_by_id_and_user_id( def get_chat_by_id_and_user_id(
self, db: Session, id: str, user_id: str self, id: str, user_id: str
) -> Optional[ChatModel]: ) -> Optional[ChatModel]:
try: try:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first() with get_session() as db:
return ChatModel.model_validate(chat) chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat)
except: except:
return None return None
def get_chats(self, db: Session, skip: int = 0, limit: int = 50) -> List[ChatModel]: def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
all_chats = ( with get_session() as db:
db.query(Chat) all_chats = (
# .limit(limit).offset(skip) db.query(Chat)
.order_by(Chat.updated_at.desc()) # .limit(limit).offset(skip)
) .order_by(Chat.updated_at.desc())
return [ChatModel.model_validate(chat) for chat in all_chats] )
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, db: Session, user_id: str) -> List[ChatModel]: def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
all_chats = ( with get_session() as db:
db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc()) all_chats = (
) db.query(Chat).filter_by(user_id=user_id).order_by(Chat.updated_at.desc())
return [ChatModel.model_validate(chat) for chat in all_chats] )
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id( def get_archived_chats_by_user_id(
self, db: Session, user_id: str self, user_id: str
) -> List[ChatModel]: ) -> List[ChatModel]:
all_chats = ( with get_session() as db:
db.query(Chat) all_chats = (
.filter_by(user_id=user_id, archived=True) db.query(Chat)
.order_by(Chat.updated_at.desc()) .filter_by(user_id=user_id, archived=True)
) .order_by(Chat.updated_at.desc())
return [ChatModel.model_validate(chat) for chat in all_chats] )
return [ChatModel.model_validate(chat) for chat in all_chats]
def delete_chat_by_id(self, db: Session, id: str) -> bool:
def delete_chat_by_id(self, id: str) -> bool:
try: try:
db.query(Chat).filter_by(id=id).delete() with get_session() as db:
db.query(Chat).filter_by(id=id).delete()
return True and self.delete_shared_chat_by_chat_id(db, id) return True and self.delete_shared_chat_by_chat_id(id)
except: except:
return False return False
def delete_chat_by_id_and_user_id(self, db: Session, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
db.query(Chat).filter_by(id=id, user_id=user_id).delete() with get_session() as db:
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
return True and self.delete_shared_chat_by_chat_id(db, id) return True and self.delete_shared_chat_by_chat_id(id)
except: except:
return False return False
def delete_chats_by_user_id(self, db: Session, user_id: str) -> bool: def delete_chats_by_user_id(self, user_id: str) -> bool:
try: try:
with get_session() as db:
self.delete_shared_chats_by_user_id(user_id)
self.delete_shared_chats_by_user_id(db, user_id) db.query(Chat).filter_by(user_id=user_id).delete()
db.query(Chat).filter_by(user_id=user_id).delete()
return True return True
except: except:
return False return False
def delete_shared_chats_by_user_id(self, db: Session, user_id: str) -> bool: def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try: try:
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all() with get_session() as db:
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user] chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete() db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
return True return True
except: except:
......
...@@ -6,7 +6,7 @@ import logging ...@@ -6,7 +6,7 @@ import logging
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
import json import json
...@@ -73,7 +73,7 @@ class DocumentForm(DocumentUpdateForm): ...@@ -73,7 +73,7 @@ class DocumentForm(DocumentUpdateForm):
class DocumentsTable: class DocumentsTable:
def insert_new_doc( def insert_new_doc(
self, db: Session, user_id: str, form_data: DocumentForm self, user_id: str, form_data: DocumentForm
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
document = DocumentModel( document = DocumentModel(
**{ **{
...@@ -84,66 +84,73 @@ class DocumentsTable: ...@@ -84,66 +84,73 @@ class DocumentsTable:
) )
try: try:
result = Document(**document.model_dump()) with get_session() as db:
db.add(result) result = Document(**document.model_dump())
db.commit() db.add(result)
db.refresh(result) db.commit()
if result: db.refresh(result)
return DocumentModel.model_validate(result) if result:
else: return DocumentModel.model_validate(result)
return None else:
return None
except: except:
return None return None
def get_doc_by_name(self, db: Session, name: str) -> Optional[DocumentModel]: def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
try: try:
document = db.query(Document).filter_by(name=name).first() with get_session() as db:
return DocumentModel.model_validate(document) if document else None document = db.query(Document).filter_by(name=name).first()
return DocumentModel.model_validate(document) if document else None
except: except:
return None return None
def get_docs(self, db: Session) -> List[DocumentModel]: def get_docs(self) -> List[DocumentModel]:
return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()] with get_session() as db:
return [DocumentModel.model_validate(doc) for doc in db.query(Document).all()]
def update_doc_by_name( def update_doc_by_name(
self, db: Session, name: str, form_data: DocumentUpdateForm self, name: str, form_data: DocumentUpdateForm
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
try: try:
db.query(Document).filter_by(name=name).update( with get_session() as db:
{ db.query(Document).filter_by(name=name).update(
"title": form_data.title, {
"name": form_data.name, "title": form_data.title,
"timestamp": int(time.time()), "name": form_data.name,
} "timestamp": int(time.time()),
) }
return self.get_doc_by_name(db, form_data.name) )
db.commit()
return self.get_doc_by_name(form_data.name)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def update_doc_content_by_name( def update_doc_content_by_name(
self, db: Session, name: str, updated: dict self, name: str, updated: dict
) -> Optional[DocumentModel]: ) -> Optional[DocumentModel]:
try: try:
doc = self.get_doc_by_name(db, name) with get_session() as db:
doc_content = json.loads(doc.content if doc.content else "{}") doc = self.get_doc_by_name(name)
doc_content = {**doc_content, **updated} doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = {**doc_content, **updated}
db.query(Document).filter_by(name=name).update(
{ db.query(Document).filter_by(name=name).update(
"content": json.dumps(doc_content), {
"timestamp": int(time.time()), "content": json.dumps(doc_content),
} "timestamp": int(time.time()),
) }
)
return self.get_doc_by_name(db, name) db.commit()
return self.get_doc_by_name(name)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
return None return None
def delete_doc_by_name(self, db: Session, name: str) -> bool: def delete_doc_by_name(self, name: str) -> bool:
try: try:
db.query(Document).filter_by(name=name).delete() with get_session() as db:
db.query(Document).filter_by(name=name).delete()
return True return True
except: except:
return False return False
......
...@@ -6,7 +6,7 @@ import logging ...@@ -6,7 +6,7 @@ import logging
from sqlalchemy import Column, String, BigInteger from sqlalchemy import Column, String, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import JSONField, Base from apps.webui.internal.db import JSONField, Base, get_session
import json import json
...@@ -60,7 +60,7 @@ class FileForm(BaseModel): ...@@ -60,7 +60,7 @@ class FileForm(BaseModel):
class FilesTable: class FilesTable:
def insert_new_file(self, db: Session, user_id: str, form_data: FileForm) -> Optional[FileModel]: def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
file = FileModel( file = FileModel(
**{ **{
**form_data.model_dump(), **form_data.model_dump(),
...@@ -70,38 +70,45 @@ class FilesTable: ...@@ -70,38 +70,45 @@ class FilesTable:
) )
try: try:
result = File(**file.model_dump()) with get_session() as db:
db.add(result) result = File(**file.model_dump())
db.commit() db.add(result)
db.refresh(result) db.commit()
if result: db.refresh(result)
return FileModel.model_validate(result) if result:
else: return FileModel.model_validate(result)
return None else:
return None
except Exception as e: except Exception as e:
print(f"Error creating tool: {e}") print(f"Error creating tool: {e}")
return None return None
def get_file_by_id(self, db: Session, id: str) -> Optional[FileModel]: def get_file_by_id(self, id: str) -> Optional[FileModel]:
try: try:
file = db.get(File, id) with get_session() as db:
return FileModel.model_validate(file) file = db.get(File, id)
return FileModel.model_validate(file)
except: except:
return None return None
def get_files(self, db: Session) -> List[FileModel]: def get_files(self) -> List[FileModel]:
return [FileModel.model_validate(file) for file in db.query(File).all()] with get_session() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
def delete_file_by_id(self, db: Session, id: str) -> bool: def delete_file_by_id(self, id: str) -> bool:
try: try:
db.query(File).filter_by(id=id).delete() with get_session() as db:
db.query(File).filter_by(id=id).delete()
db.commit()
return True return True
except: except:
return False return False
def delete_all_files(self, db: Session) -> bool: def delete_all_files(self) -> bool:
try: try:
db.query(File).delete() with get_session() as db:
db.query(File).delete()
db.commit()
return True return True
except: except:
return False return False
......
...@@ -6,7 +6,7 @@ import logging ...@@ -6,7 +6,7 @@ import logging
from sqlalchemy import Column, String, Text, BigInteger, Boolean from sqlalchemy import Column, String, Text, BigInteger, Boolean
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import JSONField, Base from apps.webui.internal.db import JSONField, Base, get_session
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
...@@ -87,7 +87,7 @@ class FunctionValves(BaseModel): ...@@ -87,7 +87,7 @@ class FunctionValves(BaseModel):
class FunctionsTable: class FunctionsTable:
def insert_new_function( def insert_new_function(
self, db: Session, user_id: str, type: str, form_data: FunctionForm self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]: ) -> Optional[FunctionModel]:
function = FunctionModel( function = FunctionModel(
**{ **{
...@@ -100,57 +100,64 @@ class FunctionsTable: ...@@ -100,57 +100,64 @@ class FunctionsTable:
) )
try: try:
result = Function(**function.model_dump()) with get_session() as db:
db.add(result) result = Function(**function.model_dump())
db.commit() db.add(result)
db.refresh(result) db.commit()
if result: db.refresh(result)
return FunctionModel.model_validate(result) if result:
else: return FunctionModel.model_validate(result)
return None else:
return None
except Exception as e: except Exception as e:
print(f"Error creating tool: {e}") print(f"Error creating tool: {e}")
return None return None
def get_function_by_id(self, db: Session, id: str) -> Optional[FunctionModel]: def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try: try:
function = db.get(Function, id) with get_session() as db:
return FunctionModel.model_validate(function) function = db.get(Function, id)
return FunctionModel.model_validate(function)
except: except:
return None return None
def get_functions(self, active_only=False) -> List[FunctionModel]: def get_functions(self, active_only=False) -> List[FunctionModel]:
if active_only: if active_only:
return [ with get_session() as db:
FunctionModel(**model_to_dict(function)) return [
for function in Function.select().where(Function.is_active == True) FunctionModel.model_validate(function)
] for function in db.query(Function).filter_by(is_active=True).all()
]
else: else:
return [ with get_session() as db:
FunctionModel(**model_to_dict(function)) return [
for function in Function.select() FunctionModel.model_validate(function)
] for function in db.query(Function).all()
]
def get_functions_by_type( def get_functions_by_type(
self, type: str, active_only=False self, type: str, active_only=False
) -> List[FunctionModel]: ) -> List[FunctionModel]:
if active_only: if active_only:
return [ with get_session() as db:
FunctionModel(**model_to_dict(function)) return [
for function in Function.select().where( FunctionModel.model_validate(function)
Function.type == type, Function.is_active == True for function in db.query(Function).filter_by(
) type=type, is_active=True
] ).all()
]
else: else:
return [ with get_session() as db:
FunctionModel(**model_to_dict(function)) return [
for function in Function.select().where(Function.type == type) FunctionModel.model_validate(function)
] for function in db.query(Function).filter_by(type=type).all()
]
def get_function_valves_by_id(self, id: str) -> Optional[dict]: def get_function_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
function = Function.get(Function.id == id) with get_session() as db:
return function.valves if function.valves else {} function = db.get(Function, id)
return function.valves if function.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None
...@@ -159,14 +166,12 @@ class FunctionsTable: ...@@ -159,14 +166,12 @@ class FunctionsTable:
self, id: str, valves: dict self, id: str, valves: dict
) -> Optional[FunctionValves]: ) -> Optional[FunctionValves]:
try: try:
query = Function.update( with get_session() as db:
**{"valves": valves}, db.query(Function).filter_by(id=id).update(
updated_at=int(time.time()), {"valves": valves, "updated_at": int(time.time())}
).where(Function.id == id) )
query.execute() db.commit()
return self.get_function_by_id(id)
function = Function.get(Function.id == id)
return FunctionValves(**model_to_dict(function))
except: except:
return None return None
...@@ -214,30 +219,32 @@ class FunctionsTable: ...@@ -214,30 +219,32 @@ class FunctionsTable:
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]: def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try: try:
db.query(Function).filter_by(id=id).update({ with get_session() as db:
**updated, db.query(Function).filter_by(id=id).update({
"updated_at": int(time.time()), **updated,
}) "updated_at": int(time.time()),
return self.get_function_by_id(db, id) })
db.commit()
return self.get_function_by_id(id)
except: except:
return None return None
def deactivate_all_functions(self) -> Optional[bool]: def deactivate_all_functions(self) -> Optional[bool]:
try: try:
query = Function.update( with get_session() as db:
**{"is_active": False}, db.query(Function).update({
updated_at=int(time.time()), "is_active": False,
) "updated_at": int(time.time()),
})
query.execute() db.commit()
return True return True
except: except:
return None return None
def delete_function_by_id(self, db: Session, id: str) -> bool: def delete_function_by_id(self, id: str) -> bool:
try: try:
db.query(Function).filter_by(id=id).delete() with get_session() as db:
db.query(Function).filter_by(id=id).delete()
return True return True
except: except:
return False return False
......
...@@ -4,7 +4,7 @@ from typing import List, Union, Optional ...@@ -4,7 +4,7 @@ from typing import List, Union, Optional
from sqlalchemy import Column, String, BigInteger from sqlalchemy import Column, String, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
import time import time
...@@ -44,7 +44,6 @@ class MemoriesTable: ...@@ -44,7 +44,6 @@ class MemoriesTable:
def insert_new_memory( def insert_new_memory(
self, self,
db: Session,
user_id: str, user_id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
...@@ -59,53 +58,59 @@ class MemoriesTable: ...@@ -59,53 +58,59 @@ class MemoriesTable:
"updated_at": int(time.time()), "updated_at": int(time.time()),
} }
) )
result = Memory(**memory.dict()) with get_session() as db:
db.add(result) result = Memory(**memory.model_dump())
db.commit() db.add(result)
db.refresh(result) db.commit()
if result: db.refresh(result)
return MemoryModel.model_validate(result) if result:
else: return MemoryModel.model_validate(result)
return None else:
return None
def update_memory_by_id( def update_memory_by_id(
self, self,
db: Session,
id: str, id: str,
content: str, content: str,
) -> Optional[MemoryModel]: ) -> Optional[MemoryModel]:
try: try:
db.query(Memory).filter_by(id=id).update( with get_session() as db:
{"content": content, "updated_at": int(time.time())} db.query(Memory).filter_by(id=id).update(
) {"content": content, "updated_at": int(time.time())}
return self.get_memory_by_id(db, id) )
db.commit()
return self.get_memory_by_id(id)
except: except:
return None return None
def get_memories(self, db: Session) -> List[MemoryModel]: def get_memories(self) -> List[MemoryModel]:
try: try:
memories = db.query(Memory).all() with get_session() as db:
return [MemoryModel.model_validate(memory) for memory in memories] memories = db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except: except:
return None return None
def get_memories_by_user_id(self, db: Session, user_id: str) -> List[MemoryModel]: def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
try: try:
memories = db.query(Memory).filter_by(user_id=user_id).all() with get_session() as db:
return [MemoryModel.model_validate(memory) for memory in memories] memories = db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except: except:
return None return None
def get_memory_by_id(self, db: Session, id: str) -> Optional[MemoryModel]: def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
try: try:
memory = db.get(Memory, id) with get_session() as db:
return MemoryModel.model_validate(memory) memory = db.get(Memory, id)
return MemoryModel.model_validate(memory)
except: except:
return None return None
def delete_memory_by_id(self, db: Session, id: str) -> bool: def delete_memory_by_id(self, id: str) -> bool:
try: try:
db.query(Memory).filter_by(id=id).delete() with get_session() as db:
db.query(Memory).filter_by(id=id).delete()
return True return True
except: except:
...@@ -113,7 +118,8 @@ class MemoriesTable: ...@@ -113,7 +118,8 @@ class MemoriesTable:
def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool: def delete_memories_by_user_id(self, db: Session, user_id: str) -> bool:
try: try:
db.query(Memory).filter_by(user_id=user_id).delete() with get_session() as db:
db.query(Memory).filter_by(user_id=user_id).delete()
return True return True
except: except:
return False return False
...@@ -122,7 +128,8 @@ class MemoriesTable: ...@@ -122,7 +128,8 @@ class MemoriesTable:
self, db: Session, id: str, user_id: str self, db: Session, id: str, user_id: str
) -> bool: ) -> bool:
try: try:
db.query(Memory).filter_by(id=id, user_id=user_id).delete() with get_session() as db:
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
return True return True
except: except:
return False return False
......
...@@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict ...@@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, JSONField from apps.webui.internal.db import Base, JSONField, get_session
from typing import List, Union, Optional from typing import List, Union, Optional
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -78,8 +78,6 @@ class Model(Base): ...@@ -78,8 +78,6 @@ class Model(Base):
class ModelModel(BaseModel): class ModelModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str id: str
user_id: str user_id: str
base_model_id: Optional[str] = None base_model_id: Optional[str] = None
...@@ -91,6 +89,8 @@ class ModelModel(BaseModel): ...@@ -91,6 +89,8 @@ class ModelModel(BaseModel):
updated_at: int # timestamp in epoch updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -116,7 +116,7 @@ class ModelForm(BaseModel): ...@@ -116,7 +116,7 @@ class ModelForm(BaseModel):
class ModelsTable: class ModelsTable:
def insert_new_model( def insert_new_model(
self, db: Session, form_data: ModelForm, user_id: str self, form_data: ModelForm, user_id: str
) -> Optional[ModelModel]: ) -> Optional[ModelModel]:
model = ModelModel( model = ModelModel(
**{ **{
...@@ -127,47 +127,52 @@ class ModelsTable: ...@@ -127,47 +127,52 @@ class ModelsTable:
} }
) )
try: try:
result = Model(**model.dict()) with get_session() as db:
db.add(result) result = Model(**model.model_dump())
db.commit() db.add(result)
db.refresh(result) db.commit()
db.refresh(result)
if result:
return ModelModel.model_validate(result) if result:
else: return ModelModel.model_validate(result)
return None else:
return None
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
def get_all_models(self, db: Session) -> List[ModelModel]: def get_all_models(self) -> List[ModelModel]:
return [ModelModel.model_validate(model) for model in db.query(Model).all()] with get_session() as db:
return [ModelModel.model_validate(model) for model in db.query(Model).all()]
def get_model_by_id(self, db: Session, id: str) -> Optional[ModelModel]: def get_model_by_id(self, id: str) -> Optional[ModelModel]:
try: try:
model = db.get(Model, id) with get_session() as db:
return ModelModel.model_validate(model) model = db.get(Model, id)
return ModelModel.model_validate(model)
except: except:
return None return None
def update_model_by_id( def update_model_by_id(
self, db: Session, id: str, model: ModelForm self, id: str, model: ModelForm
) -> Optional[ModelModel]: ) -> Optional[ModelModel]:
try: try:
# update only the fields that are present in the model # update only the fields that are present in the model
model = db.query(Model).get(id) with get_session() as db:
model.update(**model.model_dump()) model = db.query(Model).get(id)
db.commit() model.update(**model.model_dump())
db.refresh(model) db.commit()
return ModelModel.model_validate(model) db.refresh(model)
return ModelModel.model_validate(model)
except Exception as e: except Exception as e:
print(e) print(e)
return None return None
def delete_model_by_id(self, db: Session, id: str) -> bool: def delete_model_by_id(self, id: str) -> bool:
try: try:
db.query(Model).filter_by(id=id).delete() with get_session() as db:
db.query(Model).filter_by(id=id).delete()
return True return True
except: except:
return False return False
......
...@@ -5,7 +5,7 @@ import time ...@@ -5,7 +5,7 @@ import time
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
import json import json
...@@ -48,61 +48,65 @@ class PromptForm(BaseModel): ...@@ -48,61 +48,65 @@ class PromptForm(BaseModel):
class PromptsTable: class PromptsTable:
def insert_new_prompt( def insert_new_prompt(
self, db: Session, user_id: str, form_data: PromptForm self, user_id: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
prompt = PromptModel( with get_session() as db:
**{ prompt = PromptModel(
"user_id": user_id, **{
"command": form_data.command, "user_id": user_id,
"title": form_data.title, "command": form_data.command,
"content": form_data.content, "title": form_data.title,
"timestamp": int(time.time()), "content": form_data.content,
} "timestamp": int(time.time()),
) }
)
try:
result = Prompt(**prompt.dict()) try:
db.add(result) result = Prompt(**prompt.dict())
db.commit() db.add(result)
db.refresh(result) db.commit()
if result: db.refresh(result)
return PromptModel.model_validate(result) if result:
else: return PromptModel.model_validate(result)
else:
return None
except Exception as e:
return None return None
except Exception as e:
return None
def get_prompt_by_command(self, db: Session, command: str) -> Optional[PromptModel]: def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try: with get_session() as db:
prompt = db.query(Prompt).filter_by(command=command).first() try:
return PromptModel.model_validate(prompt) prompt = db.query(Prompt).filter_by(command=command).first()
except: return PromptModel.model_validate(prompt)
return None except:
return None
def get_prompts(self, db: Session) -> List[PromptModel]: def get_prompts(self) -> List[PromptModel]:
return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()] with get_session() as db:
return [PromptModel.model_validate(prompt) for prompt in db.query(Prompt).all()]
def update_prompt_by_command( def update_prompt_by_command(
self, db: Session, command: str, form_data: PromptForm self, command: str, form_data: PromptForm
) -> Optional[PromptModel]: ) -> Optional[PromptModel]:
try: with get_session() as db:
db.query(Prompt).filter_by(command=command).update( try:
{ prompt = db.query(Prompt).filter_by(command=command).first()
"title": form_data.title, prompt.title = form_data.title
"content": form_data.content, prompt.content = form_data.content
"timestamp": int(time.time()), prompt.timestamp = int(time.time())
} db.commit()
) return prompt
return self.get_prompt_by_command(db, command) # return self.get_prompt_by_command(command)
except: except:
return None return None
def delete_prompt_by_command(self, db: Session, command: str) -> bool: def delete_prompt_by_command(self, command: str) -> bool:
try: with get_session() as db:
db.query(Prompt).filter_by(command=command).delete() try:
return True db.query(Prompt).filter_by(command=command).delete()
except: return True
return False except:
return False
Prompts = PromptsTable() Prompts = PromptsTable()
...@@ -9,7 +9,7 @@ import logging ...@@ -9,7 +9,7 @@ import logging
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base from apps.webui.internal.db import Base, get_session
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -80,37 +80,39 @@ class ChatTagsResponse(BaseModel): ...@@ -80,37 +80,39 @@ class ChatTagsResponse(BaseModel):
class TagTable: class TagTable:
def insert_new_tag( def insert_new_tag(
self, db: Session, name: str, user_id: str self, name: str, user_id: str
) -> Optional[TagModel]: ) -> Optional[TagModel]:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name}) tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try: try:
result = Tag(**tag.dict()) with get_session() as db:
db.add(result) result = Tag(**tag.model_dump())
db.commit() db.add(result)
db.refresh(result) db.commit()
if result: db.refresh(result)
return TagModel.model_validate(result) if result:
else: return TagModel.model_validate(result)
return None else:
return None
except Exception as e: except Exception as e:
return None return None
def get_tag_by_name_and_user_id( def get_tag_by_name_and_user_id(
self, db: Session, name: str, user_id: str self, name: str, user_id: str
) -> Optional[TagModel]: ) -> Optional[TagModel]:
try: try:
tag = db.query(Tag).filter(name=name, user_id=user_id).first() with get_session() as db:
return TagModel.model_validate(tag) tag = db.query(Tag).filter(name=name, user_id=user_id).first()
return TagModel.model_validate(tag)
except Exception as e: except Exception as e:
return None return None
def add_tag_to_chat( def add_tag_to_chat(
self, db: Session, user_id: str, form_data: ChatIdTagForm self, user_id: str, form_data: ChatIdTagForm
) -> Optional[ChatIdTagModel]: ) -> Optional[ChatIdTagModel]:
tag = self.get_tag_by_name_and_user_id(db, form_data.tag_name, user_id) tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
if tag == None: if tag == None:
tag = self.insert_new_tag(db, form_data.tag_name, user_id) tag = self.insert_new_tag(form_data.tag_name, user_id)
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chatIdTag = ChatIdTagModel( chatIdTag = ChatIdTagModel(
...@@ -123,118 +125,127 @@ class TagTable: ...@@ -123,118 +125,127 @@ class TagTable:
} }
) )
try: try:
result = ChatIdTag(**chatIdTag.dict()) with get_session() as db:
db.add(result) result = ChatIdTag(**chatIdTag.model_dump())
db.commit() db.add(result)
db.refresh(result) db.commit()
if result: db.refresh(result)
return ChatIdTagModel.model_validate(result) if result:
else: return ChatIdTagModel.model_validate(result)
return None else:
return None
except: except:
return None return None
def get_tags_by_user_id(self, db: Session, user_id: str) -> List[TagModel]: def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
tag_names = [ with get_session() as db:
chat_id_tag.tag_name tag_names = [
for chat_id_tag in ( chat_id_tag.tag_name
db.query(ChatIdTag) for chat_id_tag in (
.filter_by(user_id=user_id) db.query(ChatIdTag)
.order_by(ChatIdTag.timestamp.desc()) .filter_by(user_id=user_id)
.all() .order_by(ChatIdTag.timestamp.desc())
) .all()
] )
]
return [
TagModel.model_validate(tag) return [
for tag in ( TagModel.model_validate(tag)
db.query(Tag) for tag in (
.filter_by(user_id=user_id) db.query(Tag)
.filter(Tag.name.in_(tag_names)) .filter_by(user_id=user_id)
.all() .filter(Tag.name.in_(tag_names))
) .all()
] )
]
def get_tags_by_chat_id_and_user_id( def get_tags_by_chat_id_and_user_id(
self, db: Session, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> List[TagModel]: ) -> List[TagModel]:
tag_names = [ with get_session() as db:
chat_id_tag.tag_name tag_names = [
for chat_id_tag in ( chat_id_tag.tag_name
db.query(ChatIdTag) for chat_id_tag in (
.filter_by(user_id=user_id, chat_id=chat_id) db.query(ChatIdTag)
.order_by(ChatIdTag.timestamp.desc()) .filter_by(user_id=user_id, chat_id=chat_id)
.all() .order_by(ChatIdTag.timestamp.desc())
) .all()
] )
]
return [
TagModel.model_validate(tag) return [
for tag in ( TagModel.model_validate(tag)
db.query(Tag) for tag in (
.filter_by(user_id=user_id) db.query(Tag)
.filter(Tag.name.in_(tag_names)) .filter_by(user_id=user_id)
.all() .filter(Tag.name.in_(tag_names))
) .all()
] )
]
def get_chat_ids_by_tag_name_and_user_id( def get_chat_ids_by_tag_name_and_user_id(
self, db: Session, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> List[ChatIdTagModel]: ) -> List[ChatIdTagModel]:
return [ with get_session() as db:
ChatIdTagModel.model_validate(chat_id_tag) return [
for chat_id_tag in ( ChatIdTagModel.model_validate(chat_id_tag)
db.query(ChatIdTag) for chat_id_tag in (
.filter_by(user_id=user_id, tag_name=tag_name) db.query(ChatIdTag)
.order_by(ChatIdTag.timestamp.desc()) .filter_by(user_id=user_id, tag_name=tag_name)
.all() .order_by(ChatIdTag.timestamp.desc())
) .all()
] )
]
def count_chat_ids_by_tag_name_and_user_id( def count_chat_ids_by_tag_name_and_user_id(
self, db: Session, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> int: ) -> int:
return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count() with get_session() as db:
return db.query(ChatIdTag).filter_by(tag_name=tag_name, user_id=user_id).count()
def delete_tag_by_tag_name_and_user_id( def delete_tag_by_tag_name_and_user_id(
self, db: Session, tag_name: str, user_id: str self, tag_name: str, user_id: str
) -> bool: ) -> bool:
try: try:
res = ( with get_session() as db:
db.query(ChatIdTag) res = (
.filter_by(tag_name=tag_name, user_id=user_id) db.query(ChatIdTag)
.delete() .filter_by(tag_name=tag_name, user_id=user_id)
) .delete()
log.debug(f"res: {res}") )
log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id( db.commit()
db, tag_name, user_id
) tag_count = self.count_chat_ids_by_tag_name_and_user_id(
if tag_count == 0: tag_name, user_id
# Remove tag item from Tag col as well )
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
return True return True
except Exception as e: except Exception as e:
log.error(f"delete_tag: {e}") log.error(f"delete_tag: {e}")
return False return False
def delete_tag_by_tag_name_and_chat_id_and_user_id( def delete_tag_by_tag_name_and_chat_id_and_user_id(
self, db: Session, tag_name: str, chat_id: str, user_id: str self, tag_name: str, chat_id: str, user_id: str
) -> bool: ) -> bool:
try: try:
res = ( with get_session() as db:
db.query(ChatIdTag) res = (
.filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id) db.query(ChatIdTag)
.delete() .filter_by(tag_name=tag_name, chat_id=chat_id, user_id=user_id)
) .delete()
log.debug(f"res: {res}") )
log.debug(f"res: {res}")
tag_count = self.count_chat_ids_by_tag_name_and_user_id( db.commit()
db, tag_name, user_id
) tag_count = self.count_chat_ids_by_tag_name_and_user_id(
if tag_count == 0: tag_name, user_id
# Remove tag item from Tag col as well )
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete() if tag_count == 0:
# Remove tag item from Tag col as well
db.query(Tag).filter_by(name=tag_name, user_id=user_id).delete()
return True return True
except Exception as e: except Exception as e:
...@@ -242,13 +253,13 @@ class TagTable: ...@@ -242,13 +253,13 @@ class TagTable:
return False return False
def delete_tags_by_chat_id_and_user_id( def delete_tags_by_chat_id_and_user_id(
self, db: Session, chat_id: str, user_id: str self, chat_id: str, user_id: str
) -> bool: ) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(db, chat_id, user_id) tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
for tag in tags: for tag in tags:
self.delete_tag_by_tag_name_and_chat_id_and_user_id( self.delete_tag_by_tag_name_and_chat_id_and_user_id(
db, tag.tag_name, chat_id, user_id tag.tag_name, chat_id, user_id
) )
return True return True
......
...@@ -5,7 +5,7 @@ import logging ...@@ -5,7 +5,7 @@ import logging
from sqlalchemy import String, Column, BigInteger from sqlalchemy import String, Column, BigInteger
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from apps.webui.internal.db import Base, JSONField from apps.webui.internal.db import Base, JSONField, get_session
from apps.webui.models.users import Users from apps.webui.models.users import Users
import json import json
...@@ -82,7 +82,7 @@ class ToolValves(BaseModel): ...@@ -82,7 +82,7 @@ class ToolValves(BaseModel):
class ToolsTable: class ToolsTable:
def insert_new_tool( def insert_new_tool(
self, db: Session, user_id: str, form_data: ToolForm, specs: List[dict] self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]: ) -> Optional[ToolModel]:
tool = ToolModel( tool = ToolModel(
**{ **{
...@@ -95,46 +95,48 @@ class ToolsTable: ...@@ -95,46 +95,48 @@ class ToolsTable:
) )
try: try:
result = Tool(**tool.dict()) with get_session() as db:
db.add(result) result = Tool(**tool.model_dump())
db.commit() db.add(result)
db.refresh(result) db.commit()
if result: db.refresh(result)
return ToolModel.model_validate(result) if result:
else: return ToolModel.model_validate(result)
return None else:
return None
except Exception as e: except Exception as e:
print(f"Error creating tool: {e}") print(f"Error creating tool: {e}")
return None return None
def get_tool_by_id(self, db: Session, id: str) -> Optional[ToolModel]: def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try: try:
tool = db.get(Tool, id) with get_session() as db:
return ToolModel.model_validate(tool) tool = db.get(Tool, id)
return ToolModel.model_validate(tool)
except: except:
return None return None
def get_tools(self, db: Session) -> List[ToolModel]: def get_tools(self) -> List[ToolModel]:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()] with get_session() as db:
return [ToolModel.model_validate(tool) for tool in db.query(Tool).all()]
def get_tool_valves_by_id(self, id: str) -> Optional[dict]: def get_tool_valves_by_id(self, id: str) -> Optional[dict]:
try: try:
tool = Tool.get(Tool.id == id) with get_session() as db:
return tool.valves if tool.valves else {} tool = db.get(Tool, id)
return tool.valves if tool.valves else {}
except Exception as e: except Exception as e:
print(f"An error occurred: {e}") print(f"An error occurred: {e}")
return None return None
def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]: def update_tool_valves_by_id(self, id: str, valves: dict) -> Optional[ToolValves]:
try: try:
query = Tool.update( with get_session() as db:
**{"valves": valves}, db.query(Tool).filter_by(id=id).update(
updated_at=int(time.time()), {"valves": valves, "updated_at": int(time.time())}
).where(Tool.id == id) )
query.execute() db.commit()
return self.get_tool_by_id(id)
tool = Tool.get(Tool.id == id)
return ToolValves(**model_to_dict(tool))
except: except:
return None return None
...@@ -172,8 +174,7 @@ class ToolsTable: ...@@ -172,8 +174,7 @@ class ToolsTable:
user_settings["tools"]["valves"][id] = valves user_settings["tools"]["valves"][id] = valves
# Update the user settings in the database # Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings}) Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
return user_settings["tools"]["valves"][id] return user_settings["tools"]["valves"][id]
except Exception as e: except Exception as e:
...@@ -182,16 +183,19 @@ class ToolsTable: ...@@ -182,16 +183,19 @@ class ToolsTable:
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]: def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try: try:
db.query(Tool).filter_by(id=id).update( with get_session() as db:
{**updated, "updated_at": int(time.time())} db.query(Tool).filter_by(id=id).update(
) {**updated, "updated_at": int(time.time())}
return self.get_tool_by_id(db, id) )
db.commit()
return self.get_tool_by_id(id)
except: except:
return None return None
def delete_tool_by_id(self, db: Session, id: str) -> bool: def delete_tool_by_id(self, id: str) -> bool:
try: try:
db.query(Tool).filter_by(id=id).delete() with get_session() as db:
db.query(Tool).filter_by(id=id).delete()
return True return True
except: except:
return False return False
......
...@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session ...@@ -7,7 +7,7 @@ from sqlalchemy.orm import Session
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.webui.internal.db import Base, JSONField from apps.webui.internal.db import Base, JSONField, get_session
from apps.webui.models.chats import Chats from apps.webui.models.chats import Chats
#################### ####################
...@@ -42,8 +42,6 @@ class UserSettings(BaseModel): ...@@ -42,8 +42,6 @@ class UserSettings(BaseModel):
class UserModel(BaseModel): class UserModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str id: str
name: str name: str
email: str email: str
...@@ -60,6 +58,8 @@ class UserModel(BaseModel): ...@@ -60,6 +58,8 @@ class UserModel(BaseModel):
oauth_sub: Optional[str] = None oauth_sub: Optional[str] = None
model_config = ConfigDict(from_attributes=True)
#################### ####################
# Forms # Forms
...@@ -82,7 +82,6 @@ class UsersTable: ...@@ -82,7 +82,6 @@ class UsersTable:
def insert_new_user( def insert_new_user(
self, self,
db: Session,
id: str, id: str,
name: str, name: str,
email: str, email: str,
...@@ -90,165 +89,181 @@ class UsersTable: ...@@ -90,165 +89,181 @@ class UsersTable:
role: str = "pending", role: str = "pending",
oauth_sub: Optional[str] = None, oauth_sub: Optional[str] = None,
) -> Optional[UserModel]: ) -> Optional[UserModel]:
user = UserModel( with get_session() as db:
**{ user = UserModel(
"id": id, **{
"name": name, "id": id,
"email": email, "name": name,
"role": role, "email": email,
"profile_image_url": profile_image_url, "role": role,
"last_active_at": int(time.time()), "profile_image_url": profile_image_url,
"created_at": int(time.time()), "last_active_at": int(time.time()),
"updated_at": int(time.time()), "created_at": int(time.time()),
"oauth_sub": oauth_sub, "updated_at": int(time.time()),
} "oauth_sub": oauth_sub,
) }
result = User(**user.model_dump()) )
db.add(result) result = User(**user.model_dump())
db.commit() db.add(result)
db.refresh(result) db.commit()
if result: db.refresh(result)
return user if result:
else: return user
return None else:
return None
def get_user_by_id(self, db: Session, id: str) -> Optional[UserModel]:
try: def get_user_by_id(self, id: str) -> Optional[UserModel]:
user = db.query(User).filter_by(id=id).first() with get_session() as db:
return UserModel.model_validate(user) try:
except Exception as e: user = db.query(User).filter_by(id=id).first()
return None return UserModel.model_validate(user)
except Exception as e:
def get_user_by_api_key(self, db: Session, api_key: str) -> Optional[UserModel]: return None
try:
user = db.query(User).filter_by(api_key=api_key).first() def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
return UserModel.model_validate(user) with get_session() as db:
except: try:
return None user = db.query(User).filter_by(api_key=api_key).first()
return UserModel.model_validate(user)
def get_user_by_email(self, db: Session, email: str) -> Optional[UserModel]: except:
try: return None
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user) def get_user_by_email(self, email: str) -> Optional[UserModel]:
except: with get_session() as db:
return None try:
user = db.query(User).filter_by(email=email).first()
return UserModel.model_validate(user)
except:
return None
def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]: def get_user_by_oauth_sub(self, sub: str) -> Optional[UserModel]:
try: with get_session() as db:
user = User.get(User.oauth_sub == sub) try:
return UserModel(**model_to_dict(user)) user = db.query(User).filter_by(oauth_sub=sub).first()
except: return UserModel.model_validate(user)
return None except:
return None
def get_users(self, db: Session, skip: int = 0, limit: int = 50) -> List[UserModel]:
users = ( def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
db.query(User) with get_session() as db:
# .offset(skip).limit(limit) users = (
.all() db.query(User)
) # .offset(skip).limit(limit)
return [UserModel.model_validate(user) for user in users] .all()
)
def get_num_users(self, db: Session) -> Optional[int]: return [UserModel.model_validate(user) for user in users]
return db.query(User).count()
def get_num_users(self) -> Optional[int]:
def get_first_user(self, db: Session) -> UserModel: with get_session() as db:
try: return db.query(User).count()
user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user) def get_first_user(self) -> UserModel:
except: with get_session() as db:
return None try:
user = db.query(User).order_by(User.created_at).first()
return UserModel.model_validate(user)
except:
return None
def update_user_role_by_id( def update_user_role_by_id(
self, db: Session, id: str, role: str self, id: str, role: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: with get_session() as db:
db.query(User).filter_by(id=id).update({"role": role}) try:
db.commit() db.query(User).filter_by(id=id).update({"role": role})
db.commit()
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_profile_image_url_by_id( def update_user_profile_image_url_by_id(
self, db: Session, id: str, profile_image_url: str self, id: str, profile_image_url: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: with get_session() as db:
db.query(User).filter_by(id=id).update( try:
{"profile_image_url": profile_image_url} db.query(User).filter_by(id=id).update(
) {"profile_image_url": profile_image_url}
db.commit() )
db.commit()
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_last_active_by_id( def update_user_last_active_by_id(
self, db: Session, id: str self, id: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: with get_session() as db:
db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())}) try:
db.query(User).filter_by(id=id).update({"last_active_at": int(time.time())})
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_oauth_sub_by_id( def update_user_oauth_sub_by_id(
self, db: Session, id: str, oauth_sub: str self, id: str, oauth_sub: str
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: with get_session() as db:
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub}) try:
db.query(User).filter_by(id=id).update({"oauth_sub": oauth_sub})
user = db.query(User).filter_by(id=id).first() user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user) return UserModel.model_validate(user)
except: except:
return None return None
def update_user_by_id( def update_user_by_id(
self, db: Session, id: str, updated: dict self, id: str, updated: dict
) -> Optional[UserModel]: ) -> Optional[UserModel]:
try: with get_session() as db:
db.query(User).filter_by(id=id).update(updated) try:
db.commit() db.query(User).filter_by(id=id).update(updated)
db.commit()
user = db.query(User).filter_by(id=id).first()
return UserModel.model_validate(user)
# return UserModel(**user.dict())
except Exception as e:
return None
def delete_user_by_id(self, db: Session, id: str) -> bool: user = db.query(User).filter_by(id=id).first()
try: return UserModel.model_validate(user)
# Delete User Chats # return UserModel(**user.dict())
result = Chats.delete_chats_by_user_id(db, id) except Exception as e:
return None
def delete_user_by_id(self, id: str) -> bool:
with get_session() as db:
try:
# Delete User Chats
result = Chats.delete_chats_by_user_id(id)
if result:
# Delete User
db.query(User).filter_by(id=id).delete()
db.commit()
return True
else:
return False
except:
return False
if result: def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
# Delete User with get_session() as db:
db.query(User).filter_by(id=id).delete() try:
result = db.query(User).filter_by(id=id).update({"api_key": api_key})
db.commit() db.commit()
return True if result == 1 else False
return True except:
else:
return False return False
except:
return False
def update_user_api_key_by_id(self, db: Session, id: str, api_key: str) -> str: def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try: with get_session() as db:
result = db.query(User).filter_by(id=id).update({"api_key": api_key}) try:
db.commit() user = db.query(User).filter_by(id=id).first()
return True if result == 1 else False return user.api_key
except: except Exception as e:
return False return None
def get_user_api_key_by_id(self, db: Session, id: str) -> Optional[str]:
try:
user = db.query(User).filter_by(id=id).first()
return user.api_key
except Exception as e:
return None
Users = UsersTable() Users = UsersTable()
...@@ -10,7 +10,6 @@ import re ...@@ -10,7 +10,6 @@ import re
import uuid import uuid
import csv import csv
from apps.webui.internal.db import get_db
from apps.webui.models.auths import ( from apps.webui.models.auths import (
SigninForm, SigninForm,
SignupForm, SignupForm,
...@@ -80,12 +79,10 @@ async def get_session_user( ...@@ -80,12 +79,10 @@ async def get_session_user(
@router.post("/update/profile", response_model=UserResponse) @router.post("/update/profile", response_model=UserResponse)
async def update_profile( async def update_profile(
form_data: UpdateProfileForm, form_data: UpdateProfileForm,
session_user=Depends(get_current_user), session_user=Depends(get_current_user)
db=Depends(get_db),
): ):
if session_user: if session_user:
user = Users.update_user_by_id( user = Users.update_user_by_id(
db,
session_user.id, session_user.id,
{"profile_image_url": form_data.profile_image_url, "name": form_data.name}, {"profile_image_url": form_data.profile_image_url, "name": form_data.name},
) )
...@@ -105,17 +102,16 @@ async def update_profile( ...@@ -105,17 +102,16 @@ async def update_profile(
@router.post("/update/password", response_model=bool) @router.post("/update/password", response_model=bool)
async def update_password( async def update_password(
form_data: UpdatePasswordForm, form_data: UpdatePasswordForm,
session_user=Depends(get_current_user), session_user=Depends(get_current_user)
db=Depends(get_db),
): ):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED) raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
if session_user: if session_user:
user = Auths.authenticate_user(db, session_user.email, form_data.password) user = Auths.authenticate_user(session_user.email, form_data.password)
if user: if user:
hashed = get_password_hash(form_data.new_password) hashed = get_password_hash(form_data.new_password)
return Auths.update_user_password_by_id(db, user.id, hashed) return Auths.update_user_password_by_id(user.id, hashed)
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_PASSWORD)
else: else:
...@@ -128,7 +124,7 @@ async def update_password( ...@@ -128,7 +124,7 @@ async def update_password(
@router.post("/signin", response_model=SigninResponse) @router.post("/signin", response_model=SigninResponse)
async def signin(request: Request, response: Response, form_data: SigninForm, db=Depends(get_db)): async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
...@@ -139,34 +135,32 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db ...@@ -139,34 +135,32 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db
trusted_name = request.headers.get( trusted_name = request.headers.get(
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
) )
if not Users.get_user_by_email(db, trusted_email.lower()): if not Users.get_user_by_email(trusted_email.lower()):
await signup( await signup(
request, request,
SignupForm( SignupForm(
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
), ),
db,
) )
user = Auths.authenticate_user_by_trusted_header(db, trusted_email) user = Auths.authenticate_user_by_trusted_header(trusted_email)
elif WEBUI_AUTH == False: elif WEBUI_AUTH == False:
admin_email = "admin@localhost" admin_email = "admin@localhost"
admin_password = "admin" admin_password = "admin"
if Users.get_user_by_email(db, admin_email.lower()): if Users.get_user_by_email(admin_email.lower()):
user = Auths.authenticate_user(db, admin_email.lower(), admin_password) user = Auths.authenticate_user(admin_email.lower(), admin_password)
else: else:
if Users.get_num_users(db) != 0: if Users.get_num_users() != 0:
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS) raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
await signup( await signup(
request, request,
SignupForm(email=admin_email, password=admin_password, name="User"), SignupForm(email=admin_email, password=admin_password, name="User"),
db,
) )
user = Auths.authenticate_user(db, admin_email.lower(), admin_password) user = Auths.authenticate_user(admin_email.lower(), admin_password)
else: else:
user = Auths.authenticate_user(db, form_data.email.lower(), form_data.password) user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user: if user:
token = create_token( token = create_token(
...@@ -200,7 +194,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db ...@@ -200,7 +194,7 @@ async def signin(request: Request, response: Response, form_data: SigninForm, db
@router.post("/signup", response_model=SigninResponse) @router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, response: Response, form_data: SignupForm, db=Depends(get_db)): async def signup(request: Request, response: Response, form_data: SignupForm):
if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH: if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
raise HTTPException( raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
...@@ -211,18 +205,17 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db ...@@ -211,18 +205,17 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
) )
if Users.get_user_by_email(db, form_data.email.lower()): if Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
role = ( role = (
"admin" "admin"
if Users.get_num_users(db) == 0 if Users.get_num_users() == 0
else request.app.state.config.DEFAULT_USER_ROLE else request.app.state.config.DEFAULT_USER_ROLE
) )
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
db,
form_data.email.lower(), form_data.email.lower(),
hashed, hashed,
form_data.name, form_data.name,
...@@ -277,7 +270,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db ...@@ -277,7 +270,7 @@ async def signup(request: Request, response: Response, form_data: SignupForm, db
@router.post("/add", response_model=SigninResponse) @router.post("/add", response_model=SigninResponse)
async def add_user( async def add_user(
form_data: AddUserForm, user=Depends(get_admin_user), db=Depends(get_db) form_data: AddUserForm, user=Depends(get_admin_user)
): ):
if not validate_email_format(form_data.email.lower()): if not validate_email_format(form_data.email.lower()):
...@@ -285,7 +278,7 @@ async def add_user( ...@@ -285,7 +278,7 @@ async def add_user(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
) )
if Users.get_user_by_email(db, form_data.email.lower()): if Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
...@@ -293,7 +286,6 @@ async def add_user( ...@@ -293,7 +286,6 @@ async def add_user(
print(form_data) print(form_data)
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
db,
form_data.email.lower(), form_data.email.lower(),
hashed, hashed,
form_data.name, form_data.name,
...@@ -325,7 +317,7 @@ async def add_user( ...@@ -325,7 +317,7 @@ async def add_user(
@router.get("/admin/details") @router.get("/admin/details")
async def get_admin_details( async def get_admin_details(
request: Request, user=Depends(get_current_user), db=Depends(get_db) request: Request, user=Depends(get_current_user)
): ):
if request.app.state.config.SHOW_ADMIN_DETAILS: if request.app.state.config.SHOW_ADMIN_DETAILS:
admin_email = request.app.state.config.ADMIN_EMAIL admin_email = request.app.state.config.ADMIN_EMAIL
...@@ -334,11 +326,11 @@ async def get_admin_details( ...@@ -334,11 +326,11 @@ async def get_admin_details(
print(admin_email, admin_name) print(admin_email, admin_name)
if admin_email: if admin_email:
admin = Users.get_user_by_email(db, admin_email) admin = Users.get_user_by_email(admin_email)
if admin: if admin:
admin_name = admin.name admin_name = admin.name
else: else:
admin = Users.get_first_user(db) admin = Users.get_first_user()
if admin: if admin:
admin_email = admin.email admin_email = admin.email
admin_name = admin.name admin_name = admin.name
...@@ -411,9 +403,9 @@ async def update_admin_config( ...@@ -411,9 +403,9 @@ async def update_admin_config(
# create api key # create api key
@router.post("/api_key", response_model=ApiKey) @router.post("/api_key", response_model=ApiKey)
async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)): async def create_api_key_(user=Depends(get_current_user)):
api_key = create_api_key() api_key = create_api_key()
success = Users.update_user_api_key_by_id(db, user.id, api_key) success = Users.update_user_api_key_by_id(user.id, api_key)
if success: if success:
return { return {
"api_key": api_key, "api_key": api_key,
...@@ -424,15 +416,15 @@ async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)): ...@@ -424,15 +416,15 @@ async def create_api_key_(user=Depends(get_current_user), db=Depends(get_db)):
# delete api key # delete api key
@router.delete("/api_key", response_model=bool) @router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user), db=Depends(get_db)): async def delete_api_key(user=Depends(get_current_user)):
success = Users.update_user_api_key_by_id(db, user.id, None) success = Users.update_user_api_key_by_id(user.id, None)
return success return success
# get api key # get api key
@router.get("/api_key", response_model=ApiKey) @router.get("/api_key", response_model=ApiKey)
async def get_api_key(user=Depends(get_current_user), db=Depends(get_db)): async def get_api_key(user=Depends(get_current_user)):
api_key = Users.get_user_api_key_by_id(db, user.id) api_key = Users.get_user_api_key_by_id(user.id)
if api_key: if api_key:
return { return {
"api_key": api_key, "api_key": api_key,
......
...@@ -2,7 +2,6 @@ from fastapi import Depends, Request, HTTPException, status ...@@ -2,7 +2,6 @@ from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
from apps.webui.internal.db import get_db
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
...@@ -45,9 +44,9 @@ router = APIRouter() ...@@ -45,9 +44,9 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
@router.get("/list", response_model=List[ChatTitleIdResponse]) @router.get("/list", response_model=List[ChatTitleIdResponse])
async def get_session_user_chat_list( async def get_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db) user=Depends(get_current_user), skip: int = 0, limit: int = 50
): ):
return Chats.get_chat_list_by_user_id(db, user.id, skip, limit) return Chats.get_chat_list_by_user_id(user.id, skip, limit)
############################ ############################
...@@ -57,7 +56,7 @@ async def get_session_user_chat_list( ...@@ -57,7 +56,7 @@ async def get_session_user_chat_list(
@router.delete("/", response_model=bool) @router.delete("/", response_model=bool)
async def delete_all_user_chats( async def delete_all_user_chats(
request: Request, user=Depends(get_current_user), db=Depends(get_db) request: Request, user=Depends(get_current_user)
): ):
if ( if (
...@@ -69,7 +68,7 @@ async def delete_all_user_chats( ...@@ -69,7 +68,7 @@ async def delete_all_user_chats(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
result = Chats.delete_chats_by_user_id(db, user.id) result = Chats.delete_chats_by_user_id(user.id)
return result return result
...@@ -84,10 +83,9 @@ async def get_user_chat_list_by_user_id( ...@@ -84,10 +83,9 @@ async def get_user_chat_list_by_user_id(
user=Depends(get_admin_user), user=Depends(get_admin_user),
skip: int = 0, skip: int = 0,
limit: int = 50, limit: int = 50,
db=Depends(get_db),
): ):
return Chats.get_chat_list_by_user_id( return Chats.get_chat_list_by_user_id(
db, user_id, include_archived=True, skip=skip, limit=limit user_id, include_archived=True, skip=skip, limit=limit
) )
...@@ -98,10 +96,10 @@ async def get_user_chat_list_by_user_id( ...@@ -98,10 +96,10 @@ async def get_user_chat_list_by_user_id(
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat( async def create_new_chat(
form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db) form_data: ChatForm, user=Depends(get_current_user)
): ):
try: try:
chat = Chats.insert_new_chat(db, user.id, form_data) chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
...@@ -116,10 +114,10 @@ async def create_new_chat( ...@@ -116,10 +114,10 @@ async def create_new_chat(
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)): async def get_user_chats(user=Depends(get_current_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats_by_user_id(db, user.id) for chat in Chats.get_chats_by_user_id(user.id)
] ]
...@@ -129,10 +127,10 @@ async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)): ...@@ -129,10 +127,10 @@ async def get_user_chats(user=Depends(get_current_user), db=Depends(get_db)):
@router.get("/all/archived", response_model=List[ChatResponse]) @router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get_db)): async def get_user_archived_chats(user=Depends(get_current_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(db, user.id) for chat in Chats.get_archived_chats_by_user_id(user.id)
] ]
...@@ -142,7 +140,7 @@ async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get ...@@ -142,7 +140,7 @@ async def get_user_archived_chats(user=Depends(get_current_user), db=Depends(get
@router.get("/all/db", response_model=List[ChatResponse]) @router.get("/all/db", response_model=List[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_db)): async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
if not ENABLE_ADMIN_EXPORT: if not ENABLE_ADMIN_EXPORT:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
...@@ -150,7 +148,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_ ...@@ -150,7 +148,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_
) )
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_chats(db) for chat in Chats.get_chats()
] ]
...@@ -161,9 +159,9 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_ ...@@ -161,9 +159,9 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user), db=Depends(get_
@router.get("/archived", response_model=List[ChatTitleIdResponse]) @router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list( async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50, db=Depends(get_db) user=Depends(get_current_user), skip: int = 0, limit: int = 50
): ):
return Chats.get_archived_chat_list_by_user_id(db, user.id, skip, limit) return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################ ############################
...@@ -172,8 +170,8 @@ async def get_archived_session_user_chat_list( ...@@ -172,8 +170,8 @@ async def get_archived_session_user_chat_list(
@router.post("/archive/all", response_model=bool) @router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)): async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(db, user.id) return Chats.archive_all_chats_by_user_id(user.id)
############################ ############################
...@@ -183,7 +181,7 @@ async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)): ...@@ -183,7 +181,7 @@ async def archive_all_chats(user=Depends(get_current_user), db=Depends(get_db)):
@router.get("/share/{share_id}", response_model=Optional[ChatResponse]) @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id( async def get_shared_chat_by_id(
share_id: str, user=Depends(get_current_user), db=Depends(get_db) share_id: str, user=Depends(get_current_user)
): ):
if user.role == "pending": if user.role == "pending":
raise HTTPException( raise HTTPException(
...@@ -191,9 +189,9 @@ async def get_shared_chat_by_id( ...@@ -191,9 +189,9 @@ async def get_shared_chat_by_id(
) )
if user.role == "user": if user.role == "user":
chat = Chats.get_chat_by_share_id(db, share_id) chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin": elif user.role == "admin":
chat = Chats.get_chat_by_id(db, share_id) chat = Chats.get_chat_by_id(share_id)
if chat: if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
...@@ -216,23 +214,23 @@ class TagNameForm(BaseModel): ...@@ -216,23 +214,23 @@ class TagNameForm(BaseModel):
@router.post("/tags", response_model=List[ChatTitleIdResponse]) @router.post("/tags", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name( async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_current_user), db=Depends(get_db) form_data: TagNameForm, user=Depends(get_current_user)
): ):
print(form_data) print(form_data)
chat_ids = [ chat_ids = [
chat_id_tag.chat_id chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id( for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
db, form_data.name, user.id form_data.name, user.id
) )
] ]
chats = Chats.get_chat_list_by_chat_ids( chats = Chats.get_chat_list_by_chat_ids(
db, chat_ids, form_data.skip, form_data.limit chat_ids, form_data.skip, form_data.limit
) )
if len(chats) == 0: if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(db, form_data.name, user.id) Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
return chats return chats
...@@ -243,9 +241,9 @@ async def get_user_chat_list_by_tag_name( ...@@ -243,9 +241,9 @@ async def get_user_chat_list_by_tag_name(
@router.get("/tags/all", response_model=List[TagModel]) @router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)): async def get_all_tags(user=Depends(get_current_user)):
try: try:
tags = Tags.get_tags_by_user_id(db, user.id) tags = Tags.get_tags_by_user_id(user.id)
return tags return tags
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
...@@ -260,8 +258,8 @@ async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)): ...@@ -260,8 +258,8 @@ async def get_all_tags(user=Depends(get_current_user), db=Depends(get_db)):
@router.get("/{id}", response_model=Optional[ChatResponse]) @router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): async def get_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
...@@ -278,13 +276,13 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get ...@@ -278,13 +276,13 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id( async def update_chat_by_id(
id: str, form_data: ChatForm, user=Depends(get_current_user), db=Depends(get_db) id: str, form_data: ChatForm, user=Depends(get_current_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat} updated_chat = {**json.loads(chat.chat), **form_data.chat}
chat = Chats.update_chat_by_id(db, id, updated_chat) chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else: else:
raise HTTPException( raise HTTPException(
...@@ -300,11 +298,11 @@ async def update_chat_by_id( ...@@ -300,11 +298,11 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id( async def delete_chat_by_id(
request: Request, id: str, user=Depends(get_current_user), db=Depends(get_db) request: Request, id: str, user=Depends(get_current_user)
): ):
if user.role == "admin": if user.role == "admin":
result = Chats.delete_chat_by_id(db, id) result = Chats.delete_chat_by_id(id)
return result return result
else: else:
if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]: if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
...@@ -313,7 +311,7 @@ async def delete_chat_by_id( ...@@ -313,7 +311,7 @@ async def delete_chat_by_id(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
result = Chats.delete_chat_by_id_and_user_id(db, id, user.id) result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result return result
...@@ -323,8 +321,8 @@ async def delete_chat_by_id( ...@@ -323,8 +321,8 @@ async def delete_chat_by_id(
@router.get("/{id}/clone", response_model=Optional[ChatResponse]) @router.get("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
chat_body = json.loads(chat.chat) chat_body = json.loads(chat.chat)
...@@ -335,7 +333,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g ...@@ -335,7 +333,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
"title": f"Clone of {chat.title}", "title": f"Clone of {chat.title}",
} }
chat = Chats.insert_new_chat(db, user.id, ChatForm(**{"chat": updated_chat})) chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else: else:
raise HTTPException( raise HTTPException(
...@@ -350,11 +348,11 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g ...@@ -350,11 +348,11 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
@router.get("/{id}/archive", response_model=Optional[ChatResponse]) @router.get("/{id}/archive", response_model=Optional[ChatResponse])
async def archive_chat_by_id( async def archive_chat_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db) id: str, user=Depends(get_current_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
chat = Chats.toggle_chat_archive_by_id(db, id) chat = Chats.toggle_chat_archive_by_id(id)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else: else:
raise HTTPException( raise HTTPException(
...@@ -368,16 +366,16 @@ async def archive_chat_by_id( ...@@ -368,16 +366,16 @@ async def archive_chat_by_id(
@router.post("/{id}/share", response_model=Optional[ChatResponse]) @router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(get_db)): async def share_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
if chat.share_id: if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(db, chat.id) shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
return ChatResponse( return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)} **{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
) )
shared_chat = Chats.insert_shared_chat_by_chat_id(db, chat.id) shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
if not shared_chat: if not shared_chat:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
...@@ -401,15 +399,15 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g ...@@ -401,15 +399,15 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user), db=Depends(g
@router.delete("/{id}/share", response_model=Optional[bool]) @router.delete("/{id}/share", response_model=Optional[bool])
async def delete_shared_chat_by_id( async def delete_shared_chat_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db) id: str, user=Depends(get_current_user)
): ):
chat = Chats.get_chat_by_id_and_user_id(db, id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
if not chat.share_id: if not chat.share_id:
return False return False
result = Chats.delete_shared_chat_by_chat_id(db, id) result = Chats.delete_shared_chat_by_chat_id(id)
update_result = Chats.update_chat_share_id_by_id(db, id, None) update_result = Chats.update_chat_share_id_by_id(id, None)
return result and update_result != None return result and update_result != None
else: else:
...@@ -426,9 +424,9 @@ async def delete_shared_chat_by_id( ...@@ -426,9 +424,9 @@ async def delete_shared_chat_by_id(
@router.get("/{id}/tags", response_model=List[TagModel]) @router.get("/{id}/tags", response_model=List[TagModel])
async def get_chat_tags_by_id( async def get_chat_tags_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db) id: str, user=Depends(get_current_user)
): ):
tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id) tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if tags != None: if tags != None:
return tags return tags
...@@ -447,13 +445,12 @@ async def get_chat_tags_by_id( ...@@ -447,13 +445,12 @@ async def get_chat_tags_by_id(
async def add_chat_tag_by_id( async def add_chat_tag_by_id(
id: str, id: str,
form_data: ChatIdTagForm, form_data: ChatIdTagForm,
user=Depends(get_current_user), user=Depends(get_current_user)
db=Depends(get_db),
): ):
tags = Tags.get_tags_by_chat_id_and_user_id(db, id, user.id) tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if form_data.tag_name not in tags: if form_data.tag_name not in tags:
tag = Tags.add_tag_to_chat(db, user.id, form_data) tag = Tags.add_tag_to_chat(user.id, form_data)
if tag: if tag:
return tag return tag
...@@ -478,10 +475,9 @@ async def delete_chat_tag_by_id( ...@@ -478,10 +475,9 @@ async def delete_chat_tag_by_id(
id: str, id: str,
form_data: ChatIdTagForm, form_data: ChatIdTagForm,
user=Depends(get_current_user), user=Depends(get_current_user),
db=Depends(get_db),
): ):
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id( result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
db, form_data.tag_name, id, user.id form_data.tag_name, id, user.id
) )
if result: if result:
...@@ -499,9 +495,9 @@ async def delete_chat_tag_by_id( ...@@ -499,9 +495,9 @@ async def delete_chat_tag_by_id(
@router.delete("/{id}/tags/all", response_model=Optional[bool]) @router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id( async def delete_all_chat_tags_by_id(
id: str, user=Depends(get_current_user), db=Depends(get_db) id: str, user=Depends(get_current_user)
): ):
result = Tags.delete_tags_by_chat_id_and_user_id(db, id, user.id) result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
if result: if result:
return result return result
......
...@@ -6,7 +6,6 @@ from fastapi import APIRouter ...@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.documents import ( from apps.webui.models.documents import (
Documents, Documents,
DocumentForm, DocumentForm,
...@@ -26,7 +25,7 @@ router = APIRouter() ...@@ -26,7 +25,7 @@ router = APIRouter()
@router.get("/", response_model=List[DocumentResponse]) @router.get("/", response_model=List[DocumentResponse])
async def get_documents(user=Depends(get_current_user), db=Depends(get_db)): async def get_documents(user=Depends(get_current_user)):
docs = [ docs = [
DocumentResponse( DocumentResponse(
**{ **{
...@@ -34,7 +33,7 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)): ...@@ -34,7 +33,7 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
"content": json.loads(doc.content if doc.content else "{}"), "content": json.loads(doc.content if doc.content else "{}"),
} }
) )
for doc in Documents.get_docs(db) for doc in Documents.get_docs()
] ]
return docs return docs
...@@ -46,11 +45,11 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)): ...@@ -46,11 +45,11 @@ async def get_documents(user=Depends(get_current_user), db=Depends(get_db)):
@router.post("/create", response_model=Optional[DocumentResponse]) @router.post("/create", response_model=Optional[DocumentResponse])
async def create_new_doc( async def create_new_doc(
form_data: DocumentForm, user=Depends(get_admin_user), db=Depends(get_db) form_data: DocumentForm, user=Depends(get_admin_user)
): ):
doc = Documents.get_doc_by_name(db, form_data.name) doc = Documents.get_doc_by_name(form_data.name)
if doc == None: if doc == None:
doc = Documents.insert_new_doc(db, user.id, form_data) doc = Documents.insert_new_doc(user.id, form_data)
if doc: if doc:
return DocumentResponse( return DocumentResponse(
...@@ -78,9 +77,9 @@ async def create_new_doc( ...@@ -78,9 +77,9 @@ async def create_new_doc(
@router.get("/doc", response_model=Optional[DocumentResponse]) @router.get("/doc", response_model=Optional[DocumentResponse])
async def get_doc_by_name( async def get_doc_by_name(
name: str, user=Depends(get_current_user), db=Depends(get_db) name: str, user=Depends(get_current_user)
): ):
doc = Documents.get_doc_by_name(db, name) doc = Documents.get_doc_by_name(name)
if doc: if doc:
return DocumentResponse( return DocumentResponse(
...@@ -112,10 +111,10 @@ class TagDocumentForm(BaseModel): ...@@ -112,10 +111,10 @@ class TagDocumentForm(BaseModel):
@router.post("/doc/tags", response_model=Optional[DocumentResponse]) @router.post("/doc/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name( async def tag_doc_by_name(
form_data: TagDocumentForm, user=Depends(get_current_user), db=Depends(get_db) form_data: TagDocumentForm, user=Depends(get_current_user)
): ):
doc = Documents.update_doc_content_by_name( doc = Documents.update_doc_content_by_name(
db, form_data.name, {"tags": form_data.tags} form_data.name, {"tags": form_data.tags}
) )
if doc: if doc:
...@@ -142,9 +141,8 @@ async def update_doc_by_name( ...@@ -142,9 +141,8 @@ async def update_doc_by_name(
name: str, name: str,
form_data: DocumentUpdateForm, form_data: DocumentUpdateForm,
user=Depends(get_admin_user), user=Depends(get_admin_user),
db=Depends(get_db),
): ):
doc = Documents.update_doc_by_name(db, name, form_data) doc = Documents.update_doc_by_name(name, form_data)
if doc: if doc:
return DocumentResponse( return DocumentResponse(
**{ **{
...@@ -166,7 +164,7 @@ async def update_doc_by_name( ...@@ -166,7 +164,7 @@ async def update_doc_by_name(
@router.delete("/doc/delete", response_model=bool) @router.delete("/doc/delete", response_model=bool)
async def delete_doc_by_name( async def delete_doc_by_name(
name: str, user=Depends(get_admin_user), db=Depends(get_db) name: str, user=Depends(get_admin_user)
): ):
result = Documents.delete_doc_by_name(db, name) result = Documents.delete_doc_by_name(name)
return result return result
...@@ -20,7 +20,6 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse ...@@ -20,7 +20,6 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.files import ( from apps.webui.models.files import (
Files, Files,
FileForm, FileForm,
...@@ -53,8 +52,7 @@ router = APIRouter() ...@@ -53,8 +52,7 @@ router = APIRouter()
@router.post("/") @router.post("/")
def upload_file( def upload_file(
file: UploadFile = File(...), file: UploadFile = File(...),
user=Depends(get_verified_user), user=Depends(get_verified_user)
db=Depends(get_db)
): ):
log.info(f"file.content_type: {file.content_type}") log.info(f"file.content_type: {file.content_type}")
try: try:
...@@ -72,7 +70,6 @@ def upload_file( ...@@ -72,7 +70,6 @@ def upload_file(
f.close() f.close()
file = Files.insert_new_file( file = Files.insert_new_file(
db,
user.id, user.id,
FileForm( FileForm(
**{ **{
...@@ -109,8 +106,8 @@ def upload_file( ...@@ -109,8 +106,8 @@ def upload_file(
@router.get("/", response_model=List[FileModel]) @router.get("/", response_model=List[FileModel])
async def list_files(user=Depends(get_verified_user), db=Depends(get_db)): async def list_files(user=Depends(get_verified_user)):
files = Files.get_files(db) files = Files.get_files()
return files return files
...@@ -120,8 +117,8 @@ async def list_files(user=Depends(get_verified_user), db=Depends(get_db)): ...@@ -120,8 +117,8 @@ async def list_files(user=Depends(get_verified_user), db=Depends(get_db)):
@router.delete("/all") @router.delete("/all")
async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)): async def delete_all_files(user=Depends(get_admin_user)):
result = Files.delete_all_files(db) result = Files.delete_all_files()
if result: if result:
folder = f"{UPLOAD_DIR}" folder = f"{UPLOAD_DIR}"
...@@ -157,8 +154,8 @@ async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)): ...@@ -157,8 +154,8 @@ async def delete_all_files(user=Depends(get_admin_user), db=Depends(get_db)):
@router.get("/{id}", response_model=Optional[FileModel]) @router.get("/{id}", response_model=Optional[FileModel])
async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): async def get_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(db, id) file = Files.get_file_by_id(id)
if file: if file:
return file return file
...@@ -175,8 +172,8 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(ge ...@@ -175,8 +172,8 @@ async def get_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(ge
@router.get("/{id}/content", response_model=Optional[FileModel]) @router.get("/{id}/content", response_model=Optional[FileModel])
async def get_file_content_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(db, id) file = Files.get_file_by_id(id)
if file: if file:
file_path = Path(file.meta["path"]) file_path = Path(file.meta["path"])
...@@ -226,11 +223,11 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)): ...@@ -226,11 +223,11 @@ async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
@router.delete("/{id}") @router.delete("/{id}")
async def delete_file_by_id(id: str, user=Depends(get_verified_user), db=Depends(get_db)): async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(db, id) file = Files.get_file_by_id(id)
if file: if file:
result = Files.delete_file_by_id(db, id) result = Files.delete_file_by_id(id)
if result: if result:
return {"message": "File deleted successfully"} return {"message": "File deleted successfully"}
else: else:
......
...@@ -6,7 +6,6 @@ from fastapi import APIRouter ...@@ -6,7 +6,6 @@ from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.webui.internal.db import get_db
from apps.webui.models.functions import ( from apps.webui.models.functions import (
Functions, Functions,
FunctionForm, FunctionForm,
...@@ -32,8 +31,8 @@ router = APIRouter() ...@@ -32,8 +31,8 @@ router = APIRouter()
@router.get("/", response_model=List[FunctionResponse]) @router.get("/", response_model=List[FunctionResponse])
async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)): async def get_functions(user=Depends(get_verified_user)):
return Functions.get_functions(db) return Functions.get_functions()
############################ ############################
...@@ -42,8 +41,8 @@ async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)): ...@@ -42,8 +41,8 @@ async def get_functions(user=Depends(get_verified_user), db=Depends(get_db)):
@router.get("/export", response_model=List[FunctionModel]) @router.get("/export", response_model=List[FunctionModel])
async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)): async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions(db) return Functions.get_functions()
############################ ############################
...@@ -53,7 +52,7 @@ async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)): ...@@ -53,7 +52,7 @@ async def get_functions(user=Depends(get_admin_user), db=Depends(get_db)):
@router.post("/create", response_model=Optional[FunctionResponse]) @router.post("/create", response_model=Optional[FunctionResponse])
async def create_new_function( async def create_new_function(
request: Request, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db) request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
): ):
if not form_data.id.isidentifier(): if not form_data.id.isidentifier():
raise HTTPException( raise HTTPException(
...@@ -63,7 +62,7 @@ async def create_new_function( ...@@ -63,7 +62,7 @@ async def create_new_function(
form_data.id = form_data.id.lower() form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(db, form_data.id) function = Functions.get_function_by_id(form_data.id)
if function == None: if function == None:
function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py") function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
try: try:
...@@ -78,7 +77,7 @@ async def create_new_function( ...@@ -78,7 +77,7 @@ async def create_new_function(
FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module FUNCTIONS[form_data.id] = function_module
function = Functions.insert_new_function(db, user.id, function_type, form_data) function = Functions.insert_new_function(user.id, function_type, form_data)
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True) function_cache_dir.mkdir(parents=True, exist_ok=True)
...@@ -109,8 +108,8 @@ async def create_new_function( ...@@ -109,8 +108,8 @@ async def create_new_function(
@router.get("/id/{id}", response_model=Optional[FunctionModel]) @router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user), db=Depends(get_db)): async def get_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(db, id) function = Functions.get_function_by_id(id)
if function: if function:
return function return function
...@@ -155,7 +154,7 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)): ...@@ -155,7 +154,7 @@ async def toggle_function_by_id(id: str, user=Depends(get_admin_user)):
@router.post("/id/{id}/update", response_model=Optional[FunctionModel]) @router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_function_by_id( async def update_function_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user), db=Depends(get_db) request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
): ):
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py") function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
...@@ -172,7 +171,7 @@ async def update_function_by_id( ...@@ -172,7 +171,7 @@ async def update_function_by_id(
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type} updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
print(updated) print(updated)
function = Functions.update_function_by_id(db, id, updated) function = Functions.update_function_by_id(id, updated)
if function: if function:
return function return function
...@@ -196,9 +195,9 @@ async def update_function_by_id( ...@@ -196,9 +195,9 @@ async def update_function_by_id(
@router.delete("/id/{id}/delete", response_model=bool) @router.delete("/id/{id}/delete", response_model=bool)
async def delete_function_by_id( async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user), db=Depends(get_db) request: Request, id: str, user=Depends(get_admin_user)
): ):
result = Functions.delete_function_by_id(db, id) result = Functions.delete_function_by_id(id)
if result: if result:
FUNCTIONS = request.app.state.FUNCTIONS FUNCTIONS = request.app.state.FUNCTIONS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment