Unverified Commit b8d7fdf1 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #1965 from open-webui/dev

0.1.124
parents 30b05311 b44ae536
version: 2
updates:
- package-ecosystem: pip
directory: "/backend"
schedule:
interval: daily
time: "13:00"
groups:
python-packages:
patterns:
- "*"
- package-ecosystem: pip
directory: '/backend'
schedule:
interval: daily
time: '13:00'
- package-ecosystem: 'github-actions'
directory: '/'
schedule:
# Check for updates to GitHub Actions every week
interval: 'weekly'
## Pull Request Checklist
- [ ] **Target branch:** Pull requests should target the `dev` branch.
- [ ] **Description:** Briefly describe the changes in this pull request.
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
......
......@@ -20,7 +20,16 @@ jobs:
- name: Build and run Compose Stack
run: |
docker compose up --detach --build
docker compose --file docker-compose.yaml --file docker-compose.api.yaml up --detach --build
- name: Wait for Ollama to be up
timeout-minutes: 5
run: |
until curl --output /dev/null --silent --fail http://localhost:11434; do
printf '.'
sleep 1
done
echo "Service is up!"
- name: Preload Ollama model
run: |
......
......@@ -5,6 +5,28 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.124] - 2024-05-08
### Added
- **🖼️ Improved Chat Sidebar**: Now conveniently displays time ranges and organizes chats by today, yesterday, and more.
- **📜 Citations in RAG Feature**: Easily track the context fed to the LLM with added citations in the RAG feature.
- **🔒 Auth Disable Option**: Introducing the ability to disable authentication. Set 'WEBUI_AUTH' to False to disable authentication. Note: Only applicable for fresh installations without existing users.
- **📹 Enhanced YouTube RAG Pipeline**: Now supports non-English videos for an enriched experience.
- **🔊 Specify OpenAI TTS Models**: Customize your TTS experience by specifying OpenAI TTS models.
- **🔧 Additional Environment Variables**: Discover more environment variables in our comprehensive documentation at Open WebUI Documentation (https://docs.openwebui.com).
- **🌐 Language Support**: Arabic, Finnish, and Hindi added; Improved support for German, Vietnamese, and Chinese.
### Fixed
- **🛠️ Model Selector Styling**: Addressed styling issues for improved user experience.
- **⚠️ Warning Messages**: Resolved backend warning messages.
### Changed
- **📝 Title Generation**: Limited output to 50 tokens.
- **📦 Helm Charts**: Removed Helm charts, now available in a separate repository (https://github.com/open-webui/helm-charts).
## [0.1.123] - 2024-05-02
### Added
......
......@@ -152,7 +152,7 @@ We offer various installation alternatives, including non-Docker methods, Docker
### Troubleshooting
Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s).
Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s).
### Keeping Your Docker Installation Up-to-Date
......
......@@ -43,6 +43,8 @@ from config import (
DEVICE_TYPE,
AUDIO_OPENAI_API_BASE_URL,
AUDIO_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL,
AUDIO_OPENAI_API_VOICE,
)
log = logging.getLogger(__name__)
......@@ -60,6 +62,8 @@ app.add_middleware(
app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
app.state.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
app.state.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
......@@ -72,6 +76,8 @@ SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
class OpenAIConfigUpdateForm(BaseModel):
url: str
key: str
model: str
speaker: str
@app.get("/config")
......@@ -79,6 +85,8 @@ async def get_openai_config(user=Depends(get_admin_user)):
return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE,
}
......@@ -91,11 +99,15 @@ async def update_openai_config(
app.state.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key
app.state.OPENAI_API_MODEL = form_data.model
app.state.OPENAI_API_VOICE = form_data.speaker
return {
"status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE,
}
......
......@@ -36,6 +36,10 @@ from config import (
LITELLM_PROXY_HOST,
)
import warnings
warnings.simplefilter("ignore")
from litellm.utils import get_llm_provider
import asyncio
......
......@@ -25,13 +25,19 @@ import uuid
import aiohttp
import asyncio
import logging
import time
from urllib.parse import urlparse
from typing import Optional, List, Union
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user
from utils.utils import (
decode_token,
get_current_user,
get_verified_user,
get_admin_user,
)
from config import (
......@@ -164,7 +170,7 @@ async def get_all_models():
@app.get("/api/tags")
@app.get("/api/tags/{url_idx}")
async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_current_user)
url_idx: Optional[int] = None, user=Depends(get_verified_user)
):
if url_idx == None:
models = await get_all_models()
......@@ -563,7 +569,7 @@ async def delete_model(
@app.post("/api/show")
async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)):
async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
if form_data.name not in app.state.MODELS:
raise HTTPException(
status_code=400,
......@@ -612,7 +618,7 @@ class GenerateEmbeddingsForm(BaseModel):
async def generate_embeddings(
form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None,
user=Depends(get_current_user),
user=Depends(get_verified_user),
):
if url_idx == None:
model = form_data.model
......@@ -730,7 +736,7 @@ class GenerateCompletionForm(BaseModel):
async def generate_completion(
form_data: GenerateCompletionForm,
url_idx: Optional[int] = None,
user=Depends(get_current_user),
user=Depends(get_verified_user),
):
if url_idx == None:
......@@ -833,7 +839,7 @@ class GenerateChatCompletionForm(BaseModel):
async def generate_chat_completion(
form_data: GenerateChatCompletionForm,
url_idx: Optional[int] = None,
user=Depends(get_current_user),
user=Depends(get_verified_user),
):
if url_idx == None:
......@@ -942,7 +948,7 @@ class OpenAIChatCompletionForm(BaseModel):
async def generate_openai_chat_completion(
form_data: OpenAIChatCompletionForm,
url_idx: Optional[int] = None,
user=Depends(get_current_user),
user=Depends(get_verified_user),
):
if url_idx == None:
......@@ -1026,6 +1032,75 @@ async def generate_openai_chat_completion(
)
@app.get("/v1/models")
@app.get("/v1/models/{url_idx}")
async def get_openai_models(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
if url_idx == None:
models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
models["models"],
)
)
return {
"data": [
{
"id": model["model"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in models["models"]
],
"object": "list",
}
else:
url = app.state.OLLAMA_BASE_URLS[url_idx]
try:
r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status()
models = r.json()
return {
"data": [
{
"id": model["model"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in models["models"]
],
"object": "list",
}
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
class UrlForm(BaseModel):
url: str
......@@ -1241,7 +1316,9 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)):
async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user)
):
url = app.state.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}"
......
......@@ -79,6 +79,7 @@ from config import (
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
RAG_RERANKING_MODEL,
PDF_EXTRACT_IMAGES,
RAG_RERANKING_MODEL_AUTO_UPDATE,
......@@ -90,7 +91,8 @@ from config import (
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
ENABLE_LOCAL_WEB_FETCH,
ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
)
from constants import ERROR_MESSAGES
......@@ -104,6 +106,9 @@ app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
......@@ -113,12 +118,17 @@ app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None
def update_embedding_model(
embedding_model: str,
update_model: bool = False,
......@@ -308,6 +318,11 @@ async def get_rag_config(user=Depends(get_admin_user)):
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
......@@ -316,16 +331,53 @@ class ChunkParamUpdateForm(BaseModel):
chunk_overlap: int
class YoutubeLoaderConfig(BaseModel):
language: List[str]
translation: Optional[str] = None
class ConfigUpdateForm(BaseModel):
pdf_extract_images: bool
chunk: ChunkParamUpdateForm
pdf_extract_images: Optional[bool] = None
chunk: Optional[ChunkParamUpdateForm] = None
web_loader_ssl_verification: Optional[bool] = None
youtube: Optional[YoutubeLoaderConfig] = None
@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
app.state.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
app.state.PDF_EXTRACT_IMAGES = (
form_data.pdf_extract_images
if form_data.pdf_extract_images != None
else app.state.PDF_EXTRACT_IMAGES
)
app.state.CHUNK_SIZE = (
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
)
app.state.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk != None
else app.state.CHUNK_OVERLAP
)
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language
if form_data.youtube != None
else app.state.YOUTUBE_LOADER_LANGUAGE
)
app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation
if form_data.youtube != None
else app.state.YOUTUBE_LOADER_TRANSLATION
)
return {
"status": True,
......@@ -334,6 +386,11 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
......@@ -460,7 +517,12 @@ def query_collection_handler(
@app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
try:
loader = YoutubeLoader.from_youtube_url(form_data.url, add_video_info=False)
loader = YoutubeLoader.from_youtube_url(
form_data.url,
add_video_info=True,
language=app.state.YOUTUBE_LOADER_LANGUAGE,
translation=app.state.YOUTUBE_LOADER_TRANSLATION,
)
data = loader.load()
collection_name = form_data.collection_name
......@@ -485,7 +547,9 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = get_web_loader(form_data.url)
loader = get_web_loader(
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
data = loader.load()
collection_name = form_data.collection_name
......@@ -506,11 +570,11 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
)
def get_web_loader(url: str):
def get_web_loader(url: str, verify_ssl: bool = True):
# Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH:
if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
......@@ -523,7 +587,7 @@ def get_web_loader(url: str):
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url)
return WebBaseLoader(url, verify_ssl=verify_ssl)
def resolve_hostname(hostname):
......@@ -594,7 +658,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
for batch in create_batches(
api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts],
ids=[str(uuid.uuid4()) for _ in texts],
metadatas=metadatas,
embeddings=embeddings,
documents=texts,
......
......@@ -271,14 +271,14 @@ def rag_messages(
for doc in docs:
context = None
collection = doc.get("collection_name")
if collection:
collection = [collection]
else:
collection = doc.get("collection_names", [])
collection_names = (
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
)
collection = set(collection).difference(extracted_collections)
if not collection:
collection_names = set(collection_names).difference(extracted_collections)
if not collection_names:
log.debug(f"skipping {doc} as it has already been extracted")
continue
......@@ -288,11 +288,7 @@ def rag_messages(
else:
if hybrid_search:
context = query_collection_with_hybrid_search(
collection_names=(
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
......@@ -301,11 +297,7 @@ def rag_messages(
)
else:
context = query_collection(
collection_names=(
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
collection_names=collection_names,
query=query,
embedding_function=embedding_function,
k=k,
......@@ -315,18 +307,31 @@ def rag_messages(
context = None
if context:
relevant_contexts.append(context)
relevant_contexts.append({**context, "source": doc})
extracted_collections.extend(collection)
extracted_collections.extend(collection_names)
context_string = ""
citations = []
for context in relevant_contexts:
try:
if "documents" in context:
items = [item for item in context["documents"][0] if item is not None]
context_string += "\n\n".join(items)
context_string += "\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
if "metadatas" in context:
citations.append(
{
"source": context["source"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
)
except Exception as e:
log.exception(e)
context_string = context_string.strip()
ra_content = rag_template(
......@@ -355,7 +360,7 @@ def rag_messages(
messages[last_user_message_idx] = new_user_message
return messages
return messages, citations
def get_model_path(model: str, update_model: bool = False):
......
......@@ -33,7 +33,7 @@ from utils.utils import (
from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import WEBUI_AUTH_TRUSTED_EMAIL_HEADER
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER
router = APIRouter()
......@@ -118,6 +118,21 @@ async def signin(request: Request, form_data: SigninForm):
),
)
user = Auths.authenticate_user_by_trusted_header(trusted_email)
if WEBUI_AUTH == False:
if Users.get_num_users() != 0:
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
admin_email = "admin@localhost"
admin_password = "admin"
if not Users.get_user_by_email(admin_email.lower()):
await signup(
request,
SignupForm(email=admin_email, password=admin_password, name="User"),
)
user = Auths.authenticate_user(admin_email.lower(), admin_password)
else:
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
......
......@@ -93,6 +93,31 @@ async def get_archived_session_user_chat_list(
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################
# GetChats
############################
......@@ -141,6 +166,55 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
)
############################
# GetChatsByTags
############################
class TagNameForm(BaseModel):
name: str
skip: Optional[int] = 0
limit: Optional[int] = 50
@router.post("/tags", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_current_user)
):
print(form_data)
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
form_data.name, user.id
)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
return chats
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatById
############################
......@@ -274,70 +348,6 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
)
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, skip, limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id)
return chats
############################
# GetChatTagsById
############################
......
......@@ -18,6 +18,18 @@ from secrets import token_bytes
from constants import ERROR_MESSAGES
####################################
# Load .env file
####################################
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv("../.env"))
except ImportError:
print("dotenv not installed, skipping...")
####################################
# LOGGING
####################################
......@@ -59,23 +71,16 @@ for source in log_sources:
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
####################################
# Load .env file
####################################
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv("../.env"))
except ImportError:
log.warning("dotenv not installed, skipping...")
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
####################################
# ENV (dev,test,prod)
####################################
......@@ -408,7 +413,7 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = True
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
......@@ -454,6 +459,11 @@ ENABLE_RAG_HYBRID_SEARCH = (
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
)
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true"
)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
......@@ -483,13 +493,6 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
if USE_CUDA.lower() == "true":
DEVICE_TYPE = "cuda"
else:
DEVICE_TYPE = "cpu"
if CHROMA_HTTP_HOST != "":
CHROMA_CLIENT = chromadb.HttpClient(
......@@ -509,6 +512,16 @@ else:
database=CHROMA_DATABASE,
)
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
if USE_CUDA.lower() == "true":
DEVICE_TYPE = "cuda"
else:
DEVICE_TYPE = "cpu"
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
......@@ -531,7 +544,11 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true"
ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
)
YOUTUBE_LOADER_LANGUAGE = os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(",")
####################################
# Transcribe
......@@ -574,6 +591,8 @@ IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "")
AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
AUDIO_OPENAI_API_MODEL = os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1")
AUDIO_OPENAI_API_VOICE = os.getenv("AUDIO_OPENAI_API_VOICE", "alloy")
####################################
# LiteLLM
......
......@@ -42,6 +42,9 @@ class ERROR_MESSAGES(str, Enum):
"The password provided is incorrect. Please check for typos and try again."
)
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation."
UNAUTHORIZED = "401 Unauthorized"
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED = (
......
......@@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse, Response
from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
......@@ -25,6 +25,8 @@ from apps.litellm.main import (
start_litellm_background,
shutdown_litellm_background,
)
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
......@@ -41,6 +43,8 @@ from apps.rag.utils import rag_messages
from config import (
CONFIG_DATA,
WEBUI_NAME,
WEBUI_URL,
WEBUI_AUTH,
ENV,
VERSION,
CHANGELOG,
......@@ -74,7 +78,7 @@ class SPAStaticFiles(StaticFiles):
print(
f"""
rf"""
___ __ __ _ _ _ ___
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
......@@ -100,6 +104,8 @@ origins = ["*"]
class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
return_citations = False
if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
):
......@@ -112,11 +118,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if "docs" in data:
data = {**data}
data["messages"] = rag_messages(
data["messages"], citations = rag_messages(
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.RAG_TEMPLATE,
......@@ -128,7 +138,9 @@ class RAGMiddleware(BaseHTTPMiddleware):
)
del data["docs"]
log.debug(f"data['messages']: {data['messages']}")
log.debug(
f"data['messages']: {data['messages']}, citations: {citations}"
)
modified_body_bytes = json.dumps(data).encode("utf-8")
......@@ -146,11 +158,36 @@ class RAGMiddleware(BaseHTTPMiddleware):
]
response = await call_next(request)
if return_citations:
# Inject the citations into the response
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, citations),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, citations),
)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, citations):
yield f"data: {json.dumps({'citations': citations})}\n\n"
async for data in original_generator:
yield data
async def ollama_stream_wrapper(self, original_generator, citations):
yield f"{json.dumps({'citations': citations})}\n"
async for data in original_generator:
yield data
app.add_middleware(RAGMiddleware)
......@@ -204,6 +241,7 @@ async def get_app_config():
"status": True,
"name": WEBUI_NAME,
"version": VERSION,
"auth": WEBUI_AUTH,
"default_locale": default_locale,
"images": images_app.state.ENABLED,
"default_models": webui_app.state.DEFAULT_MODELS,
......@@ -315,6 +353,21 @@ async def get_manifest_json():
}
@app.get("/opensearch.xml")
async def get_opensearch_xml():
xml_content = rf"""
<OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
<ShortName>{WEBUI_NAME}</ShortName>
<Description>Search {WEBUI_NAME}</Description>
<InputEncoding>UTF-8</InputEncoding>
<Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image>
<Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
<moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
</OpenSearchDescription>
"""
return Response(content=xml_content, media_type="application/xml")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
......
......@@ -9,7 +9,6 @@ Flask-Cors==4.0.0
python-socketio==5.11.2
python-jose==3.3.0
passlib[bcrypt]==1.7.4
uuid==1.30
requests==2.31.0
aiohttp==3.9.5
......@@ -19,7 +18,6 @@ psycopg2-binary==2.9.9
PyMySQL==1.1.0
bcrypt==4.1.2
litellm==1.35.28
litellm[proxy]==1.35.28
boto3==1.34.95
......@@ -54,9 +52,9 @@ rank-bm25==0.2.2
faster-whisper==1.0.1
PyJWT==2.8.0
PyJWT[crypto]==2.8.0
black==24.4.2
langfuse==2.27.3
youtube-transcript-api
youtube-transcript-api==0.6.2
pytube
\ No newline at end of file
......@@ -8,7 +8,7 @@ KEY_FILE=.webui_secret_key
PORT="${PORT:-8080}"
HOST="${HOST:-0.0.0.0}"
if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
echo "No WEBUI_SECRET_KEY provided"
echo "Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
if ! [ -e "$KEY_FILE" ]; then
echo "Generating WEBUI_SECRET_KEY"
......
......@@ -13,7 +13,7 @@ SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
:: Check if WEBUI_SECRET_KEY and WEBUI_JWT_SECRET_KEY are not set
IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
echo No WEBUI_SECRET_KEY provided
echo Loading WEBUI_SECRET_KEY from file, not provided as an environment variable.
IF NOT EXIST "%KEY_FILE%" (
echo Generating WEBUI_SECRET_KEY
......
......@@ -38,9 +38,10 @@ def calculate_sha256_string(string):
def validate_email_format(email: str) -> bool:
if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
return False
return True
if email.endswith("@localhost"):
return True
return bool(re.match(r"[^@]+@[^@]+\.[^@]+", email))
def sanitize_filename(file_name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment