"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "4608afa5b9fa25491f5ee3041599654e79ab4a7b"
Commit 45311bfa authored by Anuraag Jain's avatar Anuraag Jain
Browse files

Merge branch 'main' into feat/cancel-model-download

# Conflicts:
#	src/lib/components/chat/Settings/Models.svelte
parents ae97a963 2fa94956
import re
from typing import List
from config import CHROMA_CLIENT
def query_doc(collection_name: str, query: str, k: int, embedding_function):
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=embedding_function,
)
result = collection.query(
query_texts=[query],
n_results=k,
)
return result
except Exception as e:
raise e
def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data
combined_ids = []
combined_distances = []
combined_metadatas = []
combined_documents = []
# Combine data from each dictionary
for data in query_results:
combined_ids.extend(data["ids"][0])
combined_distances.extend(data["distances"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_documents.extend(data["documents"][0])
# Create a list of tuples (distance, id, metadata, document)
combined = list(
zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
)
# Sort the list based on distances
combined.sort(key=lambda x: x[0])
# Unzip the sorted list
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_ids = list(sorted_ids)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
sorted_documents = list(sorted_documents)[:k]
# Create the output dictionary
merged_query_results = {
"ids": [sorted_ids],
"distances": [sorted_distances],
"metadatas": [sorted_metadatas],
"documents": [sorted_documents],
"embeddings": None,
"uris": None,
"data": None,
}
return merged_query_results
def query_collection(
collection_names: List[str], query: str, k: int, embedding_function
):
results = []
for collection_name in collection_names:
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=embedding_function,
)
result = collection.query(
query_texts=[query],
n_results=k,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k)
def rag_template(template: str, context: str, query: str):
template = template.replace("[context]", context)
template = template.replace("[query]", query)
return template
def rag_messages(docs, messages, template, k, embedding_function):
print(docs)
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i]["role"] == "user":
last_user_message_idx = i
break
user_message = messages[last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
relevant_contexts = []
for doc in docs:
context = None
try:
if doc["type"] == "collection":
context = query_collection(
collection_names=doc["collection_names"],
query=query,
k=k,
embedding_function=embedding_function,
)
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=k,
embedding_function=embedding_function,
)
except Exception as e:
print(e)
context = None
relevant_contexts.append(context)
context_string = ""
for context in relevant_contexts:
if context:
context_string += " ".join(context["documents"][0]) + "\n"
ra_content = rag_template(
template=template,
context=context_string,
query=query,
)
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
messages[last_user_message_idx] = new_user_message
return messages
from peewee import * from peewee import *
from config import DATA_DIR from config import DATA_DIR
import os
DB = SqliteDatabase(f"{DATA_DIR}/ollama.db") # Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
print("File renamed successfully.")
else:
pass
DB = SqliteDatabase(f"{DATA_DIR}/webui.db")
DB.connect() DB.connect()
...@@ -19,6 +19,7 @@ from config import ( ...@@ -19,6 +19,7 @@ from config import (
DEFAULT_USER_ROLE, DEFAULT_USER_ROLE,
ENABLE_SIGNUP, ENABLE_SIGNUP,
USER_PERMISSIONS, USER_PERMISSIONS,
WEBHOOK_URL,
) )
app = FastAPI() app = FastAPI()
...@@ -26,10 +27,13 @@ app = FastAPI() ...@@ -26,10 +27,13 @@ app = FastAPI()
origins = ["*"] origins = ["*"]
app.state.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.JWT_EXPIRES_IN = "-1"
app.state.DEFAULT_MODELS = DEFAULT_MODELS app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS app.state.USER_PERMISSIONS = USER_PERMISSIONS
app.state.WEBHOOK_URL = WEBHOOK_URL
app.add_middleware( app.add_middleware(
...@@ -55,7 +59,6 @@ app.include_router(utils.router, prefix="/utils", tags=["utils"]) ...@@ -55,7 +59,6 @@ app.include_router(utils.router, prefix="/utils", tags=["utils"])
async def get_status(): async def get_status():
return { return {
"status": True, "status": True,
"version": WEBUI_VERSION,
"auth": WEBUI_AUTH, "auth": WEBUI_AUTH,
"default_models": app.state.DEFAULT_MODELS, "default_models": app.state.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS,
......
...@@ -167,6 +167,27 @@ class TagTable: ...@@ -167,6 +167,27 @@ class TagTable:
.count() .count()
) )
def delete_tag_by_tag_name_and_user_id(self, tag_name: str, user_id: str) -> bool:
try:
query = ChatIdTag.delete().where(
(ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
print(res)
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
# Remove tag item from Tag col as well
query = Tag.delete().where(
(Tag.name == tag_name) & (Tag.user_id == user_id)
)
query.execute() # Remove the rows, return number of rows removed.
return True
except Exception as e:
print("delete_tag", e)
return False
def delete_tag_by_tag_name_and_chat_id_and_user_id( def delete_tag_by_tag_name_and_chat_id_and_user_id(
self, tag_name: str, chat_id: str, user_id: str self, tag_name: str, chat_id: str, user_id: str
) -> bool: ) -> bool:
......
...@@ -7,6 +7,7 @@ from fastapi import APIRouter, status ...@@ -7,6 +7,7 @@ from fastapi import APIRouter, status
from pydantic import BaseModel from pydantic import BaseModel
import time import time
import uuid import uuid
import re
from apps.web.models.auths import ( from apps.web.models.auths import (
SigninForm, SigninForm,
...@@ -25,8 +26,9 @@ from utils.utils import ( ...@@ -25,8 +26,9 @@ from utils.utils import (
get_admin_user, get_admin_user,
create_token, create_token,
) )
from utils.misc import get_gravatar_url, validate_email_format from utils.misc import parse_duration, validate_email_format
from constants import ERROR_MESSAGES from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -95,10 +97,13 @@ async def update_password( ...@@ -95,10 +97,13 @@ async def update_password(
@router.post("/signin", response_model=SigninResponse) @router.post("/signin", response_model=SigninResponse)
async def signin(form_data: SigninForm): async def signin(request: Request, form_data: SigninForm):
user = Auths.authenticate_user(form_data.email.lower(), form_data.password) user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user: if user:
token = create_token(data={"id": user.id}) token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
)
return { return {
"token": token, "token": token,
...@@ -145,9 +150,23 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -145,9 +150,23 @@ async def signup(request: Request, form_data: SignupForm):
) )
if user: if user:
token = create_token(data={"id": user.id}) token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
)
# response.set_cookie(key='token', value=token, httponly=True) # response.set_cookie(key='token', value=token, httponly=True)
if request.app.state.WEBHOOK_URL:
post_webhook(
request.app.state.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
"message": WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
"user": user.model_dump_json(exclude_none=True),
},
)
return { return {
"token": token, "token": token,
"token_type": "Bearer", "token_type": "Bearer",
...@@ -200,3 +219,33 @@ async def update_default_user_role( ...@@ -200,3 +219,33 @@ async def update_default_user_role(
if form_data.role in ["pending", "user", "admin"]: if form_data.role in ["pending", "user", "admin"]:
request.app.state.DEFAULT_USER_ROLE = form_data.role request.app.state.DEFAULT_USER_ROLE = form_data.role
return request.app.state.DEFAULT_USER_ROLE return request.app.state.DEFAULT_USER_ROLE
############################
# JWT Expiration
############################
@router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return request.app.state.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel):
duration: str
@router.post("/token/expires/update")
async def update_token_expires_duration(
request: Request,
form_data: UpdateJWTExpiresDurationForm,
user=Depends(get_admin_user),
):
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
# Check if the input string matches the pattern
if re.match(pattern, form_data.duration):
request.app.state.JWT_EXPIRES_IN = form_data.duration
return request.app.state.JWT_EXPIRES_IN
else:
return request.app.state.JWT_EXPIRES_IN
...@@ -115,9 +115,12 @@ async def get_user_chats_by_tag_name( ...@@ -115,9 +115,12 @@ async def get_user_chats_by_tag_name(
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id) for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
] ]
print(chat_ids) chats = Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit)
return Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit) if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id)
return chats
############################ ############################
...@@ -268,6 +271,16 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)): ...@@ -268,6 +271,16 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
@router.delete("/", response_model=bool) @router.delete("/", response_model=bool)
async def delete_all_user_chats(user=Depends(get_current_user)): async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
if (
user.role == "user"
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chats_by_user_id(user.id) result = Chats.delete_chats_by_user_id(user.id)
return result return result
...@@ -96,6 +96,10 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)): ...@@ -96,6 +96,10 @@ async def get_doc_by_name(name: str, user=Depends(get_current_user)):
############################ ############################
class TagItem(BaseModel):
name: str
class TagDocumentForm(BaseModel): class TagDocumentForm(BaseModel):
name: str name: str
tags: List[dict] tags: List[dict]
......
from fastapi import APIRouter, UploadFile, File, BackgroundTasks from fastapi import APIRouter, UploadFile, File, BackgroundTasks
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from starlette.responses import StreamingResponse from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel from pydantic import BaseModel
...@@ -9,9 +10,11 @@ import os ...@@ -9,9 +10,11 @@ import os
import aiohttp import aiohttp
import json import json
from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url from utils.misc import calculate_sha256, get_gravatar_url
from config import OLLAMA_API_BASE_URL, DATA_DIR, UPLOAD_DIR from config import OLLAMA_BASE_URLS, DATA_DIR, UPLOAD_DIR
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -72,7 +75,7 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024 ...@@ -72,7 +75,7 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
hashed = calculate_sha256(file) hashed = calculate_sha256(file)
file.seek(0) file.seek(0)
url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}" url = f"{OLLAMA_BASE_URLS[0]}/api/blobs/sha256:{hashed}"
response = requests.post(url, data=file) response = requests.post(url, data=file)
if response.ok: if response.ok:
...@@ -144,7 +147,7 @@ def upload(file: UploadFile = File(...)): ...@@ -144,7 +147,7 @@ def upload(file: UploadFile = File(...)):
hashed = calculate_sha256(f) hashed = calculate_sha256(f)
f.seek(0) f.seek(0)
url = f"{OLLAMA_API_BASE_URL}/blobs/sha256:{hashed}" url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
response = requests.post(url, data=f) response = requests.post(url, data=f)
if response.ok: if response.ok:
...@@ -172,3 +175,13 @@ async def get_gravatar( ...@@ -172,3 +175,13 @@ async def get_gravatar(
email: str, email: str,
): ):
return get_gravatar_url(email) return get_gravatar_url(email)
@router.get("/db/download")
async def download_db(user=Depends(get_admin_user)):
return FileResponse(
f"{DATA_DIR}/webui.db",
media_type="application/octet-stream",
filename="webui.db",
)
import os import os
import chromadb import chromadb
from chromadb import Settings from chromadb import Settings
from secrets import token_bytes
from base64 import b64encode from base64 import b64encode
from constants import ERROR_MESSAGES from bs4 import BeautifulSoup
from pathlib import Path from pathlib import Path
import json
import yaml
import markdown
import requests
import shutil
from secrets import token_bytes
from constants import ERROR_MESSAGES
try: try:
from dotenv import load_dotenv, find_dotenv from dotenv import load_dotenv, find_dotenv
...@@ -13,6 +23,8 @@ try: ...@@ -13,6 +23,8 @@ try:
except ImportError: except ImportError:
print("dotenv not installed, skipping...") print("dotenv not installed, skipping...")
WEBUI_NAME = "Open WebUI"
shutil.copyfile("../build/favicon.png", "./static/favicon.png")
#################################### ####################################
# ENV (dev,test,prod) # ENV (dev,test,prod)
...@@ -20,6 +32,102 @@ except ImportError: ...@@ -20,6 +32,102 @@ except ImportError:
ENV = os.environ.get("ENV", "dev") ENV = os.environ.get("ENV", "dev")
try:
with open(f"../package.json", "r") as f:
PACKAGE_DATA = json.load(f)
except:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
# Function to parse each section
def parse_section(section):
items = []
for li in section.find_all("li"):
# Extract raw HTML string
raw_html = str(li)
# Extract text without HTML tags
text = li.get_text(separator=" ", strip=True)
# Split into title and content
parts = text.split(": ", 1)
title = parts[0].strip() if len(parts) > 1 else ""
content = parts[1].strip() if len(parts) > 1 else text
items.append({"title": title, "content": content, "raw": raw_html})
return items
try:
with open("../CHANGELOG.md", "r") as file:
changelog_content = file.read()
except:
changelog_content = ""
# Convert markdown content to HTML
html_content = markdown.markdown(changelog_content)
# Parse the HTML content
soup = BeautifulSoup(html_content, "html.parser")
# Initialize JSON structure
changelog_json = {}
# Iterate over each version
for version in soup.find_all("h2"):
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
date = version.get_text().strip().split(" - ")[1]
version_data = {"date": date}
# Find the next sibling that is a h3 tag (section title)
current = version.find_next_sibling()
while current and current.name != "h2":
if current.name == "h3":
section_title = current.get_text().lower() # e.g., "added", "fixed"
section_items = parse_section(current.find_next_sibling("ul"))
version_data[section_title] = section_items
# Move to the next element
current = current.find_next_sibling()
changelog_json[version_number] = version_data
CHANGELOG = changelog_json
####################################
# CUSTOM_NAME
####################################
CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "")
if CUSTOM_NAME:
try:
r = requests.get(f"https://api.openwebui.com/api/v1/custom/{CUSTOM_NAME}")
data = r.json()
if r.ok:
if "logo" in data:
url = (
f"https://api.openwebui.com{data['logo']}"
if data["logo"][0] == "/"
else data["logo"]
)
r = requests.get(url, stream=True)
if r.status_code == 200:
with open("./static/favicon.png", "wb") as f:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
WEBUI_NAME = data["name"]
except Exception as e:
print(e)
pass
#################################### ####################################
# DATA/FRONTEND BUILD DIR # DATA/FRONTEND BUILD DIR
...@@ -28,6 +136,12 @@ ENV = os.environ.get("ENV", "dev") ...@@ -28,6 +136,12 @@ ENV = os.environ.get("ENV", "dev")
DATA_DIR = str(Path(os.getenv("DATA_DIR", "./data")).resolve()) DATA_DIR = str(Path(os.getenv("DATA_DIR", "./data")).resolve())
FRONTEND_BUILD_DIR = str(Path(os.getenv("FRONTEND_BUILD_DIR", "../build"))) FRONTEND_BUILD_DIR = str(Path(os.getenv("FRONTEND_BUILD_DIR", "../build")))
try:
with open(f"{DATA_DIR}/config.json", "r") as f:
CONFIG_DATA = json.load(f)
except:
CONFIG_DATA = {}
#################################### ####################################
# File Upload DIR # File Upload DIR
#################################### ####################################
...@@ -43,17 +157,76 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) ...@@ -43,17 +157,76 @@ Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
CACHE_DIR = f"{DATA_DIR}/cache" CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Docs DIR
####################################
DOCS_DIR = f"{DATA_DIR}/docs"
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# OLLAMA_API_BASE_URL # LITELLM_CONFIG
####################################
def create_config_file(file_path):
directory = os.path.dirname(file_path)
# Check if directory exists, if not, create it
if not os.path.exists(directory):
os.makedirs(directory)
# Data to write into the YAML file
config_data = {
"general_settings": {},
"litellm_settings": {},
"model_list": [],
"router_settings": {},
}
# Write data to YAML file
with open(file_path, "w") as file:
yaml.dump(config_data, file)
LITELLM_CONFIG_PATH = f"{DATA_DIR}/litellm/config.yaml"
if not os.path.exists(LITELLM_CONFIG_PATH):
print("Config file doesn't exist. Creating...")
create_config_file(LITELLM_CONFIG_PATH)
print("Config file created successfully.")
####################################
# OLLAMA_BASE_URL
#################################### ####################################
OLLAMA_API_BASE_URL = os.environ.get( OLLAMA_API_BASE_URL = os.environ.get(
"OLLAMA_API_BASE_URL", "http://localhost:11434/api" "OLLAMA_API_BASE_URL", "http://localhost:11434/api"
) )
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
if OLLAMA_BASE_URL == "" and OLLAMA_API_BASE_URL != "":
OLLAMA_BASE_URL = (
OLLAMA_API_BASE_URL[:-4]
if OLLAMA_API_BASE_URL.endswith("/api")
else OLLAMA_API_BASE_URL
)
if ENV == "prod": if ENV == "prod":
if OLLAMA_API_BASE_URL == "/ollama/api": if OLLAMA_BASE_URL == "/ollama":
OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api" OLLAMA_BASE_URL = "http://host.docker.internal:11434"
OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
#################################### ####################################
# OPENAI_API # OPENAI_API
...@@ -62,19 +235,40 @@ if ENV == "prod": ...@@ -62,19 +235,40 @@ if ENV == "prod":
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "") OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "": if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1" OPENAI_API_BASE_URL = "https://api.openai.com/v1"
OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
OPENAI_API_BASE_URLS = (
OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
)
OPENAI_API_BASE_URLS = [
url.strip() if url != "" else "https://api.openai.com/v1"
for url in OPENAI_API_BASE_URLS.split(";")
]
#################################### ####################################
# WEBUI # WEBUI
#################################### ####################################
ENABLE_SIGNUP = os.environ.get("ENABLE_SIGNUP", True) ENABLE_SIGNUP = os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None) DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None)
DEFAULT_PROMPT_SUGGESTIONS = os.environ.get(
"DEFAULT_PROMPT_SUGGESTIONS",
[ DEFAULT_PROMPT_SUGGESTIONS = (
CONFIG_DATA["ui"]["prompt_suggestions"]
if "ui" in CONFIG_DATA
and "prompt_suggestions" in CONFIG_DATA["ui"]
and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list
else [
{ {
"title": ["Help me study", "vocabulary for a college entrance exam"], "title": ["Help me study", "vocabulary for a college entrance exam"],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.", "content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
...@@ -91,12 +285,25 @@ DEFAULT_PROMPT_SUGGESTIONS = os.environ.get( ...@@ -91,12 +285,25 @@ DEFAULT_PROMPT_SUGGESTIONS = os.environ.get(
"title": ["Show me a code snippet", "of a website's sticky header"], "title": ["Show me a code snippet", "of a website's sticky header"],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.", "content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.",
}, },
], ]
) )
DEFAULT_USER_ROLE = "pending"
USER_PERMISSIONS = {"chat": {"deletion": True}}
DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
USER_PERMISSIONS_CHAT_DELETION = (
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
)
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true"
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
#################################### ####################################
# WEBUI_VERSION # WEBUI_VERSION
#################################### ####################################
...@@ -128,7 +335,12 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "": ...@@ -128,7 +335,12 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
#################################### ####################################
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
EMBED_MODEL = "all-MiniLM-L6-v2" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
"RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
)
CHROMA_CLIENT = chromadb.PersistentClient( CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH, path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False), settings=Settings(allow_reset=True, anonymized_telemetry=False),
...@@ -136,9 +348,31 @@ CHROMA_CLIENT = chromadb.PersistentClient( ...@@ -136,9 +348,31 @@ CHROMA_CLIENT = chromadb.PersistentClient(
CHUNK_SIZE = 1500 CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100 CHUNK_OVERLAP = 100
RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
[context]
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]"""
#################################### ####################################
# Transcribe # Transcribe
#################################### ####################################
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base") WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models") WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
####################################
# Images
####################################
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
...@@ -5,6 +5,13 @@ class MESSAGES(str, Enum): ...@@ -5,6 +5,13 @@ class MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}" DEFAULT = lambda msg="": f"{msg if msg else ''}"
class WEBHOOK_MESSAGES(str, Enum):
DEFAULT = lambda msg="": f"{msg if msg else ''}"
USER_SIGNUP = lambda username="": (
f"New user signed up: {username}" if username else "New user signed up"
)
class ERROR_MESSAGES(str, Enum): class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str: def __str__(self) -> str:
return super().__str__() return super().__str__()
...@@ -41,6 +48,15 @@ class ERROR_MESSAGES(str, Enum): ...@@ -41,6 +48,15 @@ class ERROR_MESSAGES(str, Enum):
NOT_FOUND = "We could not find what you're looking for :/" NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
MALICIOUS = "Unusual activities detected, please try again in a few minutes." MALICIOUS = "Unusual activities detected, please try again in a few minutes."
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance." PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
INCORRECT_FORMAT = (
lambda err="": f"Invalid format. Please use the correct format{err}"
)
RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found"
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
{
"version": 0,
"ui": {
"prompt_suggestions": [
{
"title": [
"Help me study",
"vocabulary for a college entrance exam"
],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."
},
{
"title": [
"Give me ideas",
"for what to do with my kids' art"
],
"content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."
},
{
"title": [
"Tell me a fun fact",
"about the Roman Empire"
],
"content": "Tell me a random fun fact about the Roman Empire"
},
{
"title": [
"Show me a code snippet",
"of a website's sticky header"
],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript."
}
]
}
}
\ No newline at end of file
general_settings: {}
litellm_settings: {}
model_list: []
router_settings: {}
from bs4 import BeautifulSoup
import json
import markdown
import time import time
import os
import sys
import requests
from fastapi import FastAPI, Request from fastapi import FastAPI, Request, Depends, status
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi import HTTPException from fastapi import HTTPException
from fastapi.middleware.wsgi import WSGIMiddleware from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app from apps.openai.main import app as openai_app
from apps.litellm.main import app as litellm_app, startup as litellm_app_startup
from apps.audio.main import app as audio_app from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app
from pydantic import BaseModel
from typing import List
from apps.web.main import app as webui_app
from apps.rag.main import app as rag_app
from config import ENV, FRONTEND_BUILD_DIR from utils.utils import get_admin_user
from apps.rag.utils import rag_messages
from config import (
WEBUI_NAME,
ENV,
VERSION,
CHANGELOG,
FRONTEND_BUILD_DIR,
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
WEBHOOK_URL,
)
from constants import ERROR_MESSAGES
class SPAStaticFiles(StaticFiles): class SPAStaticFiles(StaticFiles):
...@@ -32,8 +56,70 @@ class SPAStaticFiles(StaticFiles): ...@@ -32,8 +56,70 @@ class SPAStaticFiles(StaticFiles):
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.WEBHOOK_URL = WEBHOOK_URL
origins = ["*"] origins = ["*"]
class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
):
print(request.url.path)
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if "docs" in data:
data = {**data}
data["messages"] = rag_messages(
data["docs"],
data["messages"],
rag_app.state.RAG_TEMPLATE,
rag_app.state.TOP_K,
rag_app.state.sentence_transformer_ef,
)
del data["docs"]
print(data["messages"])
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
*[
(k, v)
for k, v in request.headers.raw
if k.lower() != b"content-length"
],
]
response = await call_next(request)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
app.add_middleware(RAGMiddleware)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=origins,
...@@ -53,15 +139,124 @@ async def check_url(request: Request, call_next): ...@@ -53,15 +139,124 @@ async def check_url(request: Request, call_next):
return response return response
@app.on_event("startup")
async def on_startup():
await litellm_app_startup()
app.mount("/api/v1", webui_app) app.mount("/api/v1", webui_app)
app.mount("/litellm/api", litellm_app)
app.mount("/ollama/api", ollama_app) app.mount("/ollama", ollama_app)
app.mount("/openai/api", openai_app) app.mount("/openai/api", openai_app)
app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app) app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app) app.mount("/rag/api/v1", rag_app)
@app.get("/api/config")
async def get_app_config():
return {
"status": True,
"name": WEBUI_NAME,
"version": VERSION,
"images": images_app.state.ENABLED,
"default_models": webui_app.state.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
}
@app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)):
return {
"enabled": app.state.MODEL_FILTER_ENABLED,
"models": app.state.MODEL_FILTER_LIST,
}
class ModelFilterConfigForm(BaseModel):
enabled: bool
models: List[str]
@app.post("/api/config/model/filter")
async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
app.state.MODEL_FILTER_ENABLED = form_data.enabled
app.state.MODEL_FILTER_LIST = form_data.models
ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
return {
"enabled": app.state.MODEL_FILTER_ENABLED,
"models": app.state.MODEL_FILTER_LIST,
}
@app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)):
return {
"url": app.state.WEBHOOK_URL,
}
class UrlForm(BaseModel):
url: str
@app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL
return {
"url": app.state.WEBHOOK_URL,
}
@app.get("/api/version")
async def get_app_config():
return {
"version": VERSION,
}
@app.get("/api/changelog")
async def get_app_changelog():
return CHANGELOG
@app.get("/api/version/updates")
async def get_app_latest_release_version():
try:
response = requests.get(
f"https://api.github.com/repos/open-webui/open-webui/releases/latest"
)
response.raise_for_status()
latest_version = response.json()["tag_name"]
return {"current": VERSION, "latest": latest_version[1:]}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
)
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/cache", StaticFiles(directory="data/cache"), name="cache")
app.mount( app.mount(
"/", "/",
SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True), SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
......
...@@ -16,8 +16,14 @@ aiohttp ...@@ -16,8 +16,14 @@ aiohttp
peewee peewee
bcrypt bcrypt
litellm==1.30.7
argon2-cffi
apscheduler
google-generativeai
langchain langchain
langchain-community langchain-community
fake_useragent
chromadb chromadb
sentence_transformers sentence_transformers
pypdf pypdf
...@@ -30,6 +36,9 @@ openpyxl ...@@ -30,6 +36,9 @@ openpyxl
pyxlsb pyxlsb
xlrd xlrd
opencv-python-headless
rapidocr-onnxruntime
faster-whisper faster-whisper
PyJWT PyJWT
......
from pathlib import Path
import hashlib import hashlib
import re import re
from datetime import timedelta
from typing import Optional
def get_gravatar_url(email): def get_gravatar_url(email):
...@@ -38,3 +41,71 @@ def validate_email_format(email: str) -> bool: ...@@ -38,3 +41,71 @@ def validate_email_format(email: str) -> bool:
if not re.match(r"[^@]+@[^@]+\.[^@]+", email): if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
return False return False
return True return True
def sanitize_filename(file_name):
# Convert to lowercase
lower_case_file_name = file_name.lower()
# Remove special characters using regular expression
sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name)
# Replace spaces with dashes
final_file_name = re.sub(r"\s+", "-", sanitized_file_name)
return final_file_name
def extract_folders_after_data_docs(path):
# Convert the path to a Path object if it's not already
path = Path(path)
# Extract parts of the path
parts = path.parts
# Find the index of '/data/docs' in the path
try:
index_data_docs = parts.index("data") + 1
index_docs = parts.index("docs", index_data_docs) + 1
except ValueError:
return []
# Exclude the filename and accumulate folder names
tags = []
folders = parts[index_docs:-1]
for idx, part in enumerate(folders):
tags.append("/".join(folders[: idx + 1]))
return tags
def parse_duration(duration: str) -> Optional[timedelta]:
if duration == "-1" or duration == "0":
return None
# Regular expression to find number and unit pairs
pattern = r"(-?\d+(\.\d+)?)(ms|s|m|h|d|w)"
matches = re.findall(pattern, duration)
if not matches:
raise ValueError("Invalid duration string")
total_duration = timedelta()
for number, _, unit in matches:
number = float(number)
if unit == "ms":
total_duration += timedelta(milliseconds=number)
elif unit == "s":
total_duration += timedelta(seconds=number)
elif unit == "m":
total_duration += timedelta(minutes=number)
elif unit == "h":
total_duration += timedelta(hours=number)
elif unit == "d":
total_duration += timedelta(days=number)
elif unit == "w":
total_duration += timedelta(weeks=number)
return total_duration
...@@ -58,6 +58,14 @@ def extract_token_from_auth_header(auth_header: str): ...@@ -58,6 +58,14 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def get_http_authorization_cred(auth_header: str):
try:
scheme, credentials = auth_header.split(" ")
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
except:
raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
def get_current_user( def get_current_user(
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
): ):
......
import requests
def post_webhook(url: str, message: str, event_data: dict) -> bool:
try:
payload = {}
if "https://hooks.slack.com" in url:
payload["text"] = message
elif "https://discord.com/api/webhooks" in url:
payload["content"] = message
else:
payload = {**event_data}
r = requests.post(url, json=payload)
r.raise_for_status()
return True
except Exception as e:
print(e)
return False
No preview for this file type
#!/bin/bash
echo "Warning: This will remove all containers and volumes, including persistent data. Do you want to continue? [Y/N]"
read ans
if [ "$ans" == "Y" ] || [ "$ans" == "y" ]; then
docker-compose down -v
else
echo "Operation cancelled."
fi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment