Unverified Commit 0ed54055 authored by Ankur's avatar Ankur Committed by GitHub
Browse files

Merge branch 'dev' into main

parents 6d99ef8b c6317640
## Pull Request Checklist ## Pull Request Checklist
- [ ] **Target branch:** Pull requests should target the `dev` branch.
- [ ] **Description:** Briefly describe the changes in this pull request. - [ ] **Description:** Briefly describe the changes in this pull request.
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description. - [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources? - [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
......
...@@ -20,7 +20,16 @@ jobs: ...@@ -20,7 +20,16 @@ jobs:
- name: Build and run Compose Stack - name: Build and run Compose Stack
run: | run: |
docker compose up --detach --build docker compose --file docker-compose.yaml --file docker-compose.api.yaml up --detach --build
- name: Wait for Ollama to be up
timeout-minutes: 5
run: |
until curl --output /dev/null --silent --fail http://localhost:11434; do
printf '.'
sleep 1
done
echo "Service is up!"
- name: Preload Ollama model - name: Preload Ollama model
run: | run: |
......
...@@ -36,6 +36,10 @@ from config import ( ...@@ -36,6 +36,10 @@ from config import (
LITELLM_PROXY_HOST, LITELLM_PROXY_HOST,
) )
import warnings
warnings.simplefilter("ignore")
from litellm.utils import get_llm_provider from litellm.utils import get_llm_provider
import asyncio import asyncio
......
...@@ -31,7 +31,12 @@ from typing import Optional, List, Union ...@@ -31,7 +31,12 @@ from typing import Optional, List, Union
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user from utils.utils import (
decode_token,
get_current_user,
get_verified_user,
get_admin_user,
)
from config import ( from config import (
...@@ -164,7 +169,7 @@ async def get_all_models(): ...@@ -164,7 +169,7 @@ async def get_all_models():
@app.get("/api/tags") @app.get("/api/tags")
@app.get("/api/tags/{url_idx}") @app.get("/api/tags/{url_idx}")
async def get_ollama_tags( async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_current_user) url_idx: Optional[int] = None, user=Depends(get_verified_user)
): ):
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
...@@ -563,7 +568,7 @@ async def delete_model( ...@@ -563,7 +568,7 @@ async def delete_model(
@app.post("/api/show") @app.post("/api/show")
async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)): async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
if form_data.name not in app.state.MODELS: if form_data.name not in app.state.MODELS:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
...@@ -612,7 +617,7 @@ class GenerateEmbeddingsForm(BaseModel): ...@@ -612,7 +617,7 @@ class GenerateEmbeddingsForm(BaseModel):
async def generate_embeddings( async def generate_embeddings(
form_data: GenerateEmbeddingsForm, form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx == None:
model = form_data.model model = form_data.model
...@@ -730,7 +735,7 @@ class GenerateCompletionForm(BaseModel): ...@@ -730,7 +735,7 @@ class GenerateCompletionForm(BaseModel):
async def generate_completion( async def generate_completion(
form_data: GenerateCompletionForm, form_data: GenerateCompletionForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx == None:
...@@ -833,7 +838,7 @@ class GenerateChatCompletionForm(BaseModel): ...@@ -833,7 +838,7 @@ class GenerateChatCompletionForm(BaseModel):
async def generate_chat_completion( async def generate_chat_completion(
form_data: GenerateChatCompletionForm, form_data: GenerateChatCompletionForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx == None:
...@@ -942,7 +947,7 @@ class OpenAIChatCompletionForm(BaseModel): ...@@ -942,7 +947,7 @@ class OpenAIChatCompletionForm(BaseModel):
async def generate_openai_chat_completion( async def generate_openai_chat_completion(
form_data: OpenAIChatCompletionForm, form_data: OpenAIChatCompletionForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx == None:
...@@ -1241,7 +1246,9 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): ...@@ -1241,7 +1246,9 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user)
):
url = app.state.OLLAMA_BASE_URLS[0] url = app.state.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}" target_url = f"{url}/{path}"
......
...@@ -79,6 +79,7 @@ from config import ( ...@@ -79,6 +79,7 @@ from config import (
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
ENABLE_RAG_HYBRID_SEARCH, ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
RAG_RERANKING_MODEL, RAG_RERANKING_MODEL,
PDF_EXTRACT_IMAGES, PDF_EXTRACT_IMAGES,
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
...@@ -90,7 +91,7 @@ from config import ( ...@@ -90,7 +91,7 @@ from config import (
CHUNK_SIZE, CHUNK_SIZE,
CHUNK_OVERLAP, CHUNK_OVERLAP,
RAG_TEMPLATE, RAG_TEMPLATE,
ENABLE_LOCAL_WEB_FETCH, ENABLE_RAG_LOCAL_WEB_FETCH,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -104,6 +105,9 @@ app.state.TOP_K = RAG_TOP_K ...@@ -104,6 +105,9 @@ app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
...@@ -113,6 +117,7 @@ app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL ...@@ -113,6 +117,7 @@ app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
...@@ -308,6 +313,7 @@ async def get_rag_config(user=Depends(get_admin_user)): ...@@ -308,6 +313,7 @@ async def get_rag_config(user=Depends(get_admin_user)):
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.CHUNK_OVERLAP,
}, },
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
} }
...@@ -317,15 +323,34 @@ class ChunkParamUpdateForm(BaseModel): ...@@ -317,15 +323,34 @@ class ChunkParamUpdateForm(BaseModel):
class ConfigUpdateForm(BaseModel): class ConfigUpdateForm(BaseModel):
pdf_extract_images: bool pdf_extract_images: Optional[bool] = None
chunk: ChunkParamUpdateForm chunk: Optional[ChunkParamUpdateForm] = None
web_loader_ssl_verification: Optional[bool] = None
@app.post("/config/update") @app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images app.state.PDF_EXTRACT_IMAGES = (
app.state.CHUNK_SIZE = form_data.chunk.chunk_size form_data.pdf_extract_images
app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap if form_data.pdf_extract_images != None
else app.state.PDF_EXTRACT_IMAGES
)
app.state.CHUNK_SIZE = (
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
)
app.state.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk != None
else app.state.CHUNK_OVERLAP
)
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
return { return {
"status": True, "status": True,
...@@ -334,6 +359,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ ...@@ -334,6 +359,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.CHUNK_OVERLAP,
}, },
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
} }
...@@ -485,7 +511,9 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): ...@@ -485,7 +511,9 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
def store_web(form_data: UrlForm, user=Depends(get_current_user)): def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = get_web_loader(form_data.url) loader = get_web_loader(
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
data = loader.load() data = loader.load()
collection_name = form_data.collection_name collection_name = form_data.collection_name
...@@ -506,11 +534,11 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)): ...@@ -506,11 +534,11 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
) )
def get_web_loader(url: str): def get_web_loader(url: str, verify_ssl: bool = True):
# Check if the URL is valid # Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError): if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH: if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url) parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses # Get IPv4 and IPv6 addresses
...@@ -523,7 +551,7 @@ def get_web_loader(url: str): ...@@ -523,7 +551,7 @@ def get_web_loader(url: str):
for ip in ipv6_addresses: for ip in ipv6_addresses:
if validators.ipv6(ip, private=True): if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url) return WebBaseLoader(url, verify_ssl=verify_ssl)
def resolve_hostname(hostname): def resolve_hostname(hostname):
...@@ -594,7 +622,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -594,7 +622,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
for batch in create_batches( for batch in create_batches(
api=CHROMA_CLIENT, api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts], ids=[str(uuid.uuid4()) for _ in texts],
metadatas=metadatas, metadatas=metadatas,
embeddings=embeddings, embeddings=embeddings,
documents=texts, documents=texts,
......
...@@ -271,14 +271,14 @@ def rag_messages( ...@@ -271,14 +271,14 @@ def rag_messages(
for doc in docs: for doc in docs:
context = None context = None
collection = doc.get("collection_name") collection_names = (
if collection: doc["collection_names"]
collection = [collection] if doc["type"] == "collection"
else: else [doc["collection_name"]]
collection = doc.get("collection_names", []) )
collection = set(collection).difference(extracted_collections) collection_names = set(collection_names).difference(extracted_collections)
if not collection: if not collection_names:
log.debug(f"skipping {doc} as it has already been extracted") log.debug(f"skipping {doc} as it has already been extracted")
continue continue
...@@ -288,11 +288,7 @@ def rag_messages( ...@@ -288,11 +288,7 @@ def rag_messages(
else: else:
if hybrid_search: if hybrid_search:
context = query_collection_with_hybrid_search( context = query_collection_with_hybrid_search(
collection_names=( collection_names=collection_names,
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
query=query, query=query,
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,
...@@ -301,11 +297,7 @@ def rag_messages( ...@@ -301,11 +297,7 @@ def rag_messages(
) )
else: else:
context = query_collection( context = query_collection(
collection_names=( collection_names=collection_names,
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
query=query, query=query,
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,
...@@ -315,18 +307,31 @@ def rag_messages( ...@@ -315,18 +307,31 @@ def rag_messages(
context = None context = None
if context: if context:
relevant_contexts.append(context) relevant_contexts.append({**context, "source": doc})
extracted_collections.extend(collection) extracted_collections.extend(collection_names)
context_string = "" context_string = ""
citations = []
for context in relevant_contexts: for context in relevant_contexts:
try: try:
if "documents" in context: if "documents" in context:
items = [item for item in context["documents"][0] if item is not None] context_string += "\n\n".join(
context_string += "\n\n".join(items) [text for text in context["documents"][0] if text is not None]
)
if "metadatas" in context:
citations.append(
{
"source": context["source"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
context_string = context_string.strip() context_string = context_string.strip()
ra_content = rag_template( ra_content = rag_template(
...@@ -355,7 +360,7 @@ def rag_messages( ...@@ -355,7 +360,7 @@ def rag_messages(
messages[last_user_message_idx] = new_user_message messages[last_user_message_idx] = new_user_message
return messages return messages, citations
def get_model_path(model: str, update_model: bool = False): def get_model_path(model: str, update_model: bool = False):
......
...@@ -93,6 +93,31 @@ async def get_archived_session_user_chat_list( ...@@ -93,6 +93,31 @@ async def get_archived_session_user_chat_list(
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################ ############################
# GetChats # GetChats
############################ ############################
...@@ -141,6 +166,55 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): ...@@ -141,6 +166,55 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
) )
############################
# GetChatsByTags
############################
class TagNameForm(BaseModel):
name: str
skip: Optional[int] = 0
limit: Optional[int] = 50
@router.post("/tags", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_current_user)
):
print(form_data)
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
form_data.name, user.id
)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
return chats
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################ ############################
# GetChatById # GetChatById
############################ ############################
...@@ -274,70 +348,6 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): ...@@ -274,70 +348,6 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
) )
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, skip, limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id)
return chats
############################ ############################
# GetChatTagsById # GetChatTagsById
############################ ############################
......
...@@ -18,6 +18,18 @@ from secrets import token_bytes ...@@ -18,6 +18,18 @@ from secrets import token_bytes
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
####################################
# Load .env file
####################################
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv("../.env"))
except ImportError:
print("dotenv not installed, skipping...")
#################################### ####################################
# LOGGING # LOGGING
#################################### ####################################
...@@ -59,23 +71,16 @@ for source in log_sources: ...@@ -59,23 +71,16 @@ for source in log_sources:
log.setLevel(SRC_LOG_LEVELS["CONFIG"]) log.setLevel(SRC_LOG_LEVELS["CONFIG"])
####################################
# Load .env file
####################################
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv("../.env"))
except ImportError:
log.warning("dotenv not installed, skipping...")
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI": if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)" WEBUI_NAME += " (Open WebUI)"
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
#################################### ####################################
# ENV (dev,test,prod) # ENV (dev,test,prod)
#################################### ####################################
...@@ -454,6 +459,11 @@ ENABLE_RAG_HYBRID_SEARCH = ( ...@@ -454,6 +459,11 @@ ENABLE_RAG_HYBRID_SEARCH = (
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true" os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
) )
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true"
)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true" PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
...@@ -531,7 +541,9 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) ...@@ -531,7 +541,9 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true" ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
)
#################################### ####################################
# Transcribe # Transcribe
......
...@@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware ...@@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse, Response
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app from apps.openai.main import app as openai_app
...@@ -25,6 +25,8 @@ from apps.litellm.main import ( ...@@ -25,6 +25,8 @@ from apps.litellm.main import (
start_litellm_background, start_litellm_background,
shutdown_litellm_background, shutdown_litellm_background,
) )
from apps.audio.main import app as audio_app from apps.audio.main import app as audio_app
from apps.images.main import app as images_app from apps.images.main import app as images_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
...@@ -41,6 +43,7 @@ from apps.rag.utils import rag_messages ...@@ -41,6 +43,7 @@ from apps.rag.utils import rag_messages
from config import ( from config import (
CONFIG_DATA, CONFIG_DATA,
WEBUI_NAME, WEBUI_NAME,
WEBUI_URL,
ENV, ENV,
VERSION, VERSION,
CHANGELOG, CHANGELOG,
...@@ -74,7 +77,7 @@ class SPAStaticFiles(StaticFiles): ...@@ -74,7 +77,7 @@ class SPAStaticFiles(StaticFiles):
print( print(
f""" rf"""
___ __ __ _ _ _ ___ ___ __ __ _ _ _ ___
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _| / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || | | | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
...@@ -100,6 +103,8 @@ origins = ["*"] ...@@ -100,6 +103,8 @@ origins = ["*"]
class RAGMiddleware(BaseHTTPMiddleware): class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
return_citations = False
if request.method == "POST" and ( if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path "/api/chat" in request.url.path or "/chat/completions" in request.url.path
): ):
...@@ -112,11 +117,15 @@ class RAGMiddleware(BaseHTTPMiddleware): ...@@ -112,11 +117,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON # Parse string to JSON
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Example: Add a new key-value pair or modify existing ones # Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification # data["modified"] = True # Example modification
if "docs" in data: if "docs" in data:
data = {**data} data = {**data}
data["messages"] = rag_messages( data["messages"], citations = rag_messages(
docs=data["docs"], docs=data["docs"],
messages=data["messages"], messages=data["messages"],
template=rag_app.state.RAG_TEMPLATE, template=rag_app.state.RAG_TEMPLATE,
...@@ -128,7 +137,9 @@ class RAGMiddleware(BaseHTTPMiddleware): ...@@ -128,7 +137,9 @@ class RAGMiddleware(BaseHTTPMiddleware):
) )
del data["docs"] del data["docs"]
log.debug(f"data['messages']: {data['messages']}") log.debug(
f"data['messages']: {data['messages']}, citations: {citations}"
)
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
...@@ -146,11 +157,36 @@ class RAGMiddleware(BaseHTTPMiddleware): ...@@ -146,11 +157,36 @@ class RAGMiddleware(BaseHTTPMiddleware):
] ]
response = await call_next(request) response = await call_next(request)
if return_citations:
# Inject the citations into the response
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, citations),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, citations),
)
return response return response
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False} return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, citations):
yield f"data: {json.dumps({'citations': citations})}\n\n"
async for data in original_generator:
yield data
async def ollama_stream_wrapper(self, original_generator, citations):
yield f"{json.dumps({'citations': citations})}\n"
async for data in original_generator:
yield data
app.add_middleware(RAGMiddleware) app.add_middleware(RAGMiddleware)
...@@ -315,6 +351,21 @@ async def get_manifest_json(): ...@@ -315,6 +351,21 @@ async def get_manifest_json():
} }
@app.get("/opensearch.xml")
async def get_opensearch_xml():
xml_content = rf"""
<OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
<ShortName>{WEBUI_NAME}</ShortName>
<Description>Search {WEBUI_NAME}</Description>
<InputEncoding>UTF-8</InputEncoding>
<Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image>
<Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
<moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
</OpenSearchDescription>
"""
return Response(content=xml_content, media_type="application/xml")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
......
...@@ -9,7 +9,6 @@ Flask-Cors==4.0.0 ...@@ -9,7 +9,6 @@ Flask-Cors==4.0.0
python-socketio==5.11.2 python-socketio==5.11.2
python-jose==3.3.0 python-jose==3.3.0
passlib[bcrypt]==1.7.4 passlib[bcrypt]==1.7.4
uuid==1.30
requests==2.31.0 requests==2.31.0
aiohttp==3.9.5 aiohttp==3.9.5
...@@ -19,7 +18,6 @@ psycopg2-binary==2.9.9 ...@@ -19,7 +18,6 @@ psycopg2-binary==2.9.9
PyMySQL==1.1.0 PyMySQL==1.1.0
bcrypt==4.1.2 bcrypt==4.1.2
litellm==1.35.28
litellm[proxy]==1.35.28 litellm[proxy]==1.35.28
boto3==1.34.95 boto3==1.34.95
...@@ -54,9 +52,8 @@ rank-bm25==0.2.2 ...@@ -54,9 +52,8 @@ rank-bm25==0.2.2
faster-whisper==1.0.1 faster-whisper==1.0.1
PyJWT==2.8.0
PyJWT[crypto]==2.8.0 PyJWT[crypto]==2.8.0
black==24.4.2 black==24.4.2
langfuse==2.27.3 langfuse==2.27.3
youtube-transcript-api youtube-transcript-api==0.6.2
apiVersion: v2
name: open-webui
version: 1.0.0
appVersion: "latest"
home: https://www.openwebui.com/
icon: https://raw.githubusercontent.com/open-webui/open-webui/main/static/favicon.png
description: "Open WebUI: A User-Friendly Web Interface for Chat Interactions 👋"
keywords:
- llm
- chat
- web-ui
sources:
- https://github.com/open-webui/open-webui/tree/main/kubernetes/helm
- https://hub.docker.com/r/ollama/ollama
- https://github.com/open-webui/open-webui/pkgs/container/open-webui
annotations:
licenses: MIT
# Helm Charts
Open WebUI Helm Charts are now hosted in a separate repo, which can be found here: https://github.com/open-webui/helm-charts
The charts are released at https://helm.openwebui.com.
\ No newline at end of file
{{- define "open-webui.name" -}}
{{- default .Chart.Name .Values.nameOverride | trunc 63 | trimSuffix "-" }}
{{- end -}}
{{- define "ollama.name" -}}
ollama
{{- end -}}
{{- define "ollama.url" -}}
{{- if .Values.ollama.externalHost }}
{{- printf .Values.ollama.externalHost }}
{{- else }}
{{- printf "http://%s.%s.svc.cluster.local:%d" (include "ollama.name" .) (.Release.Namespace) (.Values.ollama.service.port | int) }}
{{- end }}
{{- end }}
{{- define "chart.name" -}}
{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }}
{{- end }}
{{- define "base.labels" -}}
helm.sh/chart: {{ include "chart.name" . }}
{{- if .Chart.AppVersion }}
app.kubernetes.io/version: {{ .Chart.AppVersion | quote }}
{{- end }}
app.kubernetes.io/managed-by: {{ .Release.Service }}
{{- end }}
{{- define "base.selectorLabels" -}}
app.kubernetes.io/instance: {{ .Release.Name }}
{{- end -}}
{{- define "open-webui.selectorLabels" -}}
{{ include "base.selectorLabels" . }}
app.kubernetes.io/component: {{ .Chart.Name }}
{{- end }}
{{- define "open-webui.labels" -}}
{{ include "base.labels" . }}
{{ include "open-webui.selectorLabels" . }}
{{- end }}
{{- define "ollama.selectorLabels" -}}
{{ include "base.selectorLabels" . }}
app.kubernetes.io/component: {{ include "ollama.name" . }}
{{- end }}
{{- define "ollama.labels" -}}
{{ include "base.labels" . }}
{{ include "ollama.selectorLabels" . }}
{{- end }}
{{- if not .Values.ollama.externalHost }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "ollama.name" . }}
labels:
{{- include "ollama.labels" . | nindent 4 }}
{{- with .Values.ollama.service.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
selector:
{{- include "ollama.selectorLabels" . | nindent 4 }}
{{- with .Values.ollama.service }}
type: {{ .type }}
ports:
- protocol: TCP
name: http
port: {{ .port }}
targetPort: http
{{- end }}
{{- end }}
{{- if not .Values.ollama.externalHost }}
apiVersion: apps/v1
kind: StatefulSet
metadata:
name: {{ include "ollama.name" . }}
labels:
{{- include "ollama.labels" . | nindent 4 }}
{{- with .Values.ollama.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
serviceName: {{ include "ollama.name" . }}
replicas: {{ .Values.ollama.replicaCount }}
selector:
matchLabels:
{{- include "ollama.selectorLabels" . | nindent 6 }}
template:
metadata:
labels:
{{- include "ollama.labels" . | nindent 8 }}
{{- with .Values.ollama.podAnnotations }}
annotations:
{{- toYaml . | nindent 8 }}
{{- end }}
spec:
enableServiceLinks: false
automountServiceAccountToken: false
{{- with .Values.ollama.runtimeClassName }}
runtimeClassName: {{ . }}
{{- end }}
containers:
- name: {{ include "ollama.name" . }}
{{- with .Values.ollama.image }}
image: {{ .repository }}:{{ .tag }}
imagePullPolicy: {{ .pullPolicy }}
{{- end }}
tty: true
ports:
- name: http
containerPort: {{ .Values.ollama.service.containerPort }}
env:
{{- if .Values.ollama.gpu.enabled }}
- name: PATH
value: /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
- name: LD_LIBRARY_PATH
value: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
- name: NVIDIA_DRIVER_CAPABILITIES
value: compute,utility
{{- end }}
{{- with .Values.ollama.resources }}
resources: {{- toYaml . | nindent 10 }}
{{- end }}
volumeMounts:
- name: data
mountPath: /root/.ollama
{{- with .Values.ollama.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- with .Values.ollama.tolerations }}
tolerations:
{{- toYaml . | nindent 8 }}
{{- end }}
volumes:
{{- if and .Values.ollama.persistence.enabled .Values.ollama.persistence.existingClaim }}
- name: data
persistentVolumeClaim:
claimName: {{ .Values.ollama.persistence.existingClaim }}
{{- else if not .Values.ollama.persistence.enabled }}
- name: data
emptyDir: {}
{{- else if and .Values.ollama.persistence.enabled (not .Values.ollama.persistence.existingClaim) }}
[]
volumeClaimTemplates:
- metadata:
name: data
labels:
{{- include "ollama.selectorLabels" . | nindent 8 }}
{{- with .Values.ollama.persistence.annotations }}
annotations:
{{- toYaml . | nindent 8 }}
{{- end }}
spec:
accessModes:
{{- range .Values.ollama.persistence.accessModes }}
- {{ . | quote }}
{{- end }}
resources:
requests:
storage: {{ .Values.ollama.persistence.size | quote }}
storageClassName: {{ .Values.ollama.persistence.storageClass }}
{{- with .Values.ollama.persistence.selector }}
selector:
{{- toYaml . | nindent 8 }}
{{- end }}
{{- end }}
{{- end }}
apiVersion: apps/v1
kind: Deployment
metadata:
name: {{ include "open-webui.name" . }}
labels:
{{- include "open-webui.labels" . | nindent 4 }}
{{- with .Values.webui.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
replicas: {{ .Values.webui.replicaCount }}
selector:
matchLabels:
{{- include "open-webui.selectorLabels" . | nindent 6 }}
template:
metadata:
labels:
{{- include "open-webui.labels" . | nindent 8 }}
{{- with .Values.webui.podAnnotations }}
annotations:
{{- toYaml . | nindent 8 }}
{{- end }}
spec:
enableServiceLinks: false
automountServiceAccountToken: false
containers:
- name: {{ .Chart.Name }}
{{- with .Values.webui.image }}
image: {{ .repository }}:{{ .tag | default $.Chart.AppVersion }}
imagePullPolicy: {{ .pullPolicy }}
{{- end }}
ports:
- name: http
containerPort: {{ .Values.webui.service.containerPort }}
{{- with .Values.webui.resources }}
resources: {{- toYaml . | nindent 10 }}
{{- end }}
volumeMounts:
- name: data
mountPath: /app/backend/data
env:
- name: OLLAMA_BASE_URL
value: {{ include "ollama.url" . | quote }}
tty: true
{{- with .Values.webui.nodeSelector }}
nodeSelector:
{{- toYaml . | nindent 8 }}
{{- end }}
volumes:
{{- if and .Values.webui.persistence.enabled .Values.webui.persistence.existingClaim }}
- name: data
persistentVolumeClaim:
claimName: {{ .Values.webui.persistence.existingClaim }}
{{- else if not .Values.webui.persistence.enabled }}
- name: data
emptyDir: {}
{{- else if and .Values.webui.persistence.enabled (not .Values.webui.persistence.existingClaim) }}
- name: data
persistentVolumeClaim:
claimName: {{ include "open-webui.name" . }}
{{- end }}
{{- if .Values.webui.ingress.enabled }}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ include "open-webui.name" . }}
labels:
{{- include "open-webui.labels" . | nindent 4 }}
{{- with .Values.webui.ingress.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
{{- with .Values.webui.ingress.class }}
ingressClassName: {{ . }}
{{- end }}
{{- if .Values.webui.ingress.tls }}
tls:
- hosts:
- {{ .Values.webui.ingress.host | quote }}
secretName: {{ default (printf "%s-tls" .Release.Name) .Values.webui.ingress.existingSecret }}
{{- end }}
rules:
- host: {{ .Values.webui.ingress.host }}
http:
paths:
- path: /
pathType: Prefix
backend:
service:
name: {{ include "open-webui.name" . }}
port:
name: http
{{- end }}
{{- if and .Values.webui.persistence.enabled (not .Values.webui.persistence.existingClaim) }}
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: {{ include "open-webui.name" . }}
labels:
{{- include "open-webui.selectorLabels" . | nindent 4 }}
{{- with .Values.webui.persistence.annotations }}
annotations:
{{- toYaml . | nindent 8 }}
{{- end }}
spec:
accessModes:
{{- range .Values.webui.persistence.accessModes }}
- {{ . | quote }}
{{- end }}
resources:
requests:
storage: {{ .Values.webui.persistence.size }}
{{- if .Values.webui.persistence.storageClass }}
storageClassName: {{ .Values.webui.persistence.storageClass }}
{{- end }}
{{- with .Values.webui.persistence.selector }}
selector:
{{- toYaml . | nindent 4 }}
{{- end }}
{{- end }}
apiVersion: v1
kind: Service
metadata:
name: {{ include "open-webui.name" . }}
labels:
{{- include "open-webui.labels" . | nindent 4 }}
{{- with .Values.webui.service.labels }}
{{- toYaml . | nindent 4 }}
{{- end }}
{{- with .Values.webui.service.annotations }}
annotations:
{{- toYaml . | nindent 4 }}
{{- end }}
spec:
selector:
{{- include "open-webui.selectorLabels" . | nindent 4 }}
type: {{ .Values.webui.service.type | default "ClusterIP" }}
ports:
- protocol: TCP
name: http
port: {{ .Values.webui.service.port }}
targetPort: http
{{- if .Values.webui.service.nodePort }}
nodePort: {{ .Values.webui.service.nodePort | int }}
{{- end }}
{{- if .Values.webui.service.loadBalancerClass }}
loadBalancerClass: {{ .Values.webui.service.loadBalancerClass | quote }}
{{- end }}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment