Unverified Commit 72354e06 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #2476 from open-webui/dev

0.2.0
parents 36e2a5e6 207e2503
......@@ -10,8 +10,4 @@ OPENAI_API_KEY=''
# DO NOT TRACK
SCARF_NO_ANALYTICS=true
DO_NOT_TRACK=true
ANONYMIZED_TELEMETRY=false
# Use locally bundled version of the LiteLLM cost map json
# to avoid repetitive startup connections
LITELLM_LOCAL_MODEL_COST_MAP="True"
\ No newline at end of file
ANONYMIZED_TELEMETRY=false
\ No newline at end of file
......@@ -11,7 +11,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v2
uses: actions/checkout@v4
- name: Check for changes in package.json
run: |
......@@ -36,7 +36,7 @@ jobs:
echo "::set-output name=content::$CHANGELOG_ESCAPED"
- name: Create GitHub release
uses: actions/github-script@v5
uses: actions/github-script@v7
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
......@@ -51,7 +51,7 @@ jobs:
console.log(`Created release ${release.data.html_url}`)
- name: Upload package to GitHub release
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v4
with:
name: package
path: .
......
name: Deploy to HuggingFace Spaces
on:
push:
branches:
- dev
- main
workflow_dispatch:
jobs:
check-secret:
runs-on: ubuntu-latest
outputs:
token-set: ${{ steps.check-key.outputs.defined }}
steps:
- id: check-key
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
if: "${{ env.HF_TOKEN != '' }}"
run: echo "defined=true" >> $GITHUB_OUTPUT
deploy:
runs-on: ubuntu-latest
needs: [check-secret]
if: needs.check-secret.outputs.token-set == 'true'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Remove git history
run: rm -rf .git
- name: Prepend YAML front matter to README.md
run: |
echo "---" > temp_readme.md
echo "title: Open WebUI" >> temp_readme.md
echo "emoji: 🐳" >> temp_readme.md
echo "colorFrom: purple" >> temp_readme.md
echo "colorTo: gray" >> temp_readme.md
echo "sdk: docker" >> temp_readme.md
echo "app_port: 8080" >> temp_readme.md
echo "---" >> temp_readme.md
cat README.md >> temp_readme.md
mv temp_readme.md README.md
- name: Configure git
run: |
git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
git config --global user.name "github-actions[bot]"
- name: Set up Git and push to Space
run: |
git init --initial-branch=main
git lfs track "*.ttf"
rm demo.gif
git add .
git commit -m "GitHub deploy: ${{ github.sha }}"
git push --force https://open-webui:${HF_TOKEN}@huggingface.co/spaces/open-webui/open-webui main
......@@ -84,6 +84,8 @@ jobs:
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
build-args: |
BUILD_HASH=${{ github.sha }}
- name: Export digest
run: |
......@@ -170,7 +172,9 @@ jobs:
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
build-args: USE_CUDA=true
build-args: |
BUILD_HASH=${{ github.sha }}
USE_CUDA=true
- name: Export digest
run: |
......@@ -257,7 +261,9 @@ jobs:
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
build-args: USE_OLLAMA=true
build-args: |
BUILD_HASH=${{ github.sha }}
USE_OLLAMA=true
- name: Export digest
run: |
......
......@@ -23,7 +23,7 @@ jobs:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
......
......@@ -19,7 +19,7 @@ jobs:
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v3
uses: actions/setup-node@v4
with:
node-version: '20' # Or specify any other version you want to use
......
......@@ -20,7 +20,11 @@ jobs:
- name: Build and run Compose Stack
run: |
docker compose --file docker-compose.yaml --file docker-compose.api.yaml up --detach --build
docker compose \
--file docker-compose.yaml \
--file docker-compose.api.yaml \
--file docker-compose.a1111-test.yaml \
up --detach --build
- name: Wait for Ollama to be up
timeout-minutes: 5
......@@ -95,7 +99,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v2
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
......
name: Release to PyPI
on:
push:
branches:
- main # or whatever branch you want to use
jobs:
release:
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/open-webui
permissions:
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 18
- uses: actions/setup-python@v5
with:
python-version: 3.11
- name: Build
run: |
python -m pip install --upgrade pip
pip install build
python -m build .
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
......@@ -5,6 +5,46 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.2.0] - 2024-06-01
### Added
- **🔧 Pipelines Support**: Open WebUI now includes a plugin framework for enhanced customization and functionality (https://github.com/open-webui/pipelines). Easily add custom logic and integrate Python libraries, from AI agents to home automation APIs.
- **🔗 Function Calling via Pipelines**: Integrate function calling seamlessly through Pipelines.
- **⚖️ User Rate Limiting via Pipelines**: Implement user-specific rate limits to manage API usage efficiently.
- **📊 Usage Monitoring with Langfuse**: Track and analyze usage statistics with Langfuse integration through Pipelines.
- **🕒 Conversation Turn Limits**: Set limits on conversation turns to manage interactions better through Pipelines.
- **🛡️ Toxic Message Filtering**: Automatically filter out toxic messages to maintain a safe environment using Pipelines.
- **🔍 Web Search Support**: Introducing built-in web search capabilities via RAG API, allowing users to search using SearXNG, Google Programmatic Search Engine, Brave Search, serpstack, and serper. Activate it effortlessly by adding necessary variables from Document settings > Web Params.
- **🗂️ Models Workspace**: Create and manage model presets for both Ollama/OpenAI API. Note: The old Modelfiles workspace is deprecated.
- **🛠️ Model Builder Feature**: Build and edit all models with persistent builder mode.
- **🏷️ Model Tagging Support**: Organize models with tagging features in the models workspace.
- **📋 Model Ordering Support**: Effortlessly organize models by dragging and dropping them into the desired positions within the models workspace.
- **📈 OpenAI Generation Stats**: Access detailed generation statistics for OpenAI models.
- **📅 System Prompt Variables**: New variables added: '{{CURRENT_DATE}}' and '{{USER_NAME}}' for dynamic prompts.
- **📢 Global Banner Support**: Manage global banners from admin settings > banners.
- **🗃️ Enhanced Archived Chats Modal**: Search and export archived chats easily.
- **📂 Archive All Button**: Quickly archive all chats from settings > chats.
- **🌐 Improved Translations**: Added and improved translations for French, Croatian, Cebuano, and Vietnamese.
### Fixed
- **🔍 Archived Chats Visibility**: Resolved issue with archived chats not showing in the admin panel.
- **💬 Message Styling**: Fixed styling issues affecting message appearance.
- **🔗 Shared Chat Responses**: Corrected the issue where shared chat response messages were not readonly.
- **🖥️ UI Enhancement**: Fixed the scrollbar overlapping issue with the message box in the user interface.
### Changed
- **💾 User Settings Storage**: User settings are now saved on the backend, ensuring consistency across all devices.
- **📡 Unified API Requests**: The API request for getting models is now unified to '/api/models' for easier usage.
- **🔄 Versioning Update**: Our versioning will now follow the format 0.x for major updates and 0.x.y for patches.
- **📦 Export All Chats (All Users)**: Moved this functionality to the Admin Panel settings for better organization and accessibility.
### Removed
- **🚫 Bundled LiteLLM Support Deprecated**: Migrate your LiteLLM config.yaml to a self-hosted LiteLLM instance. LiteLLM can still be added via OpenAI Connections. Download the LiteLLM config.yaml from admin settings > database > export LiteLLM config.yaml.
## [0.1.125] - 2024-05-19
### Added
......
......@@ -28,6 +28,7 @@ Examples of unacceptable behavior include:
- Public or private harassment
- Publishing others' private information, such as a physical or email address, without their explicit permission
- **Spamming of any kind**
- Aggressive sales tactics targeting our community members are strictly prohibited. You can mention your product if it's relevant to the discussion, but under no circumstances should you push it forcefully
- Other conduct which could reasonably be considered inappropriate in a professional setting
## Enforcement Responsibilities
......@@ -59,6 +60,7 @@ Community leaders will follow these Community Impact Guidelines in determining t
**Community Impact**: Repeated or severe violations of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals.
**Consequence**: A permanent ban from any sort of public interaction within the community.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
......
......@@ -11,12 +11,14 @@ ARG USE_CUDA_VER=cu121
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
ARG USE_RERANKING_MODEL=""
ARG BUILD_HASH=dev-build
# Override at your own risk - non-root configurations are untested
ARG UID=0
ARG GID=0
######## WebUI frontend ########
FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
ARG BUILD_HASH
WORKDIR /app
......@@ -24,6 +26,7 @@ COPY package.json package-lock.json ./
RUN npm ci
COPY . .
ENV APP_BUILD_HASH=${BUILD_HASH}
RUN npm run build
######## WebUI backend ########
......@@ -59,11 +62,6 @@ ENV OPENAI_API_KEY="" \
DO_NOT_TRACK=true \
ANONYMIZED_TELEMETRY=false
# Use locally bundled version of the LiteLLM cost map json
# to avoid repetitive startup connections
ENV LITELLM_LOCAL_MODEL_COST_MAP="True"
#### Other models #########################################################
## whisper TTS model settings ##
ENV WHISPER_MODEL="base" \
......@@ -83,10 +81,10 @@ WORKDIR /app/backend
ENV HOME /root
# Create user and group if not root
RUN if [ $UID -ne 0 ]; then \
if [ $GID -ne 0 ]; then \
addgroup --gid $GID app; \
fi; \
adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
if [ $GID -ne 0 ]; then \
addgroup --gid $GID app; \
fi; \
adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
fi
RUN mkdir -p $HOME/.cache/chroma
......@@ -132,7 +130,8 @@ RUN pip3 install uv && \
uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
fi
fi; \
chown -R $UID:$GID /app/backend/data/
......@@ -154,4 +153,7 @@ HEALTHCHECK CMD curl --silent --fail http://localhost:8080/health | jq -e '.stat
USER $UID:$GID
ARG BUILD_HASH
ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
CMD [ "bash", "start.sh"]
......@@ -15,93 +15,39 @@ Open WebUI is an extensible, feature-rich, and user-friendly self-hosted WebUI d
![Open WebUI Demo](./demo.gif)
## Features ⭐
## Key Features of Open WebUI
- 🖥️ **Intuitive Interface**: Our chat interface takes inspiration from ChatGPT, ensuring a user-friendly experience.
- 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience with support for both `:ollama` and `:cuda` tagged images.
- 📱 **Responsive Design**: Enjoy a seamless experience on both desktop and mobile devices.
- 🤝 **Ollama/OpenAI API Integration**: Effortlessly integrate OpenAI-compatible APIs for versatile conversations alongside Ollama models. Customize the OpenAI API URL to link with **LMStudio, GroqCloud, Mistral, OpenRouter, and more**.
- **Swift Responsiveness**: Enjoy fast and responsive performance.
- 🧩 **Pipelines, Open WebUI Plugin Support**: Seamlessly integrate custom logic and Python libraries into Open WebUI using [Pipelines Plugin Framework](https://github.com/open-webui/pipelines). Launch your Pipelines instance, set the OpenAI URL to the Pipelines URL, and explore endless possibilities. [Examples](https://github.com/open-webui/pipelines/examples) include **Function Calling**, User **Rate Limiting** to control access, **Usage Monitoring** with tools like Langfuse, **Live Translation with LibreTranslate** for multilingual support, **Toxic Message Filtering** and much more.
- 🚀 **Effortless Setup**: Install seamlessly using Docker or Kubernetes (kubectl, kustomize or helm) for a hassle-free experience.
- 📱 **Responsive Design**: Enjoy a seamless experience across Desktop PC, Laptop, and Mobile devices.
- 🌈 **Theme Customization**: Choose from a variety of themes to personalize your Open WebUI experience.
- 💻 **Code Syntax Highlighting**: Enjoy enhanced code readability with our syntax highlighting feature.
- 📱 **Progressive Web App (PWA) for Mobile**: Enjoy a native app-like experience on your mobile device with our PWA, providing offline access on localhost and a seamless user interface.
- ✒️🔢 **Full Markdown and LaTeX Support**: Elevate your LLM experience with comprehensive Markdown and LaTeX capabilities for enriched interaction.
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with the groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using `#` command in the prompt. In its alpha phase, occasional issues may arise as we actively refine and enhance this feature to ensure optimal performance and reliability.
- 🔍 **RAG Embedding Support**: Change the RAG embedding model directly in document settings, enhancing document processing. This feature supports Ollama and OpenAI models.
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by the URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
- 📜 **Prompt Preset Support**: Instantly access preset prompts using the `/` command in the chat input. Load predefined conversation starters effortlessly and expedite your interactions. Effortlessly import prompts through [Open WebUI Community](https://openwebui.com/) integration.
- 👍👎 **RLHF Annotation**: Empower your messages by rating them with thumbs up and thumbs down, followed by the option to provide textual feedback, facilitating the creation of datasets for Reinforcement Learning from Human Feedback (RLHF). Utilize your messages to train or fine-tune models, all while ensuring the confidentiality of locally saved data.
- 🏷️ **Conversation Tagging**: Effortlessly categorize and locate specific chats for quick reference and streamlined data collection.
- 📥🗑️ **Download/Delete Models**: Easily download or remove models directly from the web UI.
- 🔄 **Update All Ollama Models**: Easily update locally installed models all at once with a convenient button, streamlining model management.
- 🛠️ **Model Builder**: Easily create Ollama models via the Web UI. Create and add custom characters/agents, customize chat elements, and import models effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
- ⬆️ **GGUF File Model Creation**: Effortlessly create Ollama models by uploading GGUF files directly from the web UI. Streamlined process with options to upload from your machine or download GGUF files from Hugging Face.
- 📚 **Local RAG Integration**: Dive into the future of chat interactions with groundbreaking Retrieval Augmented Generation (RAG) support. This feature seamlessly integrates document interactions into your chat experience. You can load documents directly into the chat or add files to your document library, effortlessly accessing them using the `#` command before a query.
- 🤖 **Multiple Model Support**: Seamlessly switch between different chat models for diverse interactions.
- 🔍 **Web Search for RAG**: Perform web searches using providers like `SearXNG`, `Google PSE`, `Brave Search`, `serpstack`, and `serper`, and inject the results directly into your chat experience.
- 🔄 **Multi-Modal Support**: Seamlessly engage with models that support multimodal interactions, including images (e.g., LLava).
- 🌐 **Web Browsing Capability**: Seamlessly integrate websites into your chat experience using the `#` command followed by a URL. This feature allows you to incorporate web content directly into your conversations, enhancing the richness and depth of your interactions.
- 🧩 **Modelfile Builder**: Easily create Ollama modelfiles via the web UI. Create and add characters/agents, customize chat elements, and import modelfiles effortlessly through [Open WebUI Community](https://openwebui.com/) integration.
- 🎨 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API or ComfyUI (local), and OpenAI's DALL-E (external), enriching your chat experience with dynamic visual content.
- ⚙️ **Many Models Conversations**: Effortlessly engage with various models simultaneously, harnessing their unique strengths for optimal responses. Enhance your experience by leveraging a diverse set of models in parallel.
- 💬 **Collaborative Chat**: Harness the collective intelligence of multiple models by seamlessly orchestrating group conversations. Use the `@` command to specify the model, enabling dynamic and diverse dialogues within your chat interface. Immerse yourself in the collective intelligence woven into your chat environment.
- 🗨️ **Local Chat Sharing**: Generate and share chat links seamlessly between users, enhancing collaboration and communication.
- 🔄 **Regeneration History Access**: Easily revisit and explore your entire regeneration history.
- 📜 **Chat History**: Effortlessly access and manage your conversation history.
- 📬 **Archive Chats**: Effortlessly store away completed conversations with LLMs for future reference, maintaining a tidy and clutter-free chat interface while allowing for easy retrieval and reference.
- 📤📥 **Import/Export Chat History**: Seamlessly move your chat data in and out of the platform.
- 🗣️ **Voice Input Support**: Engage with your model through voice interactions; enjoy the convenience of talking to your model directly. Additionally, explore the option for sending voice input automatically after 3 seconds of silence for a streamlined experience.
- 🔊 **Configurable Text-to-Speech Endpoint**: Customize your Text-to-Speech experience with configurable OpenAI endpoints.
- ⚙️ **Fine-Tuned Control with Advanced Parameters**: Gain a deeper level of control by adjusting parameters such as temperature and defining your system prompts to tailor the conversation to your specific preferences and needs.
- 🎨🤖 **Image Generation Integration**: Seamlessly incorporate image generation capabilities using options such as AUTOMATIC1111 API (local), ComfyUI (local), and DALL-E, enriching your chat experience with dynamic visual content.
- 🤝 **OpenAI API Integration**: Effortlessly integrate OpenAI-compatible API for versatile conversations alongside Ollama models. Customize the API Base URL to link with **LMStudio, Mistral, OpenRouter, and more**.
-**Multiple OpenAI-Compatible API Support**: Seamlessly integrate and customize various OpenAI-compatible APIs, enhancing the versatility of your chat interactions.
- 🔑 **API Key Generation Support**: Generate secret keys to leverage Open WebUI with OpenAI libraries, simplifying integration and development.
- 🔗 **External Ollama Server Connection**: Seamlessly link to an external Ollama server hosted on a different address by configuring the environment variable.
- 🔀 **Multiple Ollama Instance Load Balancing**: Effortlessly distribute chat requests across multiple Ollama instances for enhanced performance and reliability.
- 👥 **Multi-User Management**: Easily oversee and administer users via our intuitive admin panel, streamlining user management processes.
- 🔗 **Webhook Integration**: Subscribe to new user sign-up events via webhook (compatible with Google Chat and Microsoft Teams), providing real-time notifications and automation capabilities.
- 🛡️ **Model Whitelisting**: Admins can whitelist models for users with the 'user' role, enhancing security and access control.
- 📧 **Trusted Email Authentication**: Authenticate using a trusted email header, adding an additional layer of security and authentication.
- 🔐 **Role-Based Access Control (RBAC)**: Ensure secure access with restricted permissions; only authorized individuals can access your Ollama, and exclusive model creation/pulling rights are reserved for administrators.
- 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security.
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
- 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates and new features.
- 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates, fixes, and new features.
Want to learn more about Open WebUI's features? Check out our [Open WebUI documentation](https://docs.openwebui.com/features) for a comprehensive overview!
## 🔗 Also Check Out Open WebUI Community!
......
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
import logging
from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
import time
import requests
from pydantic import BaseModel, ConfigDict
from typing import Optional, List
from utils.utils import get_verified_user, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, ENV
from constants import MESSAGES
import os
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import (
ENABLE_LITELLM,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
DATA_DIR,
LITELLM_PROXY_PORT,
LITELLM_PROXY_HOST,
)
import warnings
warnings.simplefilter("ignore")
from litellm.utils import get_llm_provider
import asyncio
import subprocess
import yaml
@asynccontextmanager
async def lifespan(app: FastAPI):
log.info("startup_event")
# TODO: Check config.yaml file and create one
asyncio.create_task(start_litellm_background())
yield
app = FastAPI(lifespan=lifespan)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"
with open(LITELLM_CONFIG_DIR, "r") as file:
litellm_config = yaml.safe_load(file)
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config
# Global variable to store the subprocess reference
background_process = None
CONFLICT_ENV_VARS = [
# Uvicorn uses PORT, so LiteLLM might use it as well
"PORT",
# LiteLLM uses DATABASE_URL for Prisma connections
"DATABASE_URL",
]
async def run_background_process(command):
global background_process
log.info("run_background_process")
try:
# Log the command to be executed
log.info(f"Executing command: {command}")
# Filter environment variables known to conflict with litellm
env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS}
# Execute the command and create a subprocess
process = await asyncio.create_subprocess_exec(
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)
background_process = process
log.info("Subprocess started successfully.")
# Capture STDERR for debugging purposes
stderr_output = await process.stderr.read()
stderr_text = stderr_output.decode().strip()
if stderr_text:
log.info(f"Subprocess STDERR: {stderr_text}")
# log.info output line by line
async for line in process.stdout:
log.info(line.decode().strip())
# Wait for the process to finish
returncode = await process.wait()
log.info(f"Subprocess exited with return code {returncode}")
except Exception as e:
log.error(f"Failed to start subprocess: {e}")
raise # Optionally re-raise the exception if you want it to propagate
async def start_litellm_background():
log.info("start_litellm_background")
# Command to run in the background
command = [
"litellm",
"--port",
str(LITELLM_PROXY_PORT),
"--host",
LITELLM_PROXY_HOST,
"--telemetry",
"False",
"--config",
LITELLM_CONFIG_DIR,
]
await run_background_process(command)
async def shutdown_litellm_background():
log.info("shutdown_litellm_background")
global background_process
if background_process:
background_process.terminate()
await background_process.wait() # Ensure the process has terminated
log.info("Subprocess terminated")
background_process = None
@app.get("/")
async def get_status():
return {"status": True}
async def restart_litellm():
"""
Endpoint to restart the litellm background service.
"""
log.info("Requested restart of litellm service.")
try:
# Shut down the existing process if it is running
await shutdown_litellm_background()
log.info("litellm service shutdown complete.")
# Restart the background service
asyncio.create_task(start_litellm_background())
log.info("litellm service restart complete.")
return {
"status": "success",
"message": "litellm service restarted successfully.",
}
except Exception as e:
log.info(f"Error restarting litellm service: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
@app.get("/restart")
async def restart_litellm_handler(user=Depends(get_admin_user)):
return await restart_litellm()
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return app.state.CONFIG
class LiteLLMConfigForm(BaseModel):
general_settings: Optional[dict] = None
litellm_settings: Optional[dict] = None
model_list: Optional[List[dict]] = None
router_settings: Optional[dict] = None
model_config = ConfigDict(protected_namespaces=())
@app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
app.state.CONFIG = form_data.model_dump(exclude_none=True)
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return app.state.CONFIG
@app.get("/models")
@app.get("/v1/models")
async def get_models(user=Depends(get_current_user)):
if app.state.ENABLE:
while not background_process:
await asyncio.sleep(0.1)
url = f"http://localhost:{LITELLM_PROXY_PORT}/v1"
r = None
try:
r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status()
data = r.json()
if app.state.ENABLE_MODEL_FILTER:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
return {
"data": [
{
"id": model["model_name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in app.state.CONFIG["model_list"]
],
"object": "list",
}
else:
return {
"data": [],
"object": "list",
}
@app.get("/model/info")
async def get_model_list(user=Depends(get_admin_user)):
return {"data": app.state.CONFIG["model_list"]}
class AddLiteLLMModelForm(BaseModel):
model_name: str
litellm_params: dict
model_config = ConfigDict(protected_namespaces=())
@app.post("/model/new")
async def add_model_to_config(
form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
):
try:
get_llm_provider(model=form_data.model_name)
app.state.CONFIG["model_list"].append(form_data.model_dump())
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
class DeleteLiteLLMModelForm(BaseModel):
id: str
@app.post("/model/delete")
async def delete_model_from_config(
form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
):
app.state.CONFIG["model_list"] = [
model
for model in app.state.CONFIG["model_list"]
if model["model_name"] != form_data.id
]
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
body = await request.body()
url = f"http://localhost:{LITELLM_PROXY_PORT}"
target_url = f"{url}/{path}"
headers = {}
# headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
else:
response_data = r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
......@@ -29,8 +29,8 @@ import time
from urllib.parse import urlparse
from typing import Optional, List, Union
from apps.web.models.users import Users
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
......@@ -39,10 +39,13 @@ from utils.utils import (
get_admin_user,
)
from utils.models import get_model_id_from_custom_model_id
from config import (
SRC_LOG_LEVELS,
OLLAMA_BASE_URLS,
ENABLE_OLLAMA_API,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
UPLOAD_DIR,
......@@ -67,6 +70,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
......@@ -96,6 +100,21 @@ async def get_status():
return {"status": True}
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
class OllamaConfigForm(BaseModel):
enable_ollama_api: Optional[bool] = None
@app.post("/config/update")
async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api
return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
@app.get("/urls")
async def get_ollama_api_urls(user=Depends(get_admin_user)):
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
......@@ -156,14 +175,23 @@ def merge_models_lists(model_lists):
async def get_all_models():
log.info("get_all_models()")
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS]
responses = await asyncio.gather(*tasks)
models = {
"models": merge_models_lists(
map(lambda response: response["models"] if response else None, responses)
)
}
if app.state.config.ENABLE_OLLAMA_API:
tasks = [
fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks)
models = {
"models": merge_models_lists(
map(
lambda response: response["models"] if response else None, responses
)
)
}
else:
models = {"models": []}
app.state.MODELS = {model["model"]: model for model in models["models"]}
......@@ -191,6 +219,8 @@ async def get_ollama_tags(
return models
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
r = None
try:
r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status()
......@@ -242,6 +272,8 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
)
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
r = None
try:
r = requests.request(method="GET", url=f"{url}/api/version")
r.raise_for_status()
......@@ -278,6 +310,9 @@ async def pull_model(
r = None
# Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
def get_request():
nonlocal url
nonlocal r
......@@ -305,7 +340,7 @@ async def pull_model(
r = requests.request(
method="POST",
url=f"{url}/api/pull",
data=form_data.model_dump_json(exclude_none=True).encode(),
data=json.dumps(payload),
stream=True,
)
......@@ -848,14 +883,93 @@ async def generate_chat_completion(
user=Depends(get_verified_user),
):
if url_idx == None:
model = form_data.model
log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
if ":" not in model:
model = f"{model}:latest"
payload = {
**form_data.model_dump(exclude_none=True),
}
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
model_id = form_data.model
model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
payload["options"] = {}
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
payload["options"]["seed"] = model_info.params.get("seed", None)
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
payload["options"]["top_k"] = model_info.params.get("top_k", None)
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if url_idx == None:
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
......@@ -865,16 +979,12 @@ async def generate_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
print(payload)
log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
r = None
def get_request():
nonlocal form_data
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
......@@ -883,7 +993,7 @@ async def generate_chat_completion(
def stream_content():
try:
if form_data.stream:
if payload.get("stream", None):
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
......@@ -901,7 +1011,7 @@ async def generate_chat_completion(
r = requests.request(
method="POST",
url=f"{url}/api/chat",
data=form_data.model_dump_json(exclude_none=True).encode(),
data=json.dumps(payload),
stream=True,
)
......@@ -957,14 +1067,62 @@ async def generate_openai_chat_completion(
user=Depends(get_verified_user),
):
if url_idx == None:
model = form_data.model
payload = {
**form_data.model_dump(exclude_none=True),
}
if ":" not in model:
model = f"{model}:latest"
model_id = form_data.model
model_info = Models.get_model_by_id(model_id)
if model in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[model]["urls"])
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
payload["temperature"] = model_info.params.get("temperature", None)
payload["top_p"] = model_info.params.get("top_p", None)
payload["max_tokens"] = model_info.params.get("max_tokens", None)
payload["frequency_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["seed"] = model_info.params.get("seed", None)
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if url_idx == None:
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
......@@ -977,7 +1135,7 @@ async def generate_openai_chat_completion(
r = None
def get_request():
nonlocal form_data
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
......@@ -986,7 +1144,7 @@ async def generate_openai_chat_completion(
def stream_content():
try:
if form_data.stream:
if payload.get("stream"):
yield json.dumps(
{"request_id": request_id, "done": False}
) + "\n"
......@@ -1006,7 +1164,7 @@ async def generate_openai_chat_completion(
r = requests.request(
method="POST",
url=f"{url}/v1/chat/completions",
data=form_data.model_dump_json(exclude_none=True).encode(),
data=json.dumps(payload),
stream=True,
)
......
......@@ -10,8 +10,8 @@ import logging
from pydantic import BaseModel
from apps.web.models.users import Users
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
......@@ -53,7 +53,6 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
......@@ -185,13 +184,10 @@ async def speech(request: Request, user=Depends(get_verified_user)):
async def fetch_url(url, key):
timeout = aiohttp.ClientTimeout(total=5)
try:
if key != "":
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response:
return await response.json()
else:
return None
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e:
# Handle connection error here
log.error(f"Connection error: {e}")
......@@ -199,14 +195,20 @@ async def fetch_url(url, key):
def merge_models_lists(model_lists):
log.info(f"merge_models_lists {model_lists}")
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{**model, "urlIdx": idx}
{
**model,
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"urlIdx": idx,
}
for model in models
if "api.openai.com"
not in app.state.config.OPENAI_API_BASE_URLS[idx]
......@@ -217,7 +219,7 @@ def merge_models_lists(model_lists):
return merged_list
async def get_all_models():
async def get_all_models(raw: bool = False):
log.info("get_all_models()")
if (
......@@ -232,7 +234,10 @@ async def get_all_models():
]
responses = await asyncio.gather(*tasks)
log.info(f"get_all_models:responses() {responses}")
log.debug(f"get_all_models:responses() {responses}")
if raw:
return responses
models = {
"data": merge_models_lists(
......@@ -249,10 +254,10 @@ async def get_all_models():
)
}
log.info(f"models: {models}")
log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models
return models
@app.get("/models")
......@@ -272,11 +277,16 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return models
else:
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = app.state.config.OPENAI_API_KEYS[url_idx]
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.request(method="GET", url=f"{url}/models")
r = requests.request(method="GET", url=f"{url}/models", headers=headers)
r.raise_for_status()
response_data = r.json()
......@@ -310,39 +320,107 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
body = await request.body()
# TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
payload = None
try:
body = body.decode("utf-8")
body = json.loads(body)
idx = app.state.MODELS[body.get("model")]["urlIdx"]
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if body.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in body:
body["max_tokens"] = 4000
log.debug("Modified body_dict:", body)
# Fix for ChatGPT calls failing because the num_ctx key is in body
if "num_ctx" in body:
# If 'num_ctx' is in the dictionary, delete it
# Leaving it there generates an error with the
# OpenAI API (Feb 2024)
del body["num_ctx"]
# Convert the modified body back to JSON
body = json.dumps(body)
if "chat/completions" in path:
body = body.decode("utf-8")
body = json.loads(body)
payload = {**body}
model_id = body.get("model")
model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
if model_info.params.get("temperature", None):
payload["temperature"] = int(
model_info.params.get("temperature")
)
if model_info.params.get("top_p", None):
payload["top_p"] = int(model_info.params.get("top_p", None))
if model_info.params.get("max_tokens", None):
payload["max_tokens"] = int(
model_info.params.get("max_tokens", None)
)
if model_info.params.get("frequency_penalty", None):
payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None)
+ message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
else:
pass
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if payload.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in payload:
payload["max_tokens"] = 4000
log.debug("Modified payload:", payload)
# Convert the modified body back to JSON
payload = json.dumps(payload)
except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e)
print(payload)
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}"
if key == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
......@@ -353,7 +431,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
r = requests.request(
method=request.method,
url=target_url,
data=body,
data=payload if payload else body,
headers=headers,
stream=True,
)
......@@ -376,6 +454,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if r is not None:
try:
res = r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
......
......@@ -11,7 +11,7 @@ from fastapi.middleware.cors import CORSMiddleware
import os, shutil, logging, re
from pathlib import Path
from typing import List
from typing import List, Union, Sequence
from chromadb.utils.batch_utils import create_batches
......@@ -46,7 +46,7 @@ import json
import sentence_transformers
from apps.web.models.documents import (
from apps.webui.models.documents import (
Documents,
DocumentForm,
DocumentResponse,
......@@ -61,6 +61,14 @@ from apps.rag.utils import (
query_collection_with_hybrid_search,
)
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
from utils.misc import (
calculate_sha256,
calculate_sha256_string,
......@@ -95,6 +103,17 @@ from config import (
RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE,
SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
BRAVE_SEARCH_API_KEY,
SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
SERPER_API_KEY,
RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
AppConfig,
)
......@@ -134,6 +153,20 @@ app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
def update_embedding_model(
embedding_model: str,
update_model: bool = False,
......@@ -201,6 +234,10 @@ class UrlForm(CollectionNameForm):
url: str
class SearchForm(CollectionNameForm):
query: str
@app.get("/")
async def get_status():
return {
......@@ -326,11 +363,26 @@ async def get_rag_config(user=Depends(get_admin_user)):
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"search": {
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
},
}
......@@ -344,11 +396,30 @@ class YoutubeLoaderConfig(BaseModel):
translation: Optional[str] = None
class WebSearchConfig(BaseModel):
enabled: bool
engine: Optional[str] = None
searxng_query_url: Optional[str] = None
google_pse_api_key: Optional[str] = None
google_pse_engine_id: Optional[str] = None
brave_search_api_key: Optional[str] = None
serpstack_api_key: Optional[str] = None
serpstack_https: Optional[bool] = None
serper_api_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
class WebConfig(BaseModel):
search: WebSearchConfig
web_loader_ssl_verification: Optional[bool] = None
class ConfigUpdateForm(BaseModel):
pdf_extract_images: Optional[bool] = None
chunk: Optional[ChunkParamUpdateForm] = None
web_loader_ssl_verification: Optional[bool] = None
youtube: Optional[YoutubeLoaderConfig] = None
web: Optional[WebConfig] = None
@app.post("/config/update")
......@@ -359,35 +430,36 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
else app.state.config.PDF_EXTRACT_IMAGES
)
app.state.config.CHUNK_SIZE = (
form_data.chunk.chunk_size
if form_data.chunk is not None
else app.state.config.CHUNK_SIZE
)
if form_data.chunk is not None:
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
app.state.config.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk is not None
else app.state.config.CHUNK_OVERLAP
)
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
if form_data.youtube is not None:
app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
app.state.config.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language
if form_data.youtube is not None
else app.state.config.YOUTUBE_LOADER_LANGUAGE
)
if form_data.web is not None:
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web.web_loader_ssl_verification
)
app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation
if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_TRANSLATION
)
app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
app.state.config.GOOGLE_PSE_ENGINE_ID = (
form_data.web.search.google_pse_engine_id
)
app.state.config.BRAVE_SEARCH_API_KEY = (
form_data.web.search.brave_search_api_key
)
app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
)
return {
"status": True,
......@@ -396,11 +468,26 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"search": {
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
},
}
......@@ -589,24 +676,40 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
)
def get_web_loader(url: str, verify_ssl: bool = True):
def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
# Check if the URL is valid
if isinstance(validators.url(url), validators.ValidationError):
if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(url, verify_ssl=verify_ssl)
return WebBaseLoader(
url,
verify_ssl=verify_ssl,
requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
continue_on_failure=True,
)
def validate_url(url: Union[str, Sequence[str]]):
if isinstance(url, str):
if isinstance(validators.url(url), validators.ValidationError):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
if not ENABLE_RAG_LOCAL_WEB_FETCH:
# Local web fetch is disabled, filter out any URLs that resolve to private IP addresses
parsed_url = urllib.parse.urlparse(url)
# Get IPv4 and IPv6 addresses
ipv4_addresses, ipv6_addresses = resolve_hostname(parsed_url.hostname)
# Check if any of the resolved addresses are private
# This is technically still vulnerable to DNS rebinding attacks, as we don't control WebBaseLoader
for ip in ipv4_addresses:
if validators.ipv4(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
for ip in ipv6_addresses:
if validators.ipv6(ip, private=True):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return True
elif isinstance(url, Sequence):
return all(validate_url(u) for u in url)
else:
return False
def resolve_hostname(hostname):
......@@ -620,6 +723,114 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses
def search_web(engine: str, query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order:
- SEARXNG_QUERY_URL
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
- BRAVE_SEARCH_API_KEY
- SERPSTACK_API_KEY
- SERPER_API_KEY
Args:
query (str): The query to search for
"""
# TODO: add playwright to search the web
if engine == "searxng":
if app.state.config.SEARXNG_QUERY_URL:
return search_searxng(
app.state.config.SEARXNG_QUERY_URL,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
elif engine == "google_pse":
if (
app.state.config.GOOGLE_PSE_API_KEY
and app.state.config.GOOGLE_PSE_ENGINE_ID
):
return search_google_pse(
app.state.config.GOOGLE_PSE_API_KEY,
app.state.config.GOOGLE_PSE_ENGINE_ID,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception(
"No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
)
elif engine == "brave":
if app.state.config.BRAVE_SEARCH_API_KEY:
return search_brave(
app.state.config.BRAVE_SEARCH_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
elif engine == "serpstack":
if app.state.config.SERPSTACK_API_KEY:
return search_serpstack(
app.state.config.SERPSTACK_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
https_enabled=app.state.config.SERPSTACK_HTTPS,
)
else:
raise Exception("No SERPSTACK_API_KEY found in environment variables")
elif engine == "serper":
if app.state.config.SERPER_API_KEY:
return search_serper(
app.state.config.SERPER_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No SERPER_API_KEY found in environment variables")
else:
raise Exception("No search engine API key found in environment variables")
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
try:
web_results = search_web(
app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
)
except Exception as e:
log.exception(e)
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
)
try:
urls = [result.link for result in web_results]
loader = get_web_loader(urls)
data = loader.load()
collection_name = form_data.collection_name
if collection_name == "":
collection_name = calculate_sha256_string(form_data.query)[:63]
store_data_in_vector_db(data, collection_name, overwrite=True)
return {
"status": True,
"collection_name": collection_name,
"filenames": urls,
}
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
......
import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Brave Search API key
query (str): The query to search for
"""
url = "https://api.search.brave.com/res/v1/web/search"
headers = {
"Accept": "application/json",
"Accept-Encoding": "gzip",
"X-Subscription-Token": api_key,
}
params = {"q": query, "count": count}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("web", {}).get("results", [])
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
)
for result in results[:count]
]
import json
import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse(
api_key: str, search_engine_id: str, query: str, count: int
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
Args:
api_key (str): A Programmable Search Engine API key
search_engine_id (str): A Programmable Search Engine ID
query (str): The query to search for
"""
url = "https://www.googleapis.com/customsearch/v1"
headers = {"Content-Type": "application/json"}
params = {
"cx": search_engine_id,
"q": query,
"key": api_key,
"num": count,
}
response = requests.request("GET", url, headers=headers, params=params)
response.raise_for_status()
json_response = response.json()
results = json_response.get("items", [])
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("snippet"),
)
for result in results
]
from typing import Optional
from pydantic import BaseModel
class SearchResult(BaseModel):
link: str
title: Optional[str]
snippet: Optional[str]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment