"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "d3ef3a7494ada6ab517bd8afdde234abc5862528"
Unverified Commit 92c98eda authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #1781 from open-webui/dev

0.1.122
parents 1092ee9c 85df019c
...@@ -25,11 +25,11 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -25,11 +25,11 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Document(Model): class Document(Model):
collection_name = CharField(unique=True) collection_name = CharField(unique=True)
name = CharField(unique=True) name = CharField(unique=True)
title = CharField() title = TextField()
filename = CharField() filename = TextField()
content = TextField(null=True) content = TextField(null=True)
user_id = CharField() user_id = CharField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB
......
...@@ -20,7 +20,7 @@ class Modelfile(Model): ...@@ -20,7 +20,7 @@ class Modelfile(Model):
tag_name = CharField(unique=True) tag_name = CharField(unique=True)
user_id = CharField() user_id = CharField()
modelfile = TextField() modelfile = TextField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB
......
...@@ -19,9 +19,9 @@ import json ...@@ -19,9 +19,9 @@ import json
class Prompt(Model): class Prompt(Model):
command = CharField(unique=True) command = CharField(unique=True)
user_id = CharField() user_id = CharField()
title = CharField() title = TextField()
content = TextField() content = TextField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB
......
...@@ -35,7 +35,7 @@ class ChatIdTag(Model): ...@@ -35,7 +35,7 @@ class ChatIdTag(Model):
tag_name = CharField() tag_name = CharField()
chat_id = CharField() chat_id = CharField()
user_id = CharField() user_id = CharField()
timestamp = DateField() timestamp = BigIntegerField()
class Meta: class Meta:
database = DB database = DB
......
...@@ -18,8 +18,12 @@ class User(Model): ...@@ -18,8 +18,12 @@ class User(Model):
name = CharField() name = CharField()
email = CharField() email = CharField()
role = CharField() role = CharField()
profile_image_url = CharField() profile_image_url = TextField()
timestamp = DateField()
last_active_at = BigIntegerField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
api_key = CharField(null=True, unique=True) api_key = CharField(null=True, unique=True)
class Meta: class Meta:
...@@ -32,7 +36,11 @@ class UserModel(BaseModel): ...@@ -32,7 +36,11 @@ class UserModel(BaseModel):
email: str email: str
role: str = "pending" role: str = "pending"
profile_image_url: str profile_image_url: str
timestamp: int # timestamp in epoch
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
api_key: Optional[str] = None api_key: Optional[str] = None
...@@ -73,7 +81,9 @@ class UsersTable: ...@@ -73,7 +81,9 @@ class UsersTable:
"email": email, "email": email,
"role": role, "role": role,
"profile_image_url": profile_image_url, "profile_image_url": profile_image_url,
"timestamp": int(time.time()), "last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
} }
) )
result = User.create(**user.model_dump()) result = User.create(**user.model_dump())
...@@ -137,6 +147,16 @@ class UsersTable: ...@@ -137,6 +147,16 @@ class UsersTable:
except: except:
return None return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try:
query = User.update(last_active_at=int(time.time())).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try: try:
query = User.update(**updated).where(User.id == id) query = User.update(**updated).where(User.id == id)
......
import logging
from fastapi import Request from fastapi import Request
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
......
...@@ -36,15 +36,49 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) ...@@ -36,15 +36,49 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter() router = APIRouter()
############################ ############################
# GetChats # GetChatList
############################ ############################
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats( @router.get("/list", response_model=List[ChatTitleIdResponse])
async def get_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50 user=Depends(get_current_user), skip: int = 0, limit: int = 50
): ):
return Chats.get_chat_lists_by_user_id(user.id, skip, limit) return Chats.get_chat_list_by_user_id(user.id, skip, limit)
############################
# DeleteAllChats
############################
@router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
if (
user.role == "user"
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chats_by_user_id(user.id)
return result
############################
# GetUserChatList
############################
@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_list_by_user_id(user_id, skip, limit)
############################ ############################
...@@ -53,22 +87,22 @@ async def get_user_chats( ...@@ -53,22 +87,22 @@ async def get_user_chats(
@router.get("/archived", response_model=List[ChatTitleIdResponse]) @router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_user_chats( async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50 user=Depends(get_current_user), skip: int = 0, limit: int = 50
): ):
return Chats.get_archived_chat_lists_by_user_id(user.id, skip, limit) return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################ ############################
# GetAllChats # GetChats
############################ ############################
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(user=Depends(get_current_user)): async def get_user_chats(user=Depends(get_current_user)):
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id) for chat in Chats.get_chats_by_user_id(user.id)
] ]
...@@ -86,7 +120,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)): ...@@ -86,7 +120,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
) )
return [ return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats() for chat in Chats.get_chats()
] ]
...@@ -107,45 +141,6 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): ...@@ -107,45 +141,6 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
) )
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chats_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
chats = Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id)
return chats
############################ ############################
# GetChatById # GetChatById
############################ ############################
...@@ -193,17 +188,18 @@ async def update_chat_by_id( ...@@ -193,17 +188,18 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)): async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
if ( if user.role == "admin":
user.role == "user" result = Chats.delete_chat_by_id(id)
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"] return result
): else:
raise HTTPException( if not request.app.state.USER_PERMISSIONS["chat"]["deletion"]:
status_code=status.HTTP_401_UNAUTHORIZED, raise HTTPException(
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, status_code=status.HTTP_401_UNAUTHORIZED,
) detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chat_by_id_and_user_id(id, user.id) result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result return result
############################ ############################
...@@ -303,6 +299,45 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)): ...@@ -303,6 +299,45 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
) )
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, skip, limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id)
return chats
############################ ############################
# GetChatTagsById # GetChatTagsById
############################ ############################
...@@ -383,24 +418,3 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)): ...@@ -383,24 +418,3 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
) )
############################
# DeleteAllChats
############################
@router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
if (
user.role == "user"
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chats_by_user_id(user.id)
return result
from fastapi import APIRouter, UploadFile, File, Response from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from peewee import SqliteDatabase
from starlette.responses import StreamingResponse, FileResponse from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
...@@ -7,7 +8,7 @@ from pydantic import BaseModel ...@@ -7,7 +8,7 @@ from pydantic import BaseModel
from fpdf import FPDF from fpdf import FPDF
import markdown import markdown
from apps.web.internal.db import DB
from utils.utils import get_admin_user from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url from utils.misc import calculate_sha256, get_gravatar_url
...@@ -96,8 +97,13 @@ async def download_db(user=Depends(get_admin_user)): ...@@ -96,8 +97,13 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
if not isinstance(DB, SqliteDatabase):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
)
return FileResponse( return FileResponse(
f"{DATA_DIR}/webui.db", DB.database,
media_type="application/octet-stream", media_type="application/octet-stream",
filename="webui.db", filename="webui.db",
) )
...@@ -71,6 +71,9 @@ except ImportError: ...@@ -71,6 +71,9 @@ except ImportError:
log.warning("dotenv not installed, skipping...") log.warning("dotenv not installed, skipping...")
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
#################################### ####################################
...@@ -195,9 +198,6 @@ if CUSTOM_NAME: ...@@ -195,9 +198,6 @@ if CUSTOM_NAME:
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
pass pass
else:
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
#################################### ####################################
...@@ -220,7 +220,7 @@ Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) ...@@ -220,7 +220,7 @@ Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
# Docs DIR # Docs DIR
#################################### ####################################
DOCS_DIR = f"{DATA_DIR}/docs" DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True) Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
...@@ -375,8 +375,7 @@ USER_PERMISSIONS_CHAT_DELETION = ( ...@@ -375,8 +375,7 @@ USER_PERMISSIONS_CHAT_DELETION = (
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}} USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true"
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true"
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "") MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")] MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
...@@ -418,19 +417,57 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": ...@@ -418,19 +417,57 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
#################################### ####################################
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
# Comma-separated list of header=value pairs
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
if CHROMA_HTTP_HEADERS:
CHROMA_HTTP_HEADERS = dict(
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
)
else:
CHROMA_HTTP_HEADERS = None
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2) # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
ENABLE_RAG_HYBRID_SEARCH = (
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
RAG_EMBEDDING_MODEL = os.environ.get( RAG_EMBEDDING_MODEL = os.environ.get(
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2" "RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
) )
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"), log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
)
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = ( RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
) )
RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
if not RAG_RERANKING_MODEL == "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
)
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance # device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false") USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
...@@ -439,16 +476,28 @@ if USE_CUDA.lower() == "true": ...@@ -439,16 +476,28 @@ if USE_CUDA.lower() == "true":
else: else:
DEVICE_TYPE = "cpu" DEVICE_TYPE = "cpu"
if CHROMA_HTTP_HOST != "":
CHROMA_CLIENT = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT,
headers=CHROMA_HTTP_HEADERS,
ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
)
else:
CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
)
CHROMA_CLIENT = chromadb.PersistentClient( CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
path=CHROMA_DATA_PATH, CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
settings=Settings(allow_reset=True, anonymized_telemetry=False),
)
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100
RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags. DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context> <context>
[context] [context]
</context> </context>
...@@ -458,10 +507,12 @@ When answer to user: ...@@ -458,10 +507,12 @@ When answer to user:
- If you don't know when you are not sure, ask for clarification. - If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context. Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question. And answer according to the language of the user's question.
Given the context information, answer the query. Given the context information, answer the query.
Query: [query]""" Query: [query]"""
RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
...@@ -480,18 +531,25 @@ WHISPER_MODEL_AUTO_UPDATE = ( ...@@ -480,18 +531,25 @@ WHISPER_MODEL_AUTO_UPDATE = (
# Images # Images
#################################### ####################################
IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "")
ENABLE_IMAGE_GENERATION = ( ENABLE_IMAGE_GENERATION = (
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true" os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true"
) )
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "") AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
IMAGES_OPENAI_API_BASE_URL = os.getenv( IMAGES_OPENAI_API_BASE_URL = os.getenv(
"IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL "IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL
) )
IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY) IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512")
IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50))
IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "")
#################################### ####################################
# Audio # Audio
...@@ -504,7 +562,17 @@ AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY) ...@@ -504,7 +562,17 @@ AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
# LiteLLM # LiteLLM
#################################### ####################################
ENABLE_LITELLM = os.environ.get("ENABLE_LITELLM", "True").lower() == "true"
LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365")) LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365"))
if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535: if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
raise ValueError("Invalid port number for LITELLM_PROXY_PORT") raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1") LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
####################################
# Database
####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
...@@ -69,3 +69,5 @@ class ERROR_MESSAGES(str, Enum): ...@@ -69,3 +69,5 @@ class ERROR_MESSAGES(str, Enum):
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance." CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
...@@ -47,7 +47,8 @@ from config import ( ...@@ -47,7 +47,8 @@ from config import (
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
CACHE_DIR, CACHE_DIR,
STATIC_DIR, STATIC_DIR,
MODEL_FILTER_ENABLED, ENABLE_LITELLM,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
GLOBAL_LOG_LEVEL, GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
...@@ -89,7 +90,7 @@ https://github.com/open-webui/open-webui ...@@ -89,7 +90,7 @@ https://github.com/open-webui/open-webui
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.WEBHOOK_URL = WEBHOOK_URL app.state.WEBHOOK_URL = WEBHOOK_URL
...@@ -116,15 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware): ...@@ -116,15 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware):
if "docs" in data: if "docs" in data:
data = {**data} data = {**data}
data["messages"] = rag_messages( data["messages"] = rag_messages(
data["docs"], docs=data["docs"],
data["messages"], messages=data["messages"],
rag_app.state.RAG_TEMPLATE, template=rag_app.state.RAG_TEMPLATE,
rag_app.state.TOP_K, embedding_function=rag_app.state.EMBEDDING_FUNCTION,
rag_app.state.RAG_EMBEDDING_ENGINE, k=rag_app.state.TOP_K,
rag_app.state.RAG_EMBEDDING_MODEL, reranking_function=rag_app.state.sentence_transformer_rf,
rag_app.state.sentence_transformer_ef, r=rag_app.state.RELEVANCE_THRESHOLD,
rag_app.state.OPENAI_API_KEY, hybrid_search=rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
rag_app.state.OPENAI_API_BASE_URL,
) )
del data["docs"] del data["docs"]
...@@ -176,7 +176,8 @@ async def check_url(request: Request, call_next): ...@@ -176,7 +176,8 @@ async def check_url(request: Request, call_next):
@app.on_event("startup") @app.on_event("startup")
async def on_startup(): async def on_startup():
asyncio.create_task(start_litellm_background()) if ENABLE_LITELLM:
asyncio.create_task(start_litellm_background())
app.mount("/api/v1", webui_app) app.mount("/api/v1", webui_app)
...@@ -215,7 +216,7 @@ async def get_app_config(): ...@@ -215,7 +216,7 @@ async def get_app_config():
@app.get("/api/config/model/filter") @app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)): async def get_model_filter_config(user=Depends(get_admin_user)):
return { return {
"enabled": app.state.MODEL_FILTER_ENABLED, "enabled": app.state.ENABLE_MODEL_FILTER,
"models": app.state.MODEL_FILTER_LIST, "models": app.state.MODEL_FILTER_LIST,
} }
...@@ -229,20 +230,20 @@ class ModelFilterConfigForm(BaseModel): ...@@ -229,20 +230,20 @@ class ModelFilterConfigForm(BaseModel):
async def update_model_filter_config( async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user) form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
): ):
app.state.MODEL_FILTER_ENABLED = form_data.enabled app.state.ENABLE_MODEL_FILTER = form_data.enabled
app.state.MODEL_FILTER_LIST = form_data.models app.state.MODEL_FILTER_LIST = form_data.models
ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
litellm_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
return { return {
"enabled": app.state.MODEL_FILTER_ENABLED, "enabled": app.state.ENABLE_MODEL_FILTER,
"models": app.state.MODEL_FILTER_LIST, "models": app.state.MODEL_FILTER_LIST,
} }
...@@ -326,4 +327,5 @@ app.mount( ...@@ -326,4 +327,5 @@ app.mount(
@app.on_event("shutdown") @app.on_event("shutdown")
async def shutdown_event(): async def shutdown_event():
await shutdown_litellm_background() if ENABLE_LITELLM:
await shutdown_litellm_background()
...@@ -15,6 +15,8 @@ requests ...@@ -15,6 +15,8 @@ requests
aiohttp aiohttp
peewee peewee
peewee-migrate peewee-migrate
psycopg2-binary
pymysql
bcrypt bcrypt
litellm==1.35.17 litellm==1.35.17
......
...@@ -6,6 +6,7 @@ cd "$SCRIPT_DIR" || exit ...@@ -6,6 +6,7 @@ cd "$SCRIPT_DIR" || exit
KEY_FILE=.webui_secret_key KEY_FILE=.webui_secret_key
PORT="${PORT:-8080}" PORT="${PORT:-8080}"
HOST="${HOST:-0.0.0.0}"
if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
echo "No WEBUI_SECRET_KEY provided" echo "No WEBUI_SECRET_KEY provided"
...@@ -29,4 +30,4 @@ if [ "$USE_CUDA_DOCKER" = "true" ]; then ...@@ -29,4 +30,4 @@ if [ "$USE_CUDA_DOCKER" = "true" ]; then
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib" export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
fi fi
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host 0.0.0.0 --port "$PORT" --forwarded-allow-ips '*' WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*'
...@@ -7,7 +7,7 @@ SET "SCRIPT_DIR=%~dp0" ...@@ -7,7 +7,7 @@ SET "SCRIPT_DIR=%~dp0"
cd /d "%SCRIPT_DIR%" || exit /b cd /d "%SCRIPT_DIR%" || exit /b
SET "KEY_FILE=.webui_secret_key" SET "KEY_FILE=.webui_secret_key"
SET "PORT=%PORT:8080%" IF "%PORT%"=="" SET PORT=8080
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%" SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%" SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
......
...@@ -89,6 +89,8 @@ def get_current_user( ...@@ -89,6 +89,8 @@ def get_current_user(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
else:
Users.update_user_last_active_by_id(user.id)
return user return user
else: else:
raise HTTPException( raise HTTPException(
...@@ -99,11 +101,15 @@ def get_current_user( ...@@ -99,11 +101,15 @@ def get_current_user(
def get_current_user_by_api_key(api_key: str): def get_current_user_by_api_key(api_key: str):
user = Users.get_user_by_api_key(api_key) user = Users.get_user_by_api_key(api_key)
if user is None: if user is None:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
else:
Users.update_user_last_active_by_id(user.id)
return user return user
......
...@@ -2,7 +2,12 @@ ...@@ -2,7 +2,12 @@
echo "Warning: This will remove all containers and volumes, including persistent data. Do you want to continue? [Y/N]" echo "Warning: This will remove all containers and volumes, including persistent data. Do you want to continue? [Y/N]"
read ans read ans
if [ "$ans" == "Y" ] || [ "$ans" == "y" ]; then if [ "$ans" == "Y" ] || [ "$ans" == "y" ]; then
docker-compose down -v command docker-compose 2>/dev/null
if [ "$?" == "0" ]; then
docker-compose down -v
else
docker compose down -v
fi
else else
echo "Operation cancelled." echo "Operation cancelled."
fi fi
import { defineConfig } from 'cypress';
export default defineConfig({
e2e: {
baseUrl: 'http://localhost:8080'
},
video: true
});
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
/// <reference path="../support/index.d.ts" />
// These tests run through the chat flow.
describe('Settings', () => {
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
after(() => {
// eslint-disable-next-line cypress/no-unnecessary-waiting
cy.wait(2000);
});
beforeEach(() => {
// Login as the admin user
cy.loginAdmin();
// Visit the home page
cy.visit('/');
});
context('Ollama', () => {
it('user can select a model', () => {
// Click on the model selector
cy.get('button[aria-label="Select a model"]').click();
// Select the first model
cy.get('div[role="option"][data-value]').first().click();
});
it('user can perform text chat', () => {
// Click on the model selector
cy.get('button[aria-label="Select a model"]').click();
// Select the first model
cy.get('div[role="option"][data-value]').first().click();
// Type a message
cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', {
force: true
});
// Send the message
cy.get('button[type="submit"]').click();
// User's message should be visible
cy.get('.chat-user').should('exist');
// Wait for the response
cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received
.find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received
.should('exist');
});
});
});
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
/// <reference path="../support/index.d.ts" />
import { adminUser } from '../support/e2e';
// These tests assume the following defaults:
// 1. No users exist in the database or that the test admin user is an admin
// 2. Language is set to English
// 3. The default role for new users is 'pending'
describe('Registration and Login', () => {
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
after(() => {
// eslint-disable-next-line cypress/no-unnecessary-waiting
cy.wait(2000);
});
beforeEach(() => {
cy.visit('/');
});
it('should register a new user as pending', () => {
const userName = `Test User - ${Date.now()}`;
const userEmail = `cypress-${Date.now()}@example.com`;
// Toggle from sign in to sign up
cy.contains('Sign up').click();
// Fill out the form
cy.get('input[autocomplete="name"]').type(userName);
cy.get('input[autocomplete="email"]').type(userEmail);
cy.get('input[type="password"]').type('password');
// Submit the form
cy.get('button[type="submit"]').click();
// Wait until the user is redirected to the home page
cy.contains(userName);
// Expect the user to be pending
cy.contains('Check Again');
});
it('can login with the admin user', () => {
// Fill out the form
cy.get('input[autocomplete="email"]').type(adminUser.email);
cy.get('input[type="password"]').type(adminUser.password);
// Submit the form
cy.get('button[type="submit"]').click();
// Wait until the user is redirected to the home page
cy.contains(adminUser.name);
// Dismiss the changelog dialog if it is visible
cy.getAllLocalStorage().then((ls) => {
if (!ls['version']) {
cy.get('button').contains("Okay, Let's Go!").click();
}
});
});
});
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
/// <reference path="../support/index.d.ts" />
import { adminUser } from '../support/e2e';
// These tests run through the various settings pages, ensuring that the user can interact with them as expected
describe('Settings', () => {
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
after(() => {
// eslint-disable-next-line cypress/no-unnecessary-waiting
cy.wait(2000);
});
beforeEach(() => {
// Login as the admin user
cy.loginAdmin();
// Visit the home page
cy.visit('/');
// Open the sidebar if it is not already open
cy.get('[aria-label="Open sidebar"]').then(() => {
cy.get('button[id="sidebar-toggle-button"]').click();
});
// Click on the profile link
cy.get('button').contains(adminUser.name).click();
// Click on the settings link
cy.get('button').contains('Settings').click();
});
context('General', () => {
it('user can open the General modal and hit save', () => {
cy.get('button').contains('General').click();
cy.get('button').contains('Save').click();
});
});
context('Connections', () => {
it('user can open the Connections modal and hit save', () => {
cy.get('button').contains('Connections').click();
cy.get('button').contains('Save').click();
});
});
context('Models', () => {
it('user can open the Models modal', () => {
cy.get('button').contains('Models').click();
});
});
context('Interface', () => {
it('user can open the Interface modal and hit save', () => {
cy.get('button').contains('Interface').click();
cy.get('button').contains('Save').click();
});
});
context('Audio', () => {
it('user can open the Audio modal and hit save', () => {
cy.get('button').contains('Audio').click();
cy.get('button').contains('Save').click();
});
});
context('Images', () => {
it('user can open the Images modal and hit save', () => {
cy.get('button').contains('Images').click();
// Currently fails because the backend requires a valid URL
// cy.get('button').contains('Save').click();
});
});
context('Chats', () => {
it('user can open the Chats modal', () => {
cy.get('button').contains('Chats').click();
});
});
context('Account', () => {
it('user can open the Account modal and hit save', () => {
cy.get('button').contains('Account').click();
cy.get('button').contains('Save').click();
});
});
context('About', () => {
it('user can open the About modal', () => {
cy.get('button').contains('About').click();
});
});
});
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment