Commit 276b7b90 authored by Jun Siang Cheah's avatar Jun Siang Cheah
Browse files

Merge remote-tracking branch 'upstream/dev' into feat/backend-web-search

parents b1265c9c 7b81271b
......@@ -5,10 +5,10 @@ import uuid
import logging
from peewee import *
from apps.web.models.users import UserModel, Users
from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
from config import SRC_LOG_LEVELS
......
......@@ -7,7 +7,7 @@ import json
import uuid
import time
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
####################
# Chat DB Schema
......@@ -191,6 +191,20 @@ class ChatTable:
except:
return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try:
chats = self.get_chats_by_user_id(user_id)
for chat in chats:
query = Chat.update(
archived=True,
).where(Chat.id == chat.id)
query.execute()
return True
except:
return False
def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:
......@@ -205,8 +219,22 @@ class ChatTable:
]
def get_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
self,
user_id: str,
include_archived: bool = False,
skip: int = 0,
limit: int = 50,
) -> List[ChatModel]:
if include_archived:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
else:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
......
......@@ -8,7 +8,7 @@ import logging
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
import json
......
......@@ -3,8 +3,8 @@ from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from apps.web.internal.db import DB
from apps.web.models.chats import Chats
from apps.webui.internal.db import DB
from apps.webui.models.chats import Chats
import time
import uuid
......
......@@ -8,7 +8,7 @@ from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from apps.web.internal.db import DB, JSONField
from apps.webui.internal.db import DB, JSONField
from typing import List, Union, Optional
from config import SRC_LOG_LEVELS
......
......@@ -7,7 +7,7 @@ import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
import json
......
......@@ -8,7 +8,7 @@ import uuid
import time
import logging
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
from config import SRC_LOG_LEVELS
......
......@@ -5,8 +5,8 @@ from typing import List, Union, Optional
import time
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
from apps.web.models.chats import Chats
from apps.webui.internal.db import DB
from apps.webui.models.chats import Chats
####################
# User DB Schema
......
......@@ -10,7 +10,7 @@ import uuid
import csv
from apps.web.models.auths import (
from apps.webui.models.auths import (
SigninForm,
SignupForm,
AddUserForm,
......@@ -21,7 +21,7 @@ from apps.web.models.auths import (
Auths,
ApiKey,
)
from apps.web.models.users import Users
from apps.webui.models.users import Users
from utils.utils import (
get_password_hash,
......
......@@ -7,8 +7,8 @@ from pydantic import BaseModel
import json
import logging
from apps.web.models.users import Users
from apps.web.models.chats import (
from apps.webui.models.users import Users
from apps.webui.models.chats import (
ChatModel,
ChatResponse,
ChatTitleForm,
......@@ -18,7 +18,7 @@ from apps.web.models.chats import (
)
from apps.web.models.tags import (
from apps.webui.models.tags import (
TagModel,
ChatIdTagModel,
ChatIdTagForm,
......@@ -78,43 +78,25 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_list_by_user_id(user_id, skip, limit)
############################
# GetArchivedChats
############################
@router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
return Chats.get_chat_list_by_user_id(
user_id, include_archived=True, skip=skip, limit=limit
)
############################
# GetSharedChatById
# CreateNewChat
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
@router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
try:
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
......@@ -150,19 +132,49 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
############################
# CreateNewChat
# GetArchivedChats
############################
@router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
try:
chat = Chats.insert_new_chat(user.id, form_data)
@router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################
# ArchiveAllChats
############################
@router.post("/archive/all", response_model=List[ChatTitleIdResponse])
async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(user.id)
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e:
log.exception(e)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
......
......@@ -8,7 +8,7 @@ from pydantic import BaseModel
import time
import uuid
from apps.web.models.users import Users
from apps.webui.models.users import Users
from utils.utils import (
get_password_hash,
......
......@@ -6,7 +6,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.documents import (
from apps.webui.models.documents import (
Documents,
DocumentForm,
DocumentUpdateForm,
......
......@@ -7,7 +7,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import logging
from apps.web.models.memories import Memories, MemoryModel
from apps.webui.models.memories import Memories, MemoryModel
from utils.utils import get_verified_user
from constants import ERROR_MESSAGES
......
......@@ -5,7 +5,7 @@ from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
......@@ -53,7 +53,7 @@ async def add_new_model(
############################
@router.get("/{id}", response_model=Optional[ModelModel])
@router.get("/", response_model=Optional[ModelModel])
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
model = Models.get_model_by_id(id)
......@@ -71,7 +71,7 @@ async def get_model_by_id(id: str, user=Depends(get_verified_user)):
############################
@router.post("/{id}/update", response_model=Optional[ModelModel])
@router.post("/update", response_model=Optional[ModelModel])
async def update_model_by_id(
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
):
......@@ -102,7 +102,7 @@ async def update_model_by_id(
############################
@router.delete("/{id}/delete", response_model=bool)
@router.delete("/delete", response_model=bool)
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
result = Models.delete_model_by_id(id)
return result
......@@ -6,7 +6,7 @@ from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.prompts import Prompts, PromptForm, PromptModel
from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
......
......@@ -9,9 +9,9 @@ import time
import uuid
import logging
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths
from apps.web.models.chats import Chats
from apps.webui.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.webui.models.auths import Auths
from apps.webui.models.chats import Chats
from utils.utils import get_verified_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES
......
......@@ -8,7 +8,7 @@ from pydantic import BaseModel
from fpdf import FPDF
import markdown
from apps.web.internal.db import DB
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url
......
......@@ -27,6 +27,8 @@ from constants import ERROR_MESSAGES
BACKEND_DIR = Path(__file__).parent # the path containing this file
BASE_DIR = BACKEND_DIR.parent # the path containing the backend/
print(BASE_DIR)
try:
from dotenv import load_dotenv, find_dotenv
......@@ -56,7 +58,6 @@ log_sources = [
"CONFIG",
"DB",
"IMAGES",
"LITELLM",
"MAIN",
"MODELS",
"OLLAMA",
......@@ -122,7 +123,10 @@ def parse_section(section):
try:
changelog_content = (BASE_DIR / "CHANGELOG.md").read_text()
changelog_path = BASE_DIR / "CHANGELOG.md"
with open(str(changelog_path.absolute()), "r", encoding="utf8") as file:
changelog_content = file.read()
except:
changelog_content = (pkgutil.get_data("open_webui", "CHANGELOG.md") or b"").decode()
......@@ -374,10 +378,10 @@ def create_config_file(file_path):
LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml"
if not os.path.exists(LITELLM_CONFIG_PATH):
log.info("Config file doesn't exist. Creating...")
create_config_file(LITELLM_CONFIG_PATH)
log.info("Config file created successfully.")
# if not os.path.exists(LITELLM_CONFIG_PATH):
# log.info("Config file doesn't exist. Creating...")
# create_config_file(LITELLM_CONFIG_PATH)
# log.info("Config file created successfully.")
####################################
......@@ -845,18 +849,6 @@ AUDIO_OPENAI_API_VOICE = PersistentConfig(
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
)
####################################
# LiteLLM
####################################
ENABLE_LITELLM = os.environ.get("ENABLE_LITELLM", "True").lower() == "true"
LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365"))
if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
####################################
# Database
......
......@@ -22,23 +22,16 @@ from starlette.responses import StreamingResponse, Response
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
from apps.openai.main import app as openai_app, get_all_models as get_openai_models
from apps.litellm.main import (
app as litellm_app,
start_litellm_background,
shutdown_litellm_background,
)
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app
from apps.webui.main import app as webui_app
import asyncio
from pydantic import BaseModel
from typing import List, Optional
from apps.web.models.models import Models, ModelModel
from apps.webui.models.models import Models, ModelModel
from utils.utils import get_admin_user, get_verified_user
from apps.rag.utils import rag_messages
......@@ -55,7 +48,6 @@ from config import (
STATIC_DIR,
ENABLE_OPENAI_API,
ENABLE_OLLAMA_API,
ENABLE_LITELLM,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
GLOBAL_LOG_LEVEL,
......@@ -101,11 +93,7 @@ https://github.com/open-webui/open-webui
@asynccontextmanager
async def lifespan(app: FastAPI):
if ENABLE_LITELLM:
asyncio.create_task(start_litellm_background())
yield
if ENABLE_LITELLM:
await shutdown_litellm_background()
app = FastAPI(
......@@ -263,9 +251,6 @@ async def update_embedding_function(request: Request, call_next):
return response
# TODO: Deprecate LiteLLM
app.mount("/litellm/api", litellm_app)
app.mount("/ollama", ollama_app)
app.mount("/openai", openai_app)
......@@ -373,13 +358,14 @@ async def get_app_config():
"name": WEBUI_NAME,
"version": VERSION,
"auth": WEBUI_AUTH,
"auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_signup": webui_app.state.config.ENABLE_SIGNUP,
"enable_image_generation": images_app.state.config.ENABLED,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
"default_locale": default_locale,
"images": images_app.state.config.ENABLED,
"default_models": webui_app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"admin_export_enabled": ENABLE_ADMIN_EXPORT,
"websearch": RAG_WEB_SEARCH_ENABLED,
"enable_websearch": RAG_WEB_SEARCH_ENABLED,
}
......@@ -403,15 +389,6 @@ async def update_model_filter_config(
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
app.state.config.MODEL_FILTER_LIST = form_data.models
ollama_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
ollama_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
openai_app.state.config.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
openai_app.state.config.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
return {
"enabled": app.state.config.ENABLE_MODEL_FILTER,
"models": app.state.config.MODEL_FILTER_LIST,
......@@ -432,7 +409,6 @@ class UrlForm(BaseModel):
@app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.config.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
return {
......
......@@ -18,8 +18,6 @@ psycopg2-binary==2.9.9
PyMySQL==1.1.1
bcrypt==4.1.3
litellm[proxy]==1.37.20
boto3==1.34.110
argon2-cffi==23.1.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment