Unverified Commit b8d7fdf1 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #1965 from open-webui/dev

0.1.124
parents 30b05311 b44ae536
version: 2 version: 2
updates: updates:
- package-ecosystem: pip - package-ecosystem: pip
directory: "/backend" directory: '/backend'
schedule: schedule:
interval: daily interval: daily
time: "13:00" time: '13:00'
groups: - package-ecosystem: 'github-actions'
python-packages: directory: '/'
patterns: schedule:
- "*" # Check for updates to GitHub Actions every week
interval: 'weekly'
## Pull Request Checklist ## Pull Request Checklist
- [ ] **Target branch:** Pull requests should target the `dev` branch.
- [ ] **Description:** Briefly describe the changes in this pull request. - [ ] **Description:** Briefly describe the changes in this pull request.
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description. - [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources? - [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
......
...@@ -20,7 +20,16 @@ jobs: ...@@ -20,7 +20,16 @@ jobs:
- name: Build and run Compose Stack - name: Build and run Compose Stack
run: | run: |
docker compose up --detach --build docker compose --file docker-compose.yaml --file docker-compose.api.yaml up --detach --build
- name: Wait for Ollama to be up
timeout-minutes: 5
run: |
until curl --output /dev/null --silent --fail http://localhost:11434; do
printf '.'
sleep 1
done
echo "Service is up!"
- name: Preload Ollama model - name: Preload Ollama model
run: | run: |
......
...@@ -5,6 +5,28 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,28 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.124] - 2024-05-08
### Added
- **🖼️ Improved Chat Sidebar**: Now conveniently displays time ranges and organizes chats by today, yesterday, and more.
- **📜 Citations in RAG Feature**: Easily track the context fed to the LLM with added citations in the RAG feature.
- **🔒 Auth Disable Option**: Introducing the ability to disable authentication. Set 'WEBUI_AUTH' to False to disable authentication. Note: Only applicable for fresh installations without existing users.
- **📹 Enhanced YouTube RAG Pipeline**: Now supports non-English videos for an enriched experience.
- **🔊 Specify OpenAI TTS Models**: Customize your TTS experience by specifying OpenAI TTS models.
- **🔧 Additional Environment Variables**: Discover more environment variables in our comprehensive documentation at Open WebUI Documentation (https://docs.openwebui.com).
- **🌐 Language Support**: Arabic, Finnish, and Hindi added; Improved support for German, Vietnamese, and Chinese.
### Fixed
- **🛠️ Model Selector Styling**: Addressed styling issues for improved user experience.
- **⚠️ Warning Messages**: Resolved backend warning messages.
### Changed
- **📝 Title Generation**: Limited output to 50 tokens.
- **📦 Helm Charts**: Removed Helm charts, now available in a separate repository (https://github.com/open-webui/helm-charts).
## [0.1.123] - 2024-05-02 ## [0.1.123] - 2024-05-02
### Added ### Added
......
...@@ -152,7 +152,7 @@ We offer various installation alternatives, including non-Docker methods, Docker ...@@ -152,7 +152,7 @@ We offer various installation alternatives, including non-Docker methods, Docker
### Troubleshooting ### Troubleshooting
Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/getting-started/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s). Encountering connection issues? Our [Open WebUI Documentation](https://docs.openwebui.com/troubleshooting/) has got you covered. For further assistance and to join our vibrant community, visit the [Open WebUI Discord](https://discord.gg/5rJgQTnV4s).
### Keeping Your Docker Installation Up-to-Date ### Keeping Your Docker Installation Up-to-Date
......
...@@ -43,6 +43,8 @@ from config import ( ...@@ -43,6 +43,8 @@ from config import (
DEVICE_TYPE, DEVICE_TYPE,
AUDIO_OPENAI_API_BASE_URL, AUDIO_OPENAI_API_BASE_URL,
AUDIO_OPENAI_API_KEY, AUDIO_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL,
AUDIO_OPENAI_API_VOICE,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -60,6 +62,8 @@ app.add_middleware( ...@@ -60,6 +62,8 @@ app.add_middleware(
app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
app.state.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
app.state.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
# setting device type for whisper model # setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
...@@ -72,6 +76,8 @@ SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True) ...@@ -72,6 +76,8 @@ SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
class OpenAIConfigUpdateForm(BaseModel): class OpenAIConfigUpdateForm(BaseModel):
url: str url: str
key: str key: str
model: str
speaker: str
@app.get("/config") @app.get("/config")
...@@ -79,6 +85,8 @@ async def get_openai_config(user=Depends(get_admin_user)): ...@@ -79,6 +85,8 @@ async def get_openai_config(user=Depends(get_admin_user)):
return { return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE,
} }
...@@ -91,11 +99,15 @@ async def update_openai_config( ...@@ -91,11 +99,15 @@ async def update_openai_config(
app.state.OPENAI_API_BASE_URL = form_data.url app.state.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key app.state.OPENAI_API_KEY = form_data.key
app.state.OPENAI_API_MODEL = form_data.model
app.state.OPENAI_API_VOICE = form_data.speaker
return { return {
"status": True, "status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE,
} }
......
...@@ -36,6 +36,10 @@ from config import ( ...@@ -36,6 +36,10 @@ from config import (
LITELLM_PROXY_HOST, LITELLM_PROXY_HOST,
) )
import warnings
warnings.simplefilter("ignore")
from litellm.utils import get_llm_provider from litellm.utils import get_llm_provider
import asyncio import asyncio
......
...@@ -25,13 +25,19 @@ import uuid ...@@ -25,13 +25,19 @@ import uuid
import aiohttp import aiohttp
import asyncio import asyncio
import logging import logging
import time
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional, List, Union from typing import Optional, List, Union
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user from utils.utils import (
decode_token,
get_current_user,
get_verified_user,
get_admin_user,
)
from config import ( from config import (
...@@ -164,7 +170,7 @@ async def get_all_models(): ...@@ -164,7 +170,7 @@ async def get_all_models():
@app.get("/api/tags") @app.get("/api/tags")
@app.get("/api/tags/{url_idx}") @app.get("/api/tags/{url_idx}")
async def get_ollama_tags( async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_current_user) url_idx: Optional[int] = None, user=Depends(get_verified_user)
): ):
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
...@@ -563,7 +569,7 @@ async def delete_model( ...@@ -563,7 +569,7 @@ async def delete_model(
@app.post("/api/show") @app.post("/api/show")
async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)): async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
if form_data.name not in app.state.MODELS: if form_data.name not in app.state.MODELS:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
...@@ -612,7 +618,7 @@ class GenerateEmbeddingsForm(BaseModel): ...@@ -612,7 +618,7 @@ class GenerateEmbeddingsForm(BaseModel):
async def generate_embeddings( async def generate_embeddings(
form_data: GenerateEmbeddingsForm, form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx == None:
model = form_data.model model = form_data.model
...@@ -730,7 +736,7 @@ class GenerateCompletionForm(BaseModel): ...@@ -730,7 +736,7 @@ class GenerateCompletionForm(BaseModel):
async def generate_completion( async def generate_completion(
form_data: GenerateCompletionForm, form_data: GenerateCompletionForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx == None:
...@@ -833,7 +839,7 @@ class GenerateChatCompletionForm(BaseModel): ...@@ -833,7 +839,7 @@ class GenerateChatCompletionForm(BaseModel):
async def generate_chat_completion( async def generate_chat_completion(
form_data: GenerateChatCompletionForm, form_data: GenerateChatCompletionForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx == None:
...@@ -942,7 +948,7 @@ class OpenAIChatCompletionForm(BaseModel): ...@@ -942,7 +948,7 @@ class OpenAIChatCompletionForm(BaseModel):
async def generate_openai_chat_completion( async def generate_openai_chat_completion(
form_data: OpenAIChatCompletionForm, form_data: OpenAIChatCompletionForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_current_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx == None:
...@@ -1026,6 +1032,75 @@ async def generate_openai_chat_completion( ...@@ -1026,6 +1032,75 @@ async def generate_openai_chat_completion(
) )
@app.get("/v1/models")
@app.get("/v1/models/{url_idx}")
async def get_openai_models(
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
if url_idx == None:
models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
models["models"],
)
)
return {
"data": [
{
"id": model["model"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in models["models"]
],
"object": "list",
}
else:
url = app.state.OLLAMA_BASE_URLS[url_idx]
try:
r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status()
models = r.json()
return {
"data": [
{
"id": model["model"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in models["models"]
],
"object": "list",
}
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
class UrlForm(BaseModel): class UrlForm(BaseModel):
url: str url: str
...@@ -1241,7 +1316,9 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): ...@@ -1241,7 +1316,9 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user)
):
url = app.state.OLLAMA_BASE_URLS[0] url = app.state.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}" target_url = f"{url}/{path}"
......
...@@ -79,6 +79,7 @@ from config import ( ...@@ -79,6 +79,7 @@ from config import (
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_AUTO_UPDATE,
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE, RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
ENABLE_RAG_HYBRID_SEARCH, ENABLE_RAG_HYBRID_SEARCH,
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
RAG_RERANKING_MODEL, RAG_RERANKING_MODEL,
PDF_EXTRACT_IMAGES, PDF_EXTRACT_IMAGES,
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
...@@ -90,7 +91,8 @@ from config import ( ...@@ -90,7 +91,8 @@ from config import (
CHUNK_SIZE, CHUNK_SIZE,
CHUNK_OVERLAP, CHUNK_OVERLAP,
RAG_TEMPLATE, RAG_TEMPLATE,
ENABLE_LOCAL_WEB_FETCH, ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -104,6 +106,9 @@ app.state.TOP_K = RAG_TOP_K ...@@ -104,6 +106,9 @@ app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.CHUNK_SIZE = CHUNK_SIZE app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
...@@ -113,12 +118,17 @@ app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL ...@@ -113,12 +118,17 @@ app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None
def update_embedding_model( def update_embedding_model(
embedding_model: str, embedding_model: str,
update_model: bool = False, update_model: bool = False,
...@@ -308,6 +318,11 @@ async def get_rag_config(user=Depends(get_admin_user)): ...@@ -308,6 +318,11 @@ async def get_rag_config(user=Depends(get_admin_user)):
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.CHUNK_OVERLAP,
}, },
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
} }
...@@ -316,16 +331,53 @@ class ChunkParamUpdateForm(BaseModel): ...@@ -316,16 +331,53 @@ class ChunkParamUpdateForm(BaseModel):
chunk_overlap: int chunk_overlap: int
class YoutubeLoaderConfig(BaseModel):
language: List[str]
translation: Optional[str] = None
class ConfigUpdateForm(BaseModel): class ConfigUpdateForm(BaseModel):
pdf_extract_images: bool pdf_extract_images: Optional[bool] = None
chunk: ChunkParamUpdateForm chunk: Optional[ChunkParamUpdateForm] = None
web_loader_ssl_verification: Optional[bool] = None
youtube: Optional[YoutubeLoaderConfig] = None
@app.post("/config/update") @app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images app.state.PDF_EXTRACT_IMAGES = (
app.state.CHUNK_SIZE = form_data.chunk.chunk_size form_data.pdf_extract_images
app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap if form_data.pdf_extract_images != None
else app.state.PDF_EXTRACT_IMAGES
)
app.state.CHUNK_SIZE = (
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
)
app.state.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk != None
else app.state.CHUNK_OVERLAP
)
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language
if form_data.youtube != None
else app.state.YOUTUBE_LOADER_LANGUAGE
)
app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation
if form_data.youtube != None
else app.state.YOUTUBE_LOADER_TRANSLATION
)
return { return {
"status": True, "status": True,
...@@ -334,6 +386,11 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ ...@@ -334,6 +386,11 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.CHUNK_OVERLAP,
}, },
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
} }
...@@ -460,7 +517,12 @@ def query_collection_handler( ...@@ -460,7 +517,12 @@ def query_collection_handler(
@app.post("/youtube") @app.post("/youtube")
def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
try: try:
loader = YoutubeLoader.from_youtube_url(form_data.url, add_video_info=False) loader = YoutubeLoader.from_youtube_url(
form_data.url,
add_video_info=True,
language=app.state.YOUTUBE_LOADER_LANGUAGE,
translation=app.state.YOUTUBE_LOADER_TRANSLATION,
)
data = loader.load() data = loader.load()
collection_name = form_data.collection_name collection_name = form_data.collection_name
...@@ -485,7 +547,9 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): ...@@ -485,7 +547,9 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
def store_web(form_data: UrlForm, user=Depends(get_current_user)): def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = get_web_loader(form_data.url) loader = get_web_loader(
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
data = loader.load() data = loader.load()
collection_name = form_data.collection_name collection_name = form_data.collection_name
...@@ -506,11 +570,11 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)): ...@@ -506,11 +570,11 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
) )
def get_web_loader(url: str): def get_web_loader(url: str, verify_ssl: bool = True):
# Check if the URL is valid # Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError): if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_LOCAL_WEB_FETCH: if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses # Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url) parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses # Get IPv4 and IPv6 addresses
...@@ -523,7 +587,7 @@ def get_web_loader(url: str): ...@@ -523,7 +587,7 @@ def get_web_loader(url: str):
for ip in ipv6_addresses: for ip in ipv6_addresses:
if validators.ipv6(ip, private=True): if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL) raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url) return WebBaseLoader(url, verify_ssl=verify_ssl)
def resolve_hostname(hostname): def resolve_hostname(hostname):
...@@ -594,7 +658,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -594,7 +658,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
for batch in create_batches( for batch in create_batches(
api=CHROMA_CLIENT, api=CHROMA_CLIENT,
ids=[str(uuid.uuid1()) for _ in texts], ids=[str(uuid.uuid4()) for _ in texts],
metadatas=metadatas, metadatas=metadatas,
embeddings=embeddings, embeddings=embeddings,
documents=texts, documents=texts,
......
...@@ -271,14 +271,14 @@ def rag_messages( ...@@ -271,14 +271,14 @@ def rag_messages(
for doc in docs: for doc in docs:
context = None context = None
collection = doc.get("collection_name") collection_names = (
if collection: doc["collection_names"]
collection = [collection] if doc["type"] == "collection"
else: else [doc["collection_name"]]
collection = doc.get("collection_names", []) )
collection = set(collection).difference(extracted_collections) collection_names = set(collection_names).difference(extracted_collections)
if not collection: if not collection_names:
log.debug(f"skipping {doc} as it has already been extracted") log.debug(f"skipping {doc} as it has already been extracted")
continue continue
...@@ -288,11 +288,7 @@ def rag_messages( ...@@ -288,11 +288,7 @@ def rag_messages(
else: else:
if hybrid_search: if hybrid_search:
context = query_collection_with_hybrid_search( context = query_collection_with_hybrid_search(
collection_names=( collection_names=collection_names,
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
query=query, query=query,
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,
...@@ -301,11 +297,7 @@ def rag_messages( ...@@ -301,11 +297,7 @@ def rag_messages(
) )
else: else:
context = query_collection( context = query_collection(
collection_names=( collection_names=collection_names,
doc["collection_names"]
if doc["type"] == "collection"
else [doc["collection_name"]]
),
query=query, query=query,
embedding_function=embedding_function, embedding_function=embedding_function,
k=k, k=k,
...@@ -315,18 +307,31 @@ def rag_messages( ...@@ -315,18 +307,31 @@ def rag_messages(
context = None context = None
if context: if context:
relevant_contexts.append(context) relevant_contexts.append({**context, "source": doc})
extracted_collections.extend(collection) extracted_collections.extend(collection_names)
context_string = "" context_string = ""
citations = []
for context in relevant_contexts: for context in relevant_contexts:
try: try:
if "documents" in context: if "documents" in context:
items = [item for item in context["documents"][0] if item is not None] context_string += "\n\n".join(
context_string += "\n\n".join(items) [text for text in context["documents"][0] if text is not None]
)
if "metadatas" in context:
citations.append(
{
"source": context["source"],
"document": context["documents"][0],
"metadata": context["metadatas"][0],
}
)
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
context_string = context_string.strip() context_string = context_string.strip()
ra_content = rag_template( ra_content = rag_template(
...@@ -355,7 +360,7 @@ def rag_messages( ...@@ -355,7 +360,7 @@ def rag_messages(
messages[last_user_message_idx] = new_user_message messages[last_user_message_idx] = new_user_message
return messages return messages, citations
def get_model_path(model: str, update_model: bool = False): def get_model_path(model: str, update_model: bool = False):
......
...@@ -33,7 +33,7 @@ from utils.utils import ( ...@@ -33,7 +33,7 @@ from utils.utils import (
from utils.misc import parse_duration, validate_email_format from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import WEBUI_AUTH_TRUSTED_EMAIL_HEADER from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER
router = APIRouter() router = APIRouter()
...@@ -118,6 +118,21 @@ async def signin(request: Request, form_data: SigninForm): ...@@ -118,6 +118,21 @@ async def signin(request: Request, form_data: SigninForm):
), ),
) )
user = Auths.authenticate_user_by_trusted_header(trusted_email) user = Auths.authenticate_user_by_trusted_header(trusted_email)
if WEBUI_AUTH == False:
if Users.get_num_users() != 0:
raise HTTPException(400, detail=ERROR_MESSAGES.EXISTING_USERS)
admin_email = "admin@localhost"
admin_password = "admin"
if not Users.get_user_by_email(admin_email.lower()):
await signup(
request,
SignupForm(email=admin_email, password=admin_password, name="User"),
)
user = Auths.authenticate_user(admin_email.lower(), admin_password)
else: else:
user = Auths.authenticate_user(form_data.email.lower(), form_data.password) user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
......
...@@ -93,6 +93,31 @@ async def get_archived_session_user_chat_list( ...@@ -93,6 +93,31 @@ async def get_archived_session_user_chat_list(
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit) return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################ ############################
# GetChats # GetChats
############################ ############################
...@@ -141,6 +166,55 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): ...@@ -141,6 +166,55 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
) )
############################
# GetChatsByTags
############################
class TagNameForm(BaseModel):
name: str
skip: Optional[int] = 0
limit: Optional[int] = 50
@router.post("/tags", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
form_data: TagNameForm, user=Depends(get_current_user)
):
print(form_data)
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(
form_data.name, user.id
)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, form_data.skip, form_data.limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(form_data.name, user.id)
return chats
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################ ############################
# GetChatById # GetChatById
############################ ############################
...@@ -274,70 +348,6 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)): ...@@ -274,70 +348,6 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
) )
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
if user.role == "pending":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
if user.role == "user":
chat = Chats.get_chat_by_share_id(share_id)
elif user.role == "admin":
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, skip, limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id)
return chats
############################ ############################
# GetChatTagsById # GetChatTagsById
############################ ############################
......
...@@ -18,6 +18,18 @@ from secrets import token_bytes ...@@ -18,6 +18,18 @@ from secrets import token_bytes
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
####################################
# Load .env file
####################################
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv("../.env"))
except ImportError:
print("dotenv not installed, skipping...")
#################################### ####################################
# LOGGING # LOGGING
#################################### ####################################
...@@ -59,23 +71,16 @@ for source in log_sources: ...@@ -59,23 +71,16 @@ for source in log_sources:
log.setLevel(SRC_LOG_LEVELS["CONFIG"]) log.setLevel(SRC_LOG_LEVELS["CONFIG"])
####################################
# Load .env file
####################################
try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv("../.env"))
except ImportError:
log.warning("dotenv not installed, skipping...")
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI") WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI": if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)" WEBUI_NAME += " (Open WebUI)"
WEBUI_URL = os.environ.get("WEBUI_URL", "http://localhost:3000")
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png" WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
#################################### ####################################
# ENV (dev,test,prod) # ENV (dev,test,prod)
#################################### ####################################
...@@ -408,7 +413,7 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100") ...@@ -408,7 +413,7 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
# WEBUI_AUTH (Required for security) # WEBUI_AUTH (Required for security)
#################################### ####################################
WEBUI_AUTH = True WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
) )
...@@ -454,6 +459,11 @@ ENABLE_RAG_HYBRID_SEARCH = ( ...@@ -454,6 +459,11 @@ ENABLE_RAG_HYBRID_SEARCH = (
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true" os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
) )
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true"
)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "") RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true" PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
...@@ -483,13 +493,6 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = ( ...@@ -483,13 +493,6 @@ RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true" os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
) )
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
if USE_CUDA.lower() == "true":
DEVICE_TYPE = "cuda"
else:
DEVICE_TYPE = "cpu"
if CHROMA_HTTP_HOST != "": if CHROMA_HTTP_HOST != "":
CHROMA_CLIENT = chromadb.HttpClient( CHROMA_CLIENT = chromadb.HttpClient(
...@@ -509,6 +512,16 @@ else: ...@@ -509,6 +512,16 @@ else:
database=CHROMA_DATABASE, database=CHROMA_DATABASE,
) )
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
if USE_CUDA.lower() == "true":
DEVICE_TYPE = "cuda"
else:
DEVICE_TYPE = "cpu"
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500")) CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100")) CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
...@@ -531,7 +544,11 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE) ...@@ -531,7 +544,11 @@ RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY) RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
ENABLE_LOCAL_WEB_FETCH = os.getenv("ENABLE_LOCAL_WEB_FETCH", "False").lower() == "true" ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
)
YOUTUBE_LOADER_LANGUAGE = os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(",")
#################################### ####################################
# Transcribe # Transcribe
...@@ -574,6 +591,8 @@ IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "") ...@@ -574,6 +591,8 @@ IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "")
AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL) AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY) AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
AUDIO_OPENAI_API_MODEL = os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1")
AUDIO_OPENAI_API_VOICE = os.getenv("AUDIO_OPENAI_API_VOICE", "alloy")
#################################### ####################################
# LiteLLM # LiteLLM
......
...@@ -42,6 +42,9 @@ class ERROR_MESSAGES(str, Enum): ...@@ -42,6 +42,9 @@ class ERROR_MESSAGES(str, Enum):
"The password provided is incorrect. Please check for typos and try again." "The password provided is incorrect. Please check for typos and try again."
) )
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance." INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
EXISTING_USERS = "You can't turn off authentication because there are existing users. If you want to disable WEBUI_AUTH, make sure your web interface doesn't have any existing users and is a fresh installation."
UNAUTHORIZED = "401 Unauthorized" UNAUTHORIZED = "401 Unauthorized"
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED = ( ACTION_PROHIBITED = (
......
...@@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware ...@@ -15,7 +15,7 @@ from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse, Response
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app from apps.openai.main import app as openai_app
...@@ -25,6 +25,8 @@ from apps.litellm.main import ( ...@@ -25,6 +25,8 @@ from apps.litellm.main import (
start_litellm_background, start_litellm_background,
shutdown_litellm_background, shutdown_litellm_background,
) )
from apps.audio.main import app as audio_app from apps.audio.main import app as audio_app
from apps.images.main import app as images_app from apps.images.main import app as images_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
...@@ -41,6 +43,8 @@ from apps.rag.utils import rag_messages ...@@ -41,6 +43,8 @@ from apps.rag.utils import rag_messages
from config import ( from config import (
CONFIG_DATA, CONFIG_DATA,
WEBUI_NAME, WEBUI_NAME,
WEBUI_URL,
WEBUI_AUTH,
ENV, ENV,
VERSION, VERSION,
CHANGELOG, CHANGELOG,
...@@ -74,7 +78,7 @@ class SPAStaticFiles(StaticFiles): ...@@ -74,7 +78,7 @@ class SPAStaticFiles(StaticFiles):
print( print(
f""" rf"""
___ __ __ _ _ _ ___ ___ __ __ _ _ _ ___
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _| / _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || | | | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
...@@ -100,6 +104,8 @@ origins = ["*"] ...@@ -100,6 +104,8 @@ origins = ["*"]
class RAGMiddleware(BaseHTTPMiddleware): class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
return_citations = False
if request.method == "POST" and ( if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path "/api/chat" in request.url.path or "/chat/completions" in request.url.path
): ):
...@@ -112,11 +118,15 @@ class RAGMiddleware(BaseHTTPMiddleware): ...@@ -112,11 +118,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON # Parse string to JSON
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Example: Add a new key-value pair or modify existing ones # Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification # data["modified"] = True # Example modification
if "docs" in data: if "docs" in data:
data = {**data} data = {**data}
data["messages"] = rag_messages( data["messages"], citations = rag_messages(
docs=data["docs"], docs=data["docs"],
messages=data["messages"], messages=data["messages"],
template=rag_app.state.RAG_TEMPLATE, template=rag_app.state.RAG_TEMPLATE,
...@@ -128,7 +138,9 @@ class RAGMiddleware(BaseHTTPMiddleware): ...@@ -128,7 +138,9 @@ class RAGMiddleware(BaseHTTPMiddleware):
) )
del data["docs"] del data["docs"]
log.debug(f"data['messages']: {data['messages']}") log.debug(
f"data['messages']: {data['messages']}, citations: {citations}"
)
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
...@@ -146,11 +158,36 @@ class RAGMiddleware(BaseHTTPMiddleware): ...@@ -146,11 +158,36 @@ class RAGMiddleware(BaseHTTPMiddleware):
] ]
response = await call_next(request) response = await call_next(request)
if return_citations:
# Inject the citations into the response
if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type:
return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, citations),
)
if "application/x-ndjson" in content_type:
return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, citations),
)
return response return response
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False} return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, citations):
yield f"data: {json.dumps({'citations': citations})}\n\n"
async for data in original_generator:
yield data
async def ollama_stream_wrapper(self, original_generator, citations):
yield f"{json.dumps({'citations': citations})}\n"
async for data in original_generator:
yield data
app.add_middleware(RAGMiddleware) app.add_middleware(RAGMiddleware)
...@@ -204,6 +241,7 @@ async def get_app_config(): ...@@ -204,6 +241,7 @@ async def get_app_config():
"status": True, "status": True,
"name": WEBUI_NAME, "name": WEBUI_NAME,
"version": VERSION, "version": VERSION,
"auth": WEBUI_AUTH,
"default_locale": default_locale, "default_locale": default_locale,
"images": images_app.state.ENABLED, "images": images_app.state.ENABLED,
"default_models": webui_app.state.DEFAULT_MODELS, "default_models": webui_app.state.DEFAULT_MODELS,
...@@ -315,6 +353,21 @@ async def get_manifest_json(): ...@@ -315,6 +353,21 @@ async def get_manifest_json():
} }
@app.get("/opensearch.xml")
async def get_opensearch_xml():
xml_content = rf"""
<OpenSearchDescription xmlns="http://a9.com/-/spec/opensearch/1.1/" xmlns:moz="http://www.mozilla.org/2006/browser/search/">
<ShortName>{WEBUI_NAME}</ShortName>
<Description>Search {WEBUI_NAME}</Description>
<InputEncoding>UTF-8</InputEncoding>
<Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image>
<Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
<moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
</OpenSearchDescription>
"""
return Response(content=xml_content, media_type="application/xml")
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache") app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
......
...@@ -9,7 +9,6 @@ Flask-Cors==4.0.0 ...@@ -9,7 +9,6 @@ Flask-Cors==4.0.0
python-socketio==5.11.2 python-socketio==5.11.2
python-jose==3.3.0 python-jose==3.3.0
passlib[bcrypt]==1.7.4 passlib[bcrypt]==1.7.4
uuid==1.30
requests==2.31.0 requests==2.31.0
aiohttp==3.9.5 aiohttp==3.9.5
...@@ -19,7 +18,6 @@ psycopg2-binary==2.9.9 ...@@ -19,7 +18,6 @@ psycopg2-binary==2.9.9
PyMySQL==1.1.0 PyMySQL==1.1.0
bcrypt==4.1.2 bcrypt==4.1.2
litellm==1.35.28
litellm[proxy]==1.35.28 litellm[proxy]==1.35.28
boto3==1.34.95 boto3==1.34.95
...@@ -54,9 +52,9 @@ rank-bm25==0.2.2 ...@@ -54,9 +52,9 @@ rank-bm25==0.2.2
faster-whisper==1.0.1 faster-whisper==1.0.1
PyJWT==2.8.0
PyJWT[crypto]==2.8.0 PyJWT[crypto]==2.8.0
black==24.4.2 black==24.4.2
langfuse==2.27.3 langfuse==2.27.3
youtube-transcript-api youtube-transcript-api==0.6.2
pytube
\ No newline at end of file
...@@ -8,7 +8,7 @@ KEY_FILE=.webui_secret_key ...@@ -8,7 +8,7 @@ KEY_FILE=.webui_secret_key
PORT="${PORT:-8080}" PORT="${PORT:-8080}"
HOST="${HOST:-0.0.0.0}" HOST="${HOST:-0.0.0.0}"
if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
echo "No WEBUI_SECRET_KEY provided" echo "Loading WEBUI_SECRET_KEY from file, not provided as an environment variable."
if ! [ -e "$KEY_FILE" ]; then if ! [ -e "$KEY_FILE" ]; then
echo "Generating WEBUI_SECRET_KEY" echo "Generating WEBUI_SECRET_KEY"
......
...@@ -13,7 +13,7 @@ SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%" ...@@ -13,7 +13,7 @@ SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
:: Check if WEBUI_SECRET_KEY and WEBUI_JWT_SECRET_KEY are not set :: Check if WEBUI_SECRET_KEY and WEBUI_JWT_SECRET_KEY are not set
IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " ( IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
echo No WEBUI_SECRET_KEY provided echo Loading WEBUI_SECRET_KEY from file, not provided as an environment variable.
IF NOT EXIST "%KEY_FILE%" ( IF NOT EXIST "%KEY_FILE%" (
echo Generating WEBUI_SECRET_KEY echo Generating WEBUI_SECRET_KEY
......
...@@ -38,9 +38,10 @@ def calculate_sha256_string(string): ...@@ -38,9 +38,10 @@ def calculate_sha256_string(string):
def validate_email_format(email: str) -> bool: def validate_email_format(email: str) -> bool:
if not re.match(r"[^@]+@[^@]+\.[^@]+", email): if email.endswith("@localhost"):
return False return True
return True
return bool(re.match(r"[^@]+@[^@]+\.[^@]+", email))
def sanitize_filename(file_name): def sanitize_filename(file_name):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment