Commit 45311bfa authored by Anuraag Jain's avatar Anuraag Jain
Browse files

Merge branch 'main' into feat/cancel-model-download

# Conflicts:
#	src/lib/components/chat/Settings/Models.svelte
parents ae97a963 2fa94956
......@@ -7,7 +7,6 @@ node_modules
/package
.env
.env.*
!.env.example
vite.config.js.timestamp-*
vite.config.ts.timestamp-*
__pycache__
......
# Ollama URL for the backend to connect
# The path '/ollama/api' will be redirected to the specified backend URL
OLLAMA_API_BASE_URL='http://localhost:11434/api'
# The path '/ollama' will be redirected to the specified backend URL
OLLAMA_BASE_URL='http://localhost:11434'
OPENAI_API_BASE_URL=''
OPENAI_API_KEY=''
# AUTOMATIC1111_BASE_URL="http://localhost:7860"
# DO NOT TRACK
SCARF_NO_ANALYTICS=true
DO_NOT_TRACK=true
\ No newline at end of file
......@@ -32,7 +32,7 @@ assignees: ''
**Confirmation:**
- [ ] I have read and followed all the instructions provided in the README.md.
- [ ] I have reviewed the troubleshooting.md document.
- [ ] I am on the latest version of both Open WebUI and Ollama.
- [ ] I have included the browser console logs.
- [ ] I have included the Docker container logs.
......
## Pull Request Checklist
- [ ] **Description:** Briefly describe the changes in this pull request.
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
- [ ] **Documentation:** Have you updated relevant documentation?
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
---
## Description
[Insert a brief description of the changes made in this pull request]
---
### Changelog Entry
### Added
- [List any new features or additions]
### Fixed
- [List any fixes or corrections]
### Changed
- [List any changes or updates]
### Removed
- [List any removed features or files]
name: Release
on:
push:
branches:
- main # or whatever branch you want to use
jobs:
release:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Check for changes in package.json
run: |
git diff --cached --diff-filter=d package.json || {
echo "No changes to package.json"
exit 1
}
- name: Get version number from package.json
id: get_version
run: |
VERSION=$(jq -r '.version' package.json)
echo "::set-output name=version::$VERSION"
- name: Extract latest CHANGELOG entry
id: changelog
run: |
CHANGELOG_CONTENT=$(awk 'BEGIN {print_section=0;} /^## \[/ {if (print_section == 0) {print_section=1;} else {exit;}} print_section {print;}' CHANGELOG.md)
CHANGELOG_ESCAPED=$(echo "$CHANGELOG_CONTENT" | sed ':a;N;$!ba;s/\n/%0A/g')
echo "Extracted latest release notes from CHANGELOG.md:"
echo -e "$CHANGELOG_CONTENT"
echo "::set-output name=content::$CHANGELOG_ESCAPED"
- name: Create GitHub release
uses: actions/github-script@v5
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const changelog = `${{ steps.changelog.outputs.content }}`;
const release = await github.rest.repos.createRelease({
owner: context.repo.owner,
repo: context.repo.repo,
tag_name: `v${{ steps.get_version.outputs.version }}`,
name: `v${{ steps.get_version.outputs.version }}`,
body: changelog,
})
console.log(`Created release ${release.data.html_url}`)
- name: Upload package to GitHub release
uses: actions/upload-artifact@v3
with:
name: package
path: .
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
......@@ -52,6 +52,7 @@ jobs:
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
type=semver,pattern={{major}}.{{minor}}
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
......
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.114] - 2024-03-20
### Added
- **🔗 Webhook Integration**: Now you can subscribe to new user sign-up events via webhook. Simply navigate to the admin panel > admin settings > webhook URL.
- **🛡️ Enhanced Model Filtering**: Alongside Ollama, OpenAI proxy model whitelisting, we've added model filtering functionality for LiteLLM proxy.
- **🌍 Expanded Language Support**: Spanish, Catalan, and Vietnamese languages are now available, with improvements made to others.
### Fixed
- **🔧 Input Field Spelling**: Resolved issue with spelling mistakes in input fields.
- **🖊️ Light Mode Styling**: Fixed styling issue with light mode in document adding.
### Changed
- **🔄 Language Sorting**: Languages are now sorted alphabetically by their code for improved organization.
## [0.1.113] - 2024-03-18
### Added
- 🌍 **Localization**: You can now change the UI language in Settings > General. We support Ukrainian, German, Farsi (Persian), Traditional and Simplified Chinese and French translations. You can help us to translate the UI into your language! More info in our [CONTRIBUTION.md](https://github.com/open-webui/open-webui/blob/main/docs/CONTRIBUTING.md#-translations-and-internationalization).
- 🎨 **System-wide Theme**: Introducing a new system-wide theme for enhanced visual experience.
### Fixed
- 🌑 **Dark Background on Select Fields**: Improved readability by adding a dark background to select fields, addressing issues on certain browsers/devices.
- **Multiple OPENAI_API_BASE_URLS Issue**: Resolved issue where multiple base URLs caused conflicts when one wasn't functioning.
- **RAG Encoding Issue**: Fixed encoding problem in RAG.
- **npm Audit Fix**: Addressed npm audit findings.
- **Reduced Scroll Threshold**: Improved auto-scroll experience by reducing the scroll threshold from 50px to 5px.
### Changed
- 🔄 **Sidebar UI Update**: Updated sidebar UI to feature a chat menu dropdown, replacing two icons for improved navigation.
## [0.1.112] - 2024-03-15
### Fixed
- 🗨️ Resolved chat malfunction after image generation.
- 🎨 Fixed various RAG issues.
- 🧪 Rectified experimental broken GGUF upload logic.
## [0.1.111] - 2024-03-10
### Added
- 🛡️ **Model Whitelisting**: Admins now have the ability to whitelist models for users with the 'user' role.
- 🔄 **Update All Models**: Added a convenient button to update all models at once.
- 📄 **Toggle PDF OCR**: Users can now toggle PDF OCR option for improved parsing performance.
- 🎨 **DALL-E Integration**: Introduced DALL-E integration for image generation alongside automatic1111.
- 🛠️ **RAG API Refactoring**: Refactored RAG logic and exposed its API, with additional documentation to follow.
### Fixed
- 🔒 **Max Token Settings**: Added max token settings for anthropic/claude-3-sonnet-20240229 (Issue #1094).
- 🔧 **Misalignment Issue**: Corrected misalignment of Edit and Delete Icons when Chat Title is Empty (Issue #1104).
- 🔄 **Context Loss Fix**: Resolved RAG losing context on model response regeneration with Groq models via API key (Issue #1105).
- 📁 **File Handling Bug**: Addressed File Not Found Notification when Dropping a Conversation Element (Issue #1098).
- 🖱️ **Dragged File Styling**: Fixed dragged file layover styling issue.
## [0.1.110] - 2024-03-06
### Added
- **🌐 Multiple OpenAI Servers Support**: Enjoy seamless integration with multiple OpenAI-compatible APIs, now supported natively.
### Fixed
- **🔍 OCR Issue**: Resolved PDF parsing issue caused by OCR malfunction.
- **🚫 RAG Issue**: Fixed the RAG functionality, ensuring it operates smoothly.
- **📄 "Add Docs" Model Button**: Addressed the non-functional behavior of the "Add Docs" model button.
## [0.1.109] - 2024-03-06
### Added
- **🔄 Multiple Ollama Servers Support**: Enjoy enhanced scalability and performance with support for multiple Ollama servers in a single WebUI. Load balancing features are now available, providing improved efficiency (#788, #278).
- **🔧 Support for Claude 3 and Gemini**: Responding to user requests, we've expanded our toolset to include Claude 3 and Gemini, offering a wider range of functionalities within our platform (#1064).
- **🔍 OCR Functionality for PDF Loader**: We've augmented our PDF loader with Optical Character Recognition (OCR) capabilities. Now, extract text from scanned documents and images within PDFs, broadening the scope of content processing (#1050).
### Fixed
- **🛠️ RAG Collection**: Implemented a dynamic mechanism to recreate RAG collections, ensuring users have up-to-date and accurate data (#1031).
- **📝 User Agent Headers**: Fixed issue of RAG web requests being sent with empty user_agent headers, reducing rejections from certain websites. Realistic headers are now utilized for these requests (#1024).
- **⏹️ Playground Cancel Functionality**: Introducing a new "Cancel" option for stopping Ollama generation in the Playground, enhancing user control and usability (#1006).
- **🔤 Typographical Error in 'ASSISTANT' Field**: Corrected a typographical error in the 'ASSISTANT' field within the GGUF model upload template for accuracy and consistency (#1061).
### Changed
- **🔄 Refactored Message Deletion Logic**: Streamlined message deletion process for improved efficiency and user experience, simplifying interactions within the platform (#1004).
- **⚠️ Deprecation of `OLLAMA_API_BASE_URL`**: Deprecated `OLLAMA_API_BASE_URL` environment variable; recommend using `OLLAMA_BASE_URL` instead. Refer to our documentation for further details.
## [0.1.108] - 2024-03-02
### Added
- **🎮 Playground Feature (Beta)**: Explore the full potential of the raw API through an intuitive UI with our new playground feature, accessible to admins. Simply click on the bottom name area of the sidebar to access it. The playground feature offers two modes text completion (notebook) and chat completion. As it's in beta, please report any issues you encounter.
- **🛠️ Direct Database Download for Admins**: Admins can now download the database directly from the WebUI via the admin settings.
- **🎨 Additional RAG Settings**: Customize your RAG process with the ability to edit the TOP K value. Navigate to Documents > Settings > General to make changes.
- **🖥️ UI Improvements**: Tooltips now available in the input area and sidebar handle. More tooltips will be added across other parts of the UI.
### Fixed
- Resolved input autofocus issue on mobile when the sidebar is open, making it easier to use.
- Corrected numbered list display issue in Safari (#963).
- Restricted user ability to delete chats without proper permissions (#993).
### Changed
- **Simplified Ollama Settings**: Ollama settings now don't require the `/api` suffix. You can now utilize the Ollama base URL directly, e.g., `http://localhost:11434`. Also, an `OLLAMA_BASE_URL` environment variable has been added.
- **Database Renaming**: Starting from this release, `ollama.db` will be automatically renamed to `webui.db`.
## [0.1.107] - 2024-03-01
### Added
- **🚀 Makefile and LLM Update Script**: Included Makefile and a script for LLM updates in the repository.
### Fixed
- Corrected issue where links in the settings modal didn't appear clickable (#960).
- Fixed problem with web UI port not taking effect due to incorrect environment variable name in run-compose.sh (#996).
- Enhanced user experience by displaying chat in browser title and enabling automatic scrolling to the bottom (#992).
### Changed
- Upgraded toast library from `svelte-french-toast` to `svelte-sonner` for a more polished UI.
- Enhanced accessibility with the addition of dark mode on the authentication page.
## [0.1.106] - 2024-02-27
### Added
- **🎯 Auto-focus Feature**: The input area now automatically focuses when initiating or opening a chat conversation.
### Fixed
- Corrected typo from "HuggingFace" to "Hugging Face" (Issue #924).
- Resolved bug causing errors in chat completion API calls to OpenAI due to missing "num_ctx" parameter (Issue #927).
- Fixed issues preventing text editing, selection, and cursor retention in the input field (Issue #940).
- Fixed a bug where defining an OpenAI-compatible API server using 'OPENAI_API_BASE_URL' containing 'openai' string resulted in hiding models not containing 'gpt' string from the model menu. (Issue #930)
## [0.1.105] - 2024-02-25
### Added
- **📄 Document Selection**: Now you can select and delete multiple documents at once for easier management.
### Changed
- **🏷️ Document Pre-tagging**: Simply click the "+" button at the top, enter tag names in the popup window, or select from a list of existing tags. Then, upload files with the added tags for streamlined organization.
## [0.1.104] - 2024-02-25
### Added
- **🔄 Check for Updates**: Keep your system current by checking for updates conveniently located in Settings > About.
- **🗑️ Automatic Tag Deletion**: Unused tags on the sidebar will now be deleted automatically with just a click.
### Changed
- **🎨 Modernized Styling**: Enjoy a refreshed look with updated styling for a more contemporary experience.
## [0.1.103] - 2024-02-25
### Added
- **🔗 Built-in LiteLLM Proxy**: Now includes LiteLLM proxy within Open WebUI for enhanced functionality.
- Easily integrate existing LiteLLM configurations using `-v /path/to/config.yaml:/app/backend/data/litellm/config.yaml` flag.
- When utilizing Docker container to run Open WebUI, ensure connections to localhost use `host.docker.internal`.
- **🖼️ Image Generation Enhancements**: Introducing Advanced Settings with Image Preview Feature.
- Customize image generation by setting the number of steps; defaults to A1111 value.
### Fixed
- Resolved issue with RAG scan halting document loading upon encountering unsupported MIME types or exceptions (Issue #866).
### Changed
- Ollama is no longer required to run Open WebUI.
- Access our comprehensive documentation at [Open WebUI Documentation](https://docs.openwebui.com/).
## [0.1.102] - 2024-02-22
### Added
- **🖼️ Image Generation**: Generate Images using the AUTOMATIC1111/stable-diffusion-webui API. You can set this up in Settings > Images.
- **📝 Change title generation prompt**: Change the prompt used to generate titles for your chats. You can set this up in the Settings > Interface.
- **🤖 Change embedding model**: Change the embedding model used to generate embeddings for your chats in the Dockerfile. Use any sentence transformer model from huggingface.co.
- **📢 CHANGELOG.md/Popup**: This popup will show you the latest changes.
## [0.1.101] - 2024-02-22
### Fixed
- LaTex output formatting issue (#828)
### Changed
- Instead of having the previous 1.0.0-alpha.101, we switched to semantic versioning as a way to respect global conventions.
......@@ -20,7 +20,7 @@ FROM python:3.11-slim-bookworm as base
ENV ENV=prod
ENV PORT ""
ENV OLLAMA_API_BASE_URL "/ollama/api"
ENV OLLAMA_BASE_URL "/ollama"
ENV OPENAI_API_BASE_URL ""
ENV OPENAI_API_KEY ""
......@@ -30,15 +30,31 @@ ENV WEBUI_SECRET_KEY ""
ENV SCARF_NO_ANALYTICS true
ENV DO_NOT_TRACK true
#Whisper TTS Settings
######## Preloaded models ########
# whisper TTS Settings
ENV WHISPER_MODEL="base"
ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
# RAG Embedding Model Settings
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
# device type for whisper tts and embbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
######## Preloaded models ########
WORKDIR /app/backend
# install python dependencies
COPY ./backend/requirements.txt ./requirements.txt
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir
RUN pip3 install -r requirements.txt --no-cache-dir
......@@ -48,9 +64,10 @@ RUN apt-get update \
&& apt-get install -y pandoc netcat-openbsd \
&& rm -rf /var/lib/apt/lists/*
# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')"
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
# preload embedding model
RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"
# preload tts model
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
# copy embedding weight from build
RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
......@@ -58,8 +75,10 @@ COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onn
# copy built frontend files
COPY --from=build /app/build /app/build
COPY --from=build /app/CHANGELOG.md /app/CHANGELOG.md
COPY --from=build /app/package.json /app/package.json
# copy backend files
COPY ./backend .
CMD [ "bash", "start.sh"]
\ No newline at end of file
CMD [ "bash", "start.sh"]
install:
@docker-compose up -d
remove:
@chmod +x confirm_remove.sh
@./confirm_remove.sh
start:
@docker-compose start
stop:
@docker-compose stop
update:
# Calls the LLM update script
chmod +x update_ollama_models.sh
@./update_ollama_models.sh
@git pull
@docker-compose down
# Make sure the ollama-webui container is stopped before rebuilding
@docker stop open-webui || true
@docker-compose up --build -d
@docker-compose start
This diff is collapsed.
......@@ -4,7 +4,7 @@
The Open WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues.
- **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama/api` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_API_BASE_URL` environment variable. Therefore, a request made to `/ollama/api` in the WebUI is effectively the same as making a request to `OLLAMA_API_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_API_BASE_URL/tags` in the backend.
- **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_BASE_URL` environment variable. Therefore, a request made to `/ollama` in the WebUI is effectively the same as making a request to `OLLAMA_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_BASE_URL/api/tags` in the backend.
- **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer.
......@@ -15,7 +15,7 @@ If you're experiencing connection issues, it’s often due to the WebUI docker c
**Example Docker Command**:
```bash
docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_API_BASE_URL=http://127.0.0.1:11434/api --name open-webui --restart always ghcr.io/open-webui/open-webui:main
docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_BASE_URL=http://127.0.0.1:11434 --name open-webui --restart always ghcr.io/open-webui/open-webui:main
```
### General Connection Errors
......@@ -25,8 +25,8 @@ docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_API_BASE_
**Troubleshooting Steps**:
1. **Verify Ollama URL Format**:
- When running the Web UI container, ensure the `OLLAMA_API_BASE_URL` is correctly set, including the `/api` suffix. (e.g., `http://192.168.1.1:11434/api` for different host setups).
- When running the Web UI container, ensure the `OLLAMA_BASE_URL` is correctly set. (e.g., `http://192.168.1.1:11434` for different host setups).
- In the Open WebUI, navigate to "Settings" > "General".
- Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]/api` (e.g., `http://localhost:11434/api`), including the `/api` suffix.
- Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]` (e.g., `http://localhost:11434`).
By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord.
......@@ -4,4 +4,11 @@ _old
uploads
.ipynb_checkpoints
*.db
_test
\ No newline at end of file
_test
!/data
/data/*
!/data/litellm
/data/litellm/*
!data/litellm/config.yaml
!data/config.json
\ No newline at end of file
......@@ -6,5 +6,11 @@ uploads
*.db
_test
Pipfile
data/*
!/data
/data/*
!/data/litellm
/data/litellm/*
!data/litellm/config.yaml
!data/config.json
.webui_secret_key
\ No newline at end of file
......@@ -56,7 +56,7 @@ def transcribe(
model = WhisperModel(
WHISPER_MODEL,
device="cpu",
device="auto",
compute_type="int8",
download_root=WHISPER_MODEL_DIR,
)
......
import re
import requests
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
UploadFile,
File,
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
get_current_user,
get_admin_user,
)
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
from pathlib import Path
import uuid
import base64
import json
from config import CACHE_DIR, AUTOMATIC1111_BASE_URL
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.ENGINE = ""
app.state.ENABLED = False
app.state.OPENAI_API_KEY = ""
app.state.MODEL = ""
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.IMAGE_SIZE = "512x512"
app.state.IMAGE_STEPS = 50
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
class ConfigUpdateForm(BaseModel):
engine: str
enabled: bool
@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.ENGINE = form_data.engine
app.state.ENABLED = form_data.enabled
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
class UrlUpdateForm(BaseModel):
url: str
@app.get("/url")
async def get_automatic1111_url(user=Depends(get_admin_user)):
return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
@app.post("/url/update")
async def update_automatic1111_url(
form_data: UrlUpdateForm, user=Depends(get_admin_user)
):
if form_data.url == "":
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else:
url = form_data.url.strip("/")
try:
r = requests.head(url)
app.state.AUTOMATIC1111_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"status": True,
}
class OpenAIKeyUpdateForm(BaseModel):
key: str
@app.get("/key")
async def get_openai_key(user=Depends(get_admin_user)):
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
@app.post("/key/update")
async def update_openai_key(
form_data: OpenAIKeyUpdateForm, user=Depends(get_admin_user)
):
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_KEY = form_data.key
return {
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"status": True,
}
class ImageSizeUpdateForm(BaseModel):
size: str
@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
@app.post("/size/update")
async def update_image_size(
form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
):
pattern = r"^\d+x\d+$" # Regular expression pattern
if re.match(pattern, form_data.size):
app.state.IMAGE_SIZE = form_data.size
return {
"IMAGE_SIZE": app.state.IMAGE_SIZE,
"status": True,
}
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
)
class ImageStepsUpdateForm(BaseModel):
steps: int
@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
@app.post("/steps/update")
async def update_image_size(
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
):
if form_data.steps >= 0:
app.state.IMAGE_STEPS = form_data.steps
return {
"IMAGE_STEPS": app.state.IMAGE_STEPS,
"status": True,
}
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 50)."),
)
@app.get("/models")
def get_models(user=Depends(get_current_user)):
try:
if app.state.ENGINE == "openai":
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
else:
r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
)
models = r.json()
return list(
map(
lambda model: {"id": model["title"], "name": model["model_name"]},
models,
)
)
except Exception as e:
app.state.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
try:
if app.state.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
return {"model": options["sd_model_checkpoint"]}
except Exception as e:
app.state.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
class UpdateModelForm(BaseModel):
model: str
def set_model_handler(model: str):
if app.state.ENGINE == "openai":
app.state.MODEL = model
return app.state.MODEL
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
)
return options
@app.post("/models/default/update")
def update_default_model(
form_data: UpdateModelForm,
user=Depends(get_current_user),
):
return set_model_handler(form_data.model)
class GenerateImageForm(BaseModel):
model: Optional[str] = None
prompt: str
n: int = 1
size: Optional[str] = None
negative_prompt: Optional[str] = None
def save_b64_image(b64_str):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try:
# Split the base64 string to get the actual image data
img_data = base64.b64decode(b64_str)
# Write the image data to a file
with open(file_path, "wb") as f:
f.write(img_data)
return image_id
except Exception as e:
print(f"Error saving image: {e}")
return None
@app.post("/generations")
def generate_image(
form_data: GenerateImageForm,
user=Depends(get_current_user),
):
r = None
try:
if app.state.ENGINE == "openai":
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
data = {
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
"prompt": form_data.prompt,
"n": form_data.n,
"size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
"response_format": "b64_json",
}
r = requests.post(
url=f"https://api.openai.com/v1/images/generations",
json=data,
headers=headers,
)
r.raise_for_status()
res = r.json()
images = []
for image in res["data"]:
image_id = save_b64_image(image["b64_json"])
images.append({"url": f"/cache/image/generations/{image_id}.png"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f:
json.dump(data, f)
return images
else:
if form_data.model:
set_model_handler(form_data.model)
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
)
res = r.json()
print(res)
images = []
for image in res["images"]:
image_id = save_b64_image(image)
images.append({"url": f"/cache/image/generations/{image_id}.png"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f:
json.dump({**data, "info": res["info"]}, f)
return images
except Exception as e:
error = e
if r != None:
data = r.json()
if "error" in data:
error = data["error"]["message"]
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
from litellm.proxy.proxy_server import ProxyConfig, initialize
from litellm.proxy.proxy_server import app
from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
from utils.utils import get_http_authorization_cred, get_current_user
from config import ENV
from config import (
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
proxy_config = ProxyConfig()
async def config():
router, model_list, general_settings = await proxy_config.load_config(
router=None, config_file_path="./data/litellm/config.yaml"
)
await initialize(config="./data/litellm/config.yaml", telemetry=False)
async def startup():
await config()
@app.on_event("startup")
async def on_startup():
await startup()
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.middleware("http")
async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization", "")
request.state.user = None
try:
user = get_current_user(get_http_authorization_cred(auth_header))
print(user)
request.state.user = user
except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)})
response = await call_next(request)
return response
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
user = request.state.user
if "/models" in request.url.path:
if isinstance(response, StreamingResponse):
# Read the content of the streaming response
body = b""
async for chunk in response.body_iterator:
body += chunk
data = json.loads(body.decode("utf-8"))
if app.state.MODEL_FILTER_ENABLED:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"]
in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
# Modified Flag
data["modified"] = True
return JSONResponse(content=data)
return response
app.add_middleware(ModifyModelsResponseMiddleware)
This diff is collapsed.
from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
import requests
import json
from pydantic import BaseModel
from apps.web.models.users import Users
from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
import aiohttp
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
@app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)):
if user and user.role == "admin":
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
class UrlUpdateForm(BaseModel):
url: str
@app.post("/url/update")
async def update_ollama_api_url(
form_data: UrlUpdateForm, user=Depends(get_current_user)
):
if user and user.role == "admin":
app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
# async def fetch_sse(method, target_url, body, headers):
# async with aiohttp.ClientSession() as session:
# try:
# async with session.request(
# method, target_url, data=body, headers=headers
# ) as response:
# print(response.status)
# async for line in response.content:
# yield line
# except Exception as e:
# print(e)
# error_detail = "Open WebUI: Server Connection Error"
# yield json.dumps({"error": error_detail, "message": str(e)}).encode()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)):
target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
print(target_url)
body = await request.body()
headers = dict(request.headers)
if user.role in ["user", "admin"]:
if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin":
raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
headers.pop("Host", None)
headers.pop("Authorization", None)
headers.pop("Origin", None)
headers.pop("Referer", None)
session = aiohttp.ClientSession()
response = None
try:
response = await session.request(
request.method, target_url, data=body, headers=headers
)
print(response)
if not response.ok:
data = await response.json()
print(data)
response.raise_for_status()
async def generate():
async for line in response.content:
print(line)
yield line
await session.close()
return StreamingResponse(generate(), response.status)
except Exception as e:
print(e)
error_detail = "Open WebUI: Server Connection Error"
if response is not None:
try:
res = await response.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
await session.close()
raise HTTPException(
status_code=response.status if response else 500,
detail=error_detail,
)
......@@ -3,7 +3,10 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
import requests
import aiohttp
import asyncio
import json
from pydantic import BaseModel
......@@ -15,7 +18,15 @@ from utils.utils import (
get_verified_user,
get_admin_user,
)
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
from config import (
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
CACHE_DIR,
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
from typing import List, Optional
import hashlib
from pathlib import Path
......@@ -29,116 +40,241 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = OPENAI_API_KEY
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.MODELS = {}
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
response = await call_next(request)
return response
class UrlUpdateForm(BaseModel):
url: str
class UrlsUpdateForm(BaseModel):
urls: List[str]
class KeyUpdateForm(BaseModel):
key: str
class KeysUpdateForm(BaseModel):
keys: List[str]
@app.get("/url")
async def get_openai_url(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_BASE_URL = form_data.url
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
@app.get("/key")
async def get_openai_key(user=Depends(get_admin_user)):
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
@app.post("/key/update")
async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_KEY = form_data.key
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_KEYS = form_data.keys
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
@app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
idx = None
try:
idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
body = await request.body()
name = hashlib.sha256(body).hexdigest()
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
data=body,
headers=headers,
stream=True,
)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
r.raise_for_status()
body = await request.body()
# Save the streaming content to a file
with open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
name = hashlib.sha256(body).hexdigest()
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
# Return the saved file
return FileResponse(file_path)
# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)
except Exception as e:
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
try:
print("openai")
r = requests.post(
url=target_url,
data=body,
headers=headers,
stream=True,
)
except ValueError:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
r.raise_for_status()
# Save the streaming content to a file
with open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
async def fetch_url(url, key):
try:
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession() as session:
async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
return None
def merge_models_lists(model_lists):
merged_list = []
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{**model, "urlIdx": idx}
for model in models
if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
or "gpt" in model["id"]
]
)
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
return merged_list
async def get_all_models():
print("get_all_models")
if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "":
models = {"data": []}
else:
tasks = [
fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
]
responses = await asyncio.gather(*tasks)
models = {
"data": merge_models_lists(
list(
map(
lambda response: (
response["data"]
if response and "data" in response
else None
),
responses,
)
)
)
}
print(models)
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models
@app.get("/models")
@app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None:
models = await get_all_models()
if app.state.MODEL_FILTER_ENABLED:
if user.role == "user":
models["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
models["data"],
)
)
return models
return models
else:
url = app.state.OPENAI_API_BASE_URLS[url_idx]
# Return the saved file
return FileResponse(file_path)
r = None
except Exception as e:
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
try:
r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status()
response_data = r.json()
if "api.openai.com" in url:
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], response_data["data"])
)
raise HTTPException(status_code=r.status_code, detail=error_detail)
return response_data
except Exception as e:
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
print(target_url, app.state.OPENAI_API_KEY)
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
idx = 0
body = await request.body()
# TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
try:
body = body.decode("utf-8")
body = json.loads(body)
idx = app.state.MODELS[body.get("model")]["urlIdx"]
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if body.get("model") == "gpt-4-vision-preview":
......@@ -146,15 +282,32 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
body["max_tokens"] = 4000
print("Modified body_dict:", body)
# Fix for ChatGPT calls failing because the num_ctx key is in body
if "num_ctx" in body:
# If 'num_ctx' is in the dictionary, delete it
# Leaving it there generates an error with the
# OpenAI API (Feb 2024)
del body["num_ctx"]
# Convert the modified body back to JSON
body = json.dumps(body)
except json.JSONDecodeError as e:
print("Error loading request body into a dictionary:", e)
url = app.state.OPENAI_API_BASE_URLS[idx]
key = app.state.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}"
if key == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.request(
method=request.method,
......@@ -174,21 +327,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers=dict(r.headers),
)
else:
# For non-SSE, read the response and return it
# response_data = (
# r.json()
# if r.headers.get("Content-Type", "")
# == "application/json"
# else r.text
# )
response_data = r.json()
if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], response_data["data"])
)
return response_data
except Exception as e:
print(e)
......@@ -201,4 +340,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status_code, detail=error_detail)
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
......@@ -10,9 +9,12 @@ from fastapi import (
)
from fastapi.middleware.cors import CORSMiddleware
import os, shutil
from pathlib import Path
from typing import List
# from chromadb.utils import embedding_functions
from sentence_transformers import SentenceTransformer
from chromadb.utils import embedding_functions
from langchain_community.document_loaders import (
WebBaseLoader,
......@@ -28,27 +30,68 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from pydantic import BaseModel
from typing import Optional
import mimetypes
import uuid
import time
import json
from utils.misc import calculate_sha256, calculate_sha256_string
from apps.web.models.documents import (
Documents,
DocumentForm,
DocumentResponse,
)
from apps.rag.utils import query_doc, query_collection
from utils.misc import (
calculate_sha256,
calculate_sha256_string,
sanitize_filename,
extract_folders_after_data_docs,
)
from utils.utils import get_current_user, get_admin_user
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from config import (
UPLOAD_DIR,
DOCS_DIR,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
)
from constants import ERROR_MESSAGES
# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
# model_name=EMBED_MODEL
# )
#
# if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer(
# model_name_or_path=RAG_EMBEDDING_MODEL,
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# )
app = FastAPI()
app.state.PDF_EXTRACT_IMAGES = False
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.TOP_K = 4
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
origins = ["*"]
app.add_middleware(
......@@ -68,9 +111,9 @@ class StoreWebForm(CollectionNameForm):
url: str
def store_data_in_vector_db(data, collection_name) -> bool:
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
)
docs = text_splitter.split_documents(data)
......@@ -78,7 +121,16 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs]
try:
collection = CHROMA_CLIENT.create_collection(name=collection_name)
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name:
print(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
......@@ -94,26 +146,133 @@ def store_data_in_vector_db(data, collection_name) -> bool:
@app.get("/")
async def get_status():
return {"status": True}
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
@app.get("/embedding/model")
async def get_embedding_model(user=Depends(get_admin_user)):
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
class EmbeddingModelUpdateForm(BaseModel):
embedding_model: str
@app.post("/embedding/model/update")
async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
@app.get("/config")
async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
"chunk": {
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
},
}
class ChunkParamUpdateForm(BaseModel):
chunk_size: int
chunk_overlap: int
class ConfigUpdateForm(BaseModel):
pdf_extract_images: bool
chunk: ChunkParamUpdateForm
@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = form_data.pdf_extract_images
app.state.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
return {
"status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
"chunk": {
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
},
}
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
}
@app.get("/query/settings")
async def get_query_settings(user=Depends(get_admin_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
}
class QuerySettingsForm(BaseModel):
k: Optional[int] = None
template: Optional[str] = None
@app.post("/query/settings/update")
async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
app.state.TOP_K = form_data.k if form_data.k else 4
return {"status": True, "template": app.state.RAG_TEMPLATE}
class QueryDocForm(BaseModel):
collection_name: str
query: str
k: Optional[int] = 4
k: Optional[int] = None
@app.post("/query/doc")
def query_doc(
def query_doc_handler(
form_data: QueryDocForm,
user=Depends(get_current_user),
):
try:
collection = CHROMA_CLIENT.get_collection(
name=form_data.collection_name,
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result
except Exception as e:
print(e)
raise HTTPException(
......@@ -125,74 +284,20 @@ def query_doc(
class QueryCollectionsForm(BaseModel):
collection_names: List[str]
query: str
k: Optional[int] = 4
def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data
combined_ids = []
combined_distances = []
combined_metadatas = []
combined_documents = []
# Combine data from each dictionary
for data in query_results:
combined_ids.extend(data["ids"][0])
combined_distances.extend(data["distances"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_documents.extend(data["documents"][0])
# Create a list of tuples (distance, id, metadata, document)
combined = list(
zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
)
# Sort the list based on distances
combined.sort(key=lambda x: x[0])
# Unzip the sorted list
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_ids = list(sorted_ids)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
sorted_documents = list(sorted_documents)[:k]
# Create the output dictionary
merged_query_results = {
"ids": [sorted_ids],
"distances": [sorted_distances],
"metadatas": [sorted_metadatas],
"documents": [sorted_documents],
"embeddings": None,
"uris": None,
"data": None,
}
return merged_query_results
k: Optional[int] = None
@app.post("/query/collection")
def query_collection(
def query_collection_handler(
form_data: QueryCollectionsForm,
user=Depends(get_current_user),
):
results = []
for collection_name in form_data.collection_names:
try:
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
)
result = collection.query(
query_texts=[form_data.query], n_results=form_data.k
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, form_data.k)
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
k=form_data.k if form_data.k else app.state.TOP_K,
embedding_function=app.state.sentence_transformer_ef,
)
@app.post("/web")
......@@ -206,7 +311,7 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
if collection_name == "":
collection_name = calculate_sha256_string(form_data.url)[:63]
store_data_in_vector_db(data, collection_name)
store_data_in_vector_db(data, collection_name, overwrite=True)
return {
"status": True,
"collection_name": collection_name,
......@@ -220,8 +325,8 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
)
def get_loader(file, file_path):
file_ext = file.filename.split(".")[-1].lower()
def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
known_type = True
known_source_ext = [
......@@ -270,7 +375,7 @@ def get_loader(file, file_path):
]
if file_ext == "pdf":
loader = PyPDFLoader(file_path)
loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
elif file_ext == "csv":
loader = CSVLoader(file_path)
elif file_ext == "rst":
......@@ -279,23 +384,25 @@ def get_loader(file, file_path):
loader = UnstructuredXMLLoader(file_path)
elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path)
elif file.content_type == "application/epub+zip":
elif file_content_type == "application/epub+zip":
loader = UnstructuredEPubLoader(file_path)
elif (
file.content_type
file_content_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
or file_ext in ["doc", "docx"]
):
loader = Docx2txtLoader(file_path)
elif file.content_type in [
elif file_content_type in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]:
loader = UnstructuredExcelLoader(file_path)
elif file_ext in known_source_ext or file.content_type.find("text/") >= 0:
loader = TextLoader(file_path)
elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TextLoader(file_path)
loader = TextLoader(file_path, autodetect_encoding=True)
known_type = False
return loader, known_type
......@@ -323,7 +430,7 @@ def store_doc(
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(file, file_path)
loader, known_type = get_loader(file.filename, file.content_type, file_path)
data = loader.load()
result = store_data_in_vector_db(data, collection_name)
......@@ -353,6 +460,63 @@ def store_doc(
)
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
for path in Path(DOCS_DIR).rglob("./**/*"):
try:
if path.is_file() and not path.name.startswith("."):
tags = extract_folders_after_data_docs(path)
filename = path.name
file_content_type = mimetypes.guess_type(path)
f = open(path, "rb")
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(
filename, file_content_type[0], str(path)
)
data = loader.load()
result = store_data_in_vector_db(data, collection_name)
if result:
sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None:
doc = Documents.insert_new_doc(
user.id,
DocumentForm(
**{
"name": sanitized_filename,
"title": filename,
"collection_name": collection_name,
"filename": filename,
"content": (
json.dumps(
{
"tags": list(
map(
lambda name: {"name": name},
tags,
)
)
}
)
if len(tags)
else "{}"
),
}
),
)
except Exception as e:
print(e)
return True
@app.get("/reset/db")
def reset_vector_db(user=Depends(get_admin_user)):
CHROMA_CLIENT.reset()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment