Unverified Commit 89634046 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #1107 from open-webui/dev

0.1.111
parents 04ddbf43 88d324b5
......@@ -29,11 +29,11 @@ jobs:
- name: Extract latest CHANGELOG entry
id: changelog
run: |
CHANGELOG_CONTENT=$(awk '/^## \[/{n++} n==1' CHANGELOG.md)
echo "CHANGELOG_CONTENT<<EOF"
echo "$CHANGELOG_CONTENT"
echo "EOF"
echo "::set-output name=content::${CHANGELOG_CONTENT}"
CHANGELOG_CONTENT=$(awk 'BEGIN {print_section=0;} /^## \[/ {if (print_section == 0) {print_section=1;} else {exit;}} print_section {print;}' CHANGELOG.md)
CHANGELOG_ESCAPED=$(echo "$CHANGELOG_CONTENT" | sed ':a;N;$!ba;s/\n/%0A/g')
echo "Extracted latest release notes from CHANGELOG.md:"
echo -e "$CHANGELOG_CONTENT"
echo "::set-output name=content::$CHANGELOG_ESCAPED"
- name: Create GitHub release
uses: actions/github-script@v5
......
......@@ -5,6 +5,24 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.111] - 2024-03-10
### Added
- 🛡️ **Model Whitelisting**: Admins now have the ability to whitelist models for users with the 'user' role.
- 🔄 **Update All Models**: Added a convenient button to update all models at once.
- 📄 **Toggle PDF OCR**: Users can now toggle PDF OCR option for improved parsing performance.
- 🎨 **DALL-E Integration**: Introduced DALL-E integration for image generation alongside automatic1111.
- 🛠️ **RAG API Refactoring**: Refactored RAG logic and exposed its API, with additional documentation to follow.
### Fixed
- 🔒 **Max Token Settings**: Added max token settings for anthropic/claude-3-sonnet-20240229 (Issue #1094).
- 🔧 **Misalignment Issue**: Corrected misalignment of Edit and Delete Icons when Chat Title is Empty (Issue #1104).
- 🔄 **Context Loss Fix**: Resolved RAG losing context on model response regeneration with Groq models via API key (Issue #1105).
- 📁 **File Handling Bug**: Addressed File Not Found Notification when Dropping a Conversation Element (Issue #1098).
- 🖱️ **Dragged File Styling**: Fixed dragged file layover styling issue.
## [0.1.110] - 2024-03-06
### Added
......
......@@ -41,7 +41,7 @@ ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
# device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
......@@ -81,4 +81,4 @@ COPY --from=build /app/package.json /app/package.json
# copy backend files
COPY ./backend .
CMD [ "bash", "start.sh"]
\ No newline at end of file
CMD [ "bash", "start.sh"]
......@@ -53,8 +53,6 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co
- 💬 **Collaborative Chat**: Harness the collective intelligence of multiple models by seamlessly orchestrating group conversations. Use the `@` command to specify the model, enabling dynamic and diverse dialogues within your chat interface. Immerse yourself in the collective intelligence woven into your chat environment.
- 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
- 🔄 **Regeneration History Access**: Easily revisit and explore your entire regeneration history.
- 📜 **Chat History**: Effortlessly access and manage your conversation history.
......@@ -65,8 +63,18 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co
- ⚙️ **Fine-Tuned Control with Advanced Parameters**: Gain a deeper level of control by adjusting parameters such as temperature and defining your system prompts to tailor the conversation to your specific preferences and needs.
- 🎨🤖 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using AUTOMATIC1111 API (local) and DALL-E, enriching your chat experience with dynamic visual content.
- 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
-**Multiple OpenAI-Compatible API Support**: Seamlessly integrate and customize various OpenAI-compatible APIs, enhancing the versatility of your chat interactions.
- 🔗 **External Ollama Server Connection**: Seamlessly link to an external Ollama server hosted on a different address by configuring the environment variable.
- 🔀 **Multiple Ollama Instance Load Balancing**: Effortlessly distribute chat requests across multiple Ollama instances for enhanced performance and reliability.
- 👥 **Multi-User Management**: Easily oversee and administer users via our intuitive admin panel, streamlining user management processes.
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
- 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security.
......
......@@ -21,7 +21,16 @@ from utils.utils import (
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
from config import AUTOMATIC1111_BASE_URL
from pathlib import Path
import uuid
import base64
import json
from config import CACHE_DIR, AUTOMATIC1111_BASE_URL
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI()
app.add_middleware(
......@@ -32,25 +41,34 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.ENGINE = ""
app.state.ENABLED = False
app.state.OPENAI_API_KEY = ""
app.state.MODEL = ""
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
app.state.IMAGE_SIZE = "512x512"
app.state.IMAGE_STEPS = 50
@app.get("/enabled", response_model=bool)
async def get_enable_status(request: Request, user=Depends(get_admin_user)):
return app.state.ENABLED
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
@app.get("/enabled/toggle", response_model=bool)
async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
try:
r = requests.head(app.state.AUTOMATIC1111_BASE_URL)
app.state.ENABLED = not app.state.ENABLED
return app.state.ENABLED
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
class ConfigUpdateForm(BaseModel):
engine: str
enabled: bool
@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.ENGINE = form_data.engine
app.state.ENABLED = form_data.enabled
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
class UrlUpdateForm(BaseModel):
......@@ -58,17 +76,24 @@ class UrlUpdateForm(BaseModel):
@app.get("/url")
async def get_openai_url(user=Depends(get_admin_user)):
async def get_automatic1111_url(user=Depends(get_admin_user)):
return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
async def update_automatic1111_url(
form_data: UrlUpdateForm, user=Depends(get_admin_user)
):
if form_data.url == "":
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else:
app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
url = form_data.url.strip("/")
try:
r = requests.head(url)
app.state.AUTOMATIC1111_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
......@@ -76,6 +101,30 @@ async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_use
}
class OpenAIKeyUpdateForm(BaseModel):
key: str
@app.get("/key")
async def get_openai_key(user=Depends(get_admin_user)):
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
@app.post("/key/update")
async def update_openai_key(
form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
):
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_KEY = form_data.key
return {
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"status": True,
}
class ImageSizeUpdateForm(BaseModel):
size: str
......@@ -132,9 +181,22 @@ async def update_image_size(
@app.get("/models")
def get_models(user=Depends(get_current_user)):
try:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
models = r.json()
return models
if app.state.ENGINE == "openai":
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
else:
r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
)
models = r.json()
return list(
map(
lambda model: {"id": model["title"], "name": model["model_name"]},
models,
)
)
except Exception as e:
app.state.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
......@@ -143,10 +205,12 @@ def get_models(user=Depends(get_current_user)):
@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
try:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
return {"model": options["sd_model_checkpoint"]}
if app.state.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
return {"model": options["sd_model_checkpoint"]}
except Exception as e:
app.state.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
......@@ -157,16 +221,21 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str):
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
)
if app.state.ENGINE == "openai":
app.state.MODEL = model
return app.state.MODEL
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
)
return options
return options
@app.post("/models/default/update")
......@@ -181,45 +250,113 @@ class GenerateImageForm(BaseModel):
model: Optional[str] = None
prompt: str
n: int = 1
size: str = "512x512"
size: Optional[str] = None
negative_prompt: Optional[str] = None
def save_b64_image(b64_str):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try:
# Split the base64 string to get the actual image data
img_data = base64.b64decode(b64_str)
# Write the image data to a file
with open(file_path, "wb") as f:
f.write(img_data)
return image_id
except Exception as e:
print(f"Error saving image: {e}")
return None
@app.post("/generations")
def generate_image(
form_data: GenerateImageForm,
user=Depends(get_current_user),
):
print(form_data)
r = None
try:
if form_data.model:
set_model_handler(form_data.model)
if app.state.ENGINE == "openai":
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
}
data = {
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
"prompt": form_data.prompt,
"n": form_data.n,
"size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
"response_format": "b64_json",
}
r = requests.post(
url=f"https://api.openai.com/v1/images/generations",
json=data,
headers=headers,
)
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
r.raise_for_status()
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
res = r.json()
print(data)
images = []
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
)
for image in res["data"]:
image_id = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_id}.png"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f:
json.dump(data, f)
return images
else:
if form_data.model:
set_model_handler(form_data.model)
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
)
res = r.json()
print(res)
images = []
for image in res["images"]:
image_id = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_id}.png"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f:
json.dump({**data, "info": res["info"]}, f)
return images
return r.json()
except Exception as e:
print(e)
if r:
print(r.json())
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
from litellm.proxy.proxy_server import ProxyConfig, initialize
from litellm.proxy.proxy_server import app
from fastapi import FastAPI, Request, Depends, status
from fastapi.responses import JSONResponse
from utils.utils import get_http_authorization_cred, get_current_user
from config import ENV
proxy_config = ProxyConfig()
async def config():
router, model_list, general_settings = await proxy_config.load_config(
router=None, config_file_path="./data/litellm/config.yaml"
)
await initialize(config="./data/litellm/config.yaml", telemetry=False)
async def startup():
await config()
@app.on_event("startup")
async def on_startup():
await startup()
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization", "")
if ENV != "dev":
try:
user = get_current_user(get_http_authorization_cred(auth_header))
print(user)
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
response = await call_next(request)
return response
......@@ -15,7 +15,7 @@ import asyncio
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user
from config import OLLAMA_BASE_URLS
from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST
from typing import Optional, List, Union
......@@ -29,6 +29,10 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
......@@ -129,9 +133,19 @@ async def get_all_models():
async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_current_user)
):
if url_idx == None:
return await get_all_models()
models = await get_all_models()
if app.state.MODEL_FILTER_ENABLED:
if user.role == "user":
models["models"] = list(
filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
models["models"],
)
)
return models
return models
else:
url = app.state.OLLAMA_BASE_URLS[url_idx]
try:
......
......@@ -18,7 +18,13 @@ from utils.utils import (
get_verified_user,
get_admin_user,
)
from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR
from config import (
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
CACHE_DIR,
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
from typing import List, Optional
......@@ -34,6 +40,9 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
......@@ -186,12 +195,21 @@ async def get_all_models():
return models
# , user=Depends(get_current_user)
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None):
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None:
return await get_all_models()
models = await get_all_models()
if app.state.MODEL_FILTER_ENABLED:
if user.role == "user":
models["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
models["data"],
)
)
return models
return models
else:
url = app.state.OPENAI_API_BASE_URLS[url_idx]
try:
......
......@@ -44,6 +44,8 @@ from apps.web.models.documents import (
DocumentResponse,
)
from apps.rag.utils import query_doc, query_collection
from utils.misc import (
calculate_sha256,
calculate_sha256_string,
......@@ -75,6 +77,7 @@ from constants import ERROR_MESSAGES
app = FastAPI()
app.state.PDF_EXTRACT_IMAGES = False
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
......@@ -182,12 +185,15 @@ async def update_embedding_model(
}
@app.get("/chunk")
async def get_chunk_params(user=Depends(get_admin_user)):
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
"chunk": {
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
},
}
......@@ -196,17 +202,24 @@ class ChunkParamUpdateForm(BaseModel):
chunk_overlap: int
@app.post("/chunk/update")
async def update_chunk_params(
form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
):
app.state.CHUNK_SIZE = form_data.chunk_size
app.state.CHUNK_OVERLAP = form_data.chunk_overlap
class ConfigUpdateForm(BaseModel):
pdf_extract_images: bool
chunk: ChunkParamUpdateForm
@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
app.state.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
"chunk": {
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
},
}
......@@ -248,21 +261,18 @@ class QueryDocForm(BaseModel):
@app.post("/query/doc")
def query_doc(
def query_doc_handler(
form_data: QueryDocForm,
user=Depends(get_current_user),
):
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
result = collection.query(
query_texts=[form_data.query],
n_results=form_data.k if form_data.k else app.state.TOP_K,
)
return result
except Exception as e:
print(e)
raise HTTPException(
......@@ -277,76 +287,16 @@ class QueryCollectionsForm(BaseModel):
k: Optional[int] = None
def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data
combined_ids = []
combined_distances = []
combined_metadatas = []
combined_documents = []
# Combine data from each dictionary
for data in query_results:
combined_ids.extend(data["ids"][0])
combined_distances.extend(data["distances"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_documents.extend(data["documents"][0])
# Create a list of tuples (distance, id, metadata, document)
combined = list(
zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
)
# Sort the list based on distances
combined.sort(key=lambda x: x[0])
# Unzip the sorted list
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_ids = list(sorted_ids)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
sorted_documents = list(sorted_documents)[:k]
# Create the output dictionary
merged_query_results = {
"ids": [sorted_ids],
"distances": [sorted_distances],
"metadatas": [sorted_metadatas],
"documents": [sorted_documents],
"embeddings": None,
"uris": None,
"data": None,
}
return merged_query_results
@app.post("/query/collection")
def query_collection(
def query_collection_handler(
form_data: QueryCollectionsForm,
user=Depends(get_current_user),
):
results = []
for collection_name in form_data.collection_names:
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
result = collection.query(
query_texts=[form_data.query],
n_results=form_data.k if form_data.k else app.state.TOP_K,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(
results, form_data.k if form_data.k else app.state.TOP_K
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
......@@ -425,7 +375,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
]
if file_ext == "pdf":
loader = PyPDFLoader(file_path, extract_images=True)
loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
elif file_ext == "csv":
loader = CSVLoader(file_path)
elif file_ext == "rst":
......
import re
from typing import List
from config import CHROMA_CLIENT
def query_doc(collection_name: str, query: str, k: int, embedding_function):
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=embedding_function,
)
result = collection.query(
query_texts=[query],
n_results=k,
)
return result
except Exception as e:
raise e
def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data
combined_ids = []
combined_distances = []
combined_metadatas = []
combined_documents = []
# Combine data from each dictionary
for data in query_results:
combined_ids.extend(data["ids"][0])
combined_distances.extend(data["distances"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_documents.extend(data["documents"][0])
# Create a list of tuples (distance, id, metadata, document)
combined = list(
zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
)
# Sort the list based on distances
combined.sort(key=lambda x: x[0])
# Unzip the sorted list
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_ids = list(sorted_ids)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
sorted_documents = list(sorted_documents)[:k]
# Create the output dictionary
merged_query_results = {
"ids": [sorted_ids],
"distances": [sorted_distances],
"metadatas": [sorted_metadatas],
"documents": [sorted_documents],
"embeddings": None,
"uris": None,
"data": None,
}
return merged_query_results
def query_collection(
collection_names: List[str], query: str, k: int, embedding_function
):
results = []
for collection_name in collection_names:
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=embedding_function,
)
result = collection.query(
query_texts=[query],
n_results=k,
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, k)
def rag_template(template: str, context: str, query: str):
template = re.sub(r"\[context\]", context, template)
template = re.sub(r"\[query\]", query, template)
return template
......@@ -251,7 +251,7 @@ OPENAI_API_BASE_URLS = (
OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
)
OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URL.split(";")]
OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URLS.split(";")]
####################################
......@@ -292,6 +292,11 @@ DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
USER_PERMISSIONS = {"chat": {"deletion": True}}
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False)
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
####################################
# WEBUI_VERSION
####################################
......
{
"version": "0.0.1",
"ui": {
"prompt_suggestions": [
{
......
......@@ -9,27 +9,37 @@ import requests
from fastapi import FastAPI, Request, Depends, status
from fastapi.staticfiles import StaticFiles
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from fastapi.middleware.wsgi import WSGIMiddleware
from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from litellm.proxy.proxy_server import ProxyConfig, initialize
from litellm.proxy.proxy_server import app as litellm_app
from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
from apps.litellm.main import app as litellm_app, startup as litellm_app_startup
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app
from pydantic import BaseModel
from typing import List
from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
from constants import ERROR_MESSAGES
from utils.utils import get_http_authorization_cred, get_current_user
from utils.utils import get_admin_user
from apps.rag.utils import query_doc, query_collection, rag_template
from config import (
WEBUI_NAME,
ENV,
VERSION,
CHANGELOG,
FRONTEND_BUILD_DIR,
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
from constants import ERROR_MESSAGES
class SPAStaticFiles(StaticFiles):
......@@ -43,23 +53,11 @@ class SPAStaticFiles(StaticFiles):
raise ex
proxy_config = ProxyConfig()
async def config():
router, model_list, general_settings = await proxy_config.load_config(
router=None, config_file_path="./data/litellm/config.yaml"
)
await initialize(config="./data/litellm/config.yaml", telemetry=False)
async def startup():
await config()
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
origins = ["*"]
app.add_middleware(
......@@ -73,7 +71,127 @@ app.add_middleware(
@app.on_event("startup")
async def on_startup():
await startup()
await litellm_app_startup()
class RAGMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and (
"/api/chat" in request.url.path or "/chat/completions" in request.url.path
):
print(request.url.path)
# Read the original request body
body = await request.body()
# Decode body to string
body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
if "docs" in data:
docs = data["docs"]
print(docs)
last_user_message_idx = None
for i in range(len(data["messages"]) - 1, -1, -1):
if data["messages"][i]["role"] == "user":
last_user_message_idx = i
break
user_message = data["messages"][last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
relevant_contexts = []
for doc in docs:
context = None
try:
if doc["type"] == "collection":
context = query_collection(
collection_names=doc["collection_names"],
query=query,
k=rag_app.state.TOP_K,
embedding_function=rag_app.state.sentence_transformer_ef,
)
else:
context = query_doc(
collection_name=doc["collection_name"],
query=query,
k=rag_app.state.TOP_K,
embedding_function=rag_app.state.sentence_transformer_ef,
)
except Exception as e:
print(e)
context = None
relevant_contexts.append(context)
context_string = ""
for context in relevant_contexts:
if context:
context_string += " ".join(context["documents"][0]) + "\n"
ra_content = rag_template(
template=rag_app.state.RAG_TEMPLATE,
context=context_string,
query=query,
)
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
data["messages"][last_user_message_idx] = new_user_message
del data["docs"]
print(data["messages"])
modified_body_bytes = json.dumps(data).encode("utf-8")
# Create a new request with the modified body
scope = request.scope
scope["body"] = modified_body_bytes
request = Request(scope, receive=lambda: self._receive(modified_body_bytes))
response = await call_next(request)
return response
async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False}
app.add_middleware(RAGMiddleware)
@app.middleware("http")
......@@ -86,21 +204,6 @@ async def check_url(request: Request, call_next):
return response
@litellm_app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization", "")
if ENV != "dev":
try:
user = get_current_user(get_http_authorization_cred(auth_header))
print(user)
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
response = await call_next(request)
return response
app.mount("/api/v1", webui_app)
app.mount("/litellm/api", litellm_app)
......@@ -125,6 +228,39 @@ async def get_app_config():
}
@app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)):
return {
"enabled": app.state.MODEL_FILTER_ENABLED,
"models": app.state.MODEL_FILTER_LIST,
}
class ModelFilterConfigForm(BaseModel):
enabled: bool
models: List[str]
@app.post("/api/config/model/filter")
async def get_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
app.state.MODEL_FILTER_ENABLED = form_data.enabled
app.state.MODEL_FILTER_LIST = form_data.models
ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
return {
"enabled": app.state.MODEL_FILTER_ENABLED,
"models": app.state.MODEL_FILTER_LIST,
}
@app.get("/api/version")
async def get_app_config():
......@@ -156,6 +292,7 @@ async def get_app_latest_release_version():
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount("/cache", StaticFiles(directory="data/cache"), name="cache")
app.mount(
......
......@@ -16,7 +16,8 @@ aiohttp
peewee
bcrypt
litellm
litellm==1.30.7
argon2-cffi
apscheduler
google-generativeai
......
{
"name": "open-webui",
"version": "0.1.110",
"version": "0.1.111",
"private": true,
"scripts": {
"dev": "vite dev --host",
......
import { IMAGES_API_BASE_URL } from '$lib/constants';
export const getImageGenerationEnabledStatus = async (token: string = '') => {
export const getImageGenerationConfig = async (token: string = '') => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/enabled`, {
const res = await fetch(`${IMAGES_API_BASE_URL}/config`, {
method: 'GET',
headers: {
Accept: 'application/json',
......@@ -32,10 +32,50 @@ export const getImageGenerationEnabledStatus = async (token: string = '') => {
return res;
};
export const toggleImageGenerationEnabledStatus = async (token: string = '') => {
export const updateImageGenerationConfig = async (
token: string = '',
engine: string,
enabled: boolean
) => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/enabled/toggle`, {
const res = await fetch(`${IMAGES_API_BASE_URL}/config/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
engine,
enabled
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res;
};
export const getOpenAIKey = async (token: string = '') => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/key`, {
method: 'GET',
headers: {
Accept: 'application/json',
......@@ -61,7 +101,42 @@ export const toggleImageGenerationEnabledStatus = async (token: string = '') =>
throw error;
}
return res;
return res.OPENAI_API_KEY;
};
export const updateOpenAIKey = async (token: string = '', key: string) => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/key/update`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
},
body: JSON.stringify({
key: key
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
if ('detail' in err) {
error = err.detail;
} else {
error = 'Server connection failed';
}
return null;
});
if (error) {
throw error;
}
return res.OPENAI_API_KEY;
};
export const getAUTOMATIC1111Url = async (token: string = '') => {
......@@ -263,7 +338,7 @@ export const updateImageSteps = async (token: string = '', steps: number) => {
return res.IMAGE_STEPS;
};
export const getDiffusionModels = async (token: string = '') => {
export const getImageGenerationModels = async (token: string = '') => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/models`, {
......@@ -295,7 +370,7 @@ export const getDiffusionModels = async (token: string = '') => {
return res;
};
export const getDefaultDiffusionModel = async (token: string = '') => {
export const getDefaultImageGenerationModel = async (token: string = '') => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/models/default`, {
......@@ -327,7 +402,7 @@ export const getDefaultDiffusionModel = async (token: string = '') => {
return res.model;
};
export const updateDefaultDiffusionModel = async (token: string = '', model: string) => {
export const updateDefaultImageGenerationModel = async (token: string = '', model: string) => {
let error = null;
const res = await fetch(`${IMAGES_API_BASE_URL}/models/default/update`, {
......
......@@ -77,3 +77,65 @@ export const getVersionUpdates = async () => {
return res;
};
export const getModelFilterConfig = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateModelFilterConfig = async (
token: string,
enabled: boolean,
models: string[]
) => {
let error = null;
const res = await fetch(`${WEBUI_BASE_URL}/api/config/model/filter`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
enabled: enabled,
models: models
})
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err;
return null;
});
if (error) {
throw error;
}
return res;
};
......@@ -77,6 +77,7 @@ type AddLiteLLMModelForm = {
api_base: string;
api_key: string;
rpm: string;
max_tokens: string;
};
export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMModelForm) => {
......@@ -95,7 +96,8 @@ export const addLiteLLMModel = async (token: string = '', payload: AddLiteLLMMod
model: payload.model,
...(payload.api_base === '' ? {} : { api_base: payload.api_base }),
...(payload.api_key === '' ? {} : { api_key: payload.api_key }),
...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) })
...(isNaN(parseInt(payload.rpm)) ? {} : { rpm: parseInt(payload.rpm) }),
...(payload.max_tokens === '' ? {} : { max_tokens: payload.max_tokens })
}
})
})
......
import { RAG_API_BASE_URL } from '$lib/constants';
export const getChunkParams = async (token: string) => {
export const getRAGConfig = async (token: string) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/chunk`, {
const res = await fetch(`${RAG_API_BASE_URL}/config`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
......@@ -27,18 +27,27 @@ export const getChunkParams = async (token: string) => {
return res;
};
export const updateChunkParams = async (token: string, size: number, overlap: number) => {
type ChunkConfigForm = {
chunk_size: number;
chunk_overlap: number;
};
type RAGConfigForm = {
pdf_extract_images: boolean;
chunk: ChunkConfigForm;
};
export const updateRAGConfig = async (token: string, payload: RAGConfigForm) => {
let error = null;
const res = await fetch(`${RAG_API_BASE_URL}/chunk/update`, {
const res = await fetch(`${RAG_API_BASE_URL}/config/update`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
},
body: JSON.stringify({
chunk_size: size,
chunk_overlap: overlap
...payload
})
})
.then(async (res) => {
......@@ -252,7 +261,7 @@ export const queryCollection = async (
token: string,
collection_names: string,
query: string,
k: number
k: number | null = null
) => {
let error = null;
......
<script lang="ts">
import { getModelFilterConfig, updateModelFilterConfig } from '$lib/apis';
import { getSignUpEnabledStatus, toggleSignUpEnabledStatus } from '$lib/apis/auths';
import { getUserPermissions, updateUserPermissions } from '$lib/apis/users';
import { models } from '$lib/stores';
import { onMount } from 'svelte';
export let saveHandler: Function;
let whitelistEnabled = false;
let whitelistModels = [''];
let permissions = {
chat: {
deletion: true
......@@ -13,6 +17,13 @@
onMount(async () => {
permissions = await getUserPermissions(localStorage.token);
const res = await getModelFilterConfig(localStorage.token);
if (res) {
whitelistEnabled = res.enabled;
whitelistModels = res.models.length > 0 ? res.models : [''];
}
});
</script>
......@@ -21,6 +32,8 @@
on:submit|preventDefault={async () => {
// console.log('submit');
await updateUserPermissions(localStorage.token, permissions);
await updateModelFilterConfig(localStorage.token, whitelistEnabled, whitelistModels);
saveHandler();
}}
>
......@@ -69,6 +82,106 @@
</button>
</div>
</div>
<hr class=" dark:border-gray-700 my-2" />
<div class="mt-2 space-y-3 pr-1.5">
<div>
<div class="mb-2">
<div class="flex justify-between items-center text-xs">
<div class=" text-sm font-medium">Manage Models</div>
</div>
</div>
<div class=" space-y-3">
<div>
<div class="flex justify-between items-center text-xs">
<div class=" text-xs font-medium">Model Whitelisting</div>
<button
class=" text-xs font-medium text-gray-500"
type="button"
on:click={() => {
whitelistEnabled = !whitelistEnabled;
}}>{whitelistEnabled ? 'On' : 'Off'}</button
>
</div>
</div>
{#if whitelistEnabled}
<div>
<div class=" space-y-1.5">
{#each whitelistModels as modelId, modelIdx}
<div class="flex w-full">
<div class="flex-1 mr-2">
<select
class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
bind:value={modelId}
placeholder="Select a model"
>
<option value="" disabled selected>Select a model</option>
{#each $models.filter((model) => model.id) as model}
<option value={model.id} class="bg-gray-100 dark:bg-gray-700"
>{model.name}</option
>
{/each}
</select>
</div>
{#if modelIdx === 0}
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition"
type="button"
on:click={() => {
if (whitelistModels.at(-1) !== '') {
whitelistModels = [...whitelistModels, ''];
}
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path
d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
/>
</svg>
</button>
{:else}
<button
class="px-2.5 bg-gray-100 hover:bg-gray-200 text-gray-800 dark:bg-gray-900 dark:text-white rounded-lg transition"
type="button"
on:click={() => {
whitelistModels.splice(modelIdx, 1);
whitelistModels = whitelistModels;
}}
>
<svg
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 16 16"
fill="currentColor"
class="w-4 h-4"
>
<path d="M3.75 7.25a.75.75 0 0 0 0 1.5h8.5a.75.75 0 0 0 0-1.5h-8.5Z" />
</svg>
</button>
{/if}
</div>
{/each}
</div>
<div class="flex justify-end items-center text-xs mt-1.5 text-right">
<div class=" text-xs font-medium">
{whitelistModels.length} Model(s) Whitelisted
</div>
</div>
</div>
{/if}
</div>
</div>
</div>
</div>
<div class="flex justify-end pt-3 text-sm font-medium">
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment