Unverified Commit 9763d885 authored by lainedfles's avatar lainedfles Committed by GitHub
Browse files

Merge Updates & Dockerfile improvements

parent fdef2abd
name: Python CI name: Python CI
on: on:
push: push:
branches: ['main'] branches:
- main
- dev
pull_request: pull_request:
branches:
- main
- dev
jobs: jobs:
build: build:
name: 'Format Backend' name: 'Format Backend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
node-version: python-version: [3.11]
- latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Use Python
uses: actions/setup-python@v4 - name: Set up Python
- name: Use Bun uses: actions/setup-python@v2
uses: oven-sh/setup-bun@v1 with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install yapf pip install black
- name: Format backend - name: Format backend
run: bun run format:backend run: npm run format:backend
- name: Check for changes after format
run: git diff --exit-code
name: Bun CI name: Frontend Build
on: on:
push: push:
branches: ['main'] branches:
- main
- dev
pull_request: pull_request:
branches:
- main
- dev
jobs: jobs:
build: build:
name: 'Format & Build Frontend' name: 'Format & Build Frontend'
env:
PUBLIC_API_BASE_URL: ''
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - name: Checkout Repository
- name: Use Bun uses: actions/checkout@v4
uses: oven-sh/setup-bun@v1
- run: bun --version - name: Setup Node.js
- name: Install frontend dependencies uses: actions/setup-node@v3
run: bun install with:
- name: Format frontend node-version: '20' # Or specify any other version you want to use
run: bun run format
- name: Build frontend - name: Install Dependencies
run: bun run build run: npm install
- name: Format Frontend
run: npm run format
- name: Check for Changes After Format
run: git diff --exit-code
- name: Build Frontend
run: npm run build
...@@ -166,7 +166,7 @@ cython_debug/ ...@@ -166,7 +166,7 @@ cython_debug/
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear # and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ .idea/
# Logs # Logs
logs logs
......
...@@ -5,6 +5,86 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,86 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.116] - 2024-03-31
### Added
- **🔄 Enhanced UI**: Model selector now conveniently located in the navbar, enabling seamless switching between multiple models during conversations.
- **🔍 Improved Model Selector**: Directly pull a model from the selector/Models now display detailed information for better understanding.
- **💬 Webhook Support**: Now compatible with Google Chat and Microsoft Teams.
- **🌐 Localization**: Korean translation (I18n) now available.
- **🌑 Dark Theme**: OLED dark theme introduced for reduced strain during prolonged usage.
- **🏷️ Tag Autocomplete**: Dropdown feature added for effortless chat tagging.
### Fixed
- **🔽 Auto-Scrolling**: Addressed OpenAI auto-scrolling issue.
- **🏷️ Tag Validation**: Implemented tag validation to prevent empty string tags.
- **🚫 Model Whitelisting**: Resolved LiteLLM model whitelisting issue.
- **✅ Spelling**: Corrected various spelling issues for improved readability.
## [0.1.115] - 2024-03-24
### Added
- **🔍 Custom Model Selector**: Easily find and select custom models with the new search filter feature.
- **🛑 Cancel Model Download**: Added the ability to cancel model downloads.
- **🎨 Image Generation ComfyUI**: Image generation now supports ComfyUI.
- **🌟 Updated Light Theme**: Updated the light theme for a fresh look.
- **🌍 Additional Language Support**: Now supporting Bulgarian, Italian, Portuguese, Japanese, and Dutch.
### Fixed
- **🔧 Fixed Broken Experimental GGUF Upload**: Resolved issues with experimental GGUF upload functionality.
### Changed
- **🔄 Vector Storage Reset Button**: Moved the reset vector storage button to document settings.
## [0.1.114] - 2024-03-20
### Added
- **🔗 Webhook Integration**: Now you can subscribe to new user sign-up events via webhook. Simply navigate to the admin panel > admin settings > webhook URL.
- **🛡️ Enhanced Model Filtering**: Alongside Ollama, OpenAI proxy model whitelisting, we've added model filtering functionality for LiteLLM proxy.
- **🌍 Expanded Language Support**: Spanish, Catalan, and Vietnamese languages are now available, with improvements made to others.
### Fixed
- **🔧 Input Field Spelling**: Resolved issue with spelling mistakes in input fields.
- **🖊️ Light Mode Styling**: Fixed styling issue with light mode in document adding.
### Changed
- **🔄 Language Sorting**: Languages are now sorted alphabetically by their code for improved organization.
## [0.1.113] - 2024-03-18
### Added
- 🌍 **Localization**: You can now change the UI language in Settings > General. We support Ukrainian, German, Farsi (Persian), Traditional and Simplified Chinese and French translations. You can help us to translate the UI into your language! More info in our [CONTRIBUTION.md](https://github.com/open-webui/open-webui/blob/main/docs/CONTRIBUTING.md#-translations-and-internationalization).
- 🎨 **System-wide Theme**: Introducing a new system-wide theme for enhanced visual experience.
### Fixed
- 🌑 **Dark Background on Select Fields**: Improved readability by adding a dark background to select fields, addressing issues on certain browsers/devices.
- **Multiple OPENAI_API_BASE_URLS Issue**: Resolved issue where multiple base URLs caused conflicts when one wasn't functioning.
- **RAG Encoding Issue**: Fixed encoding problem in RAG.
- **npm Audit Fix**: Addressed npm audit findings.
- **Reduced Scroll Threshold**: Improved auto-scroll experience by reducing the scroll threshold from 50px to 5px.
### Changed
- 🔄 **Sidebar UI Update**: Updated sidebar UI to feature a chat menu dropdown, replacing two icons for improved navigation.
## [0.1.112] - 2024-03-15
### Fixed
- 🗨️ Resolved chat malfunction after image generation.
- 🎨 Fixed various RAG issues.
- 🧪 Rectified experimental broken GGUF upload logic.
## [0.1.111] - 2024-03-10 ## [0.1.111] - 2024-03-10
### Added ### Added
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# Initialize device type args # Initialize device type args
# use build args in the docker build commmand with --build-arg="BUILDARG=true" # use build args in the docker build commmand with --build-arg="BUILDARG=true"
ARG USE_CUDA=false ARG USE_CUDA=false
ARG USE_CUDA_VER=cu121
ARG USE_EMBEDDING_MODEL=all-MiniLM-L6-v2
ARG USE_MPS=false ARG USE_MPS=false
ARG INCLUDE_OLLAMA=false ARG INCLUDE_OLLAMA=false
...@@ -28,8 +30,9 @@ RUN npm run build ...@@ -28,8 +30,9 @@ RUN npm run build
######## WebUI backend ######## ######## WebUI backend ########
FROM python:3.11-slim-bookworm as base FROM python:3.11-slim-bookworm as base
# Use args
ARG USE_CUDA ARG USE_CUDA
ARG USE_CUDA_VER
ARG USE_EMBEDDING_MODEL
ARG USE_MPS ARG USE_MPS
ARG INCLUDE_OLLAMA ARG INCLUDE_OLLAMA
...@@ -39,7 +42,9 @@ ENV ENV=prod \ ...@@ -39,7 +42,9 @@ ENV ENV=prod \
# pass build args to the build # pass build args to the build
INCLUDE_OLLAMA_DOCKER=${INCLUDE_OLLAMA} \ INCLUDE_OLLAMA_DOCKER=${INCLUDE_OLLAMA} \
USE_MPS_DOCKER=${USE_MPS} \ USE_MPS_DOCKER=${USE_MPS} \
USE_CUDA_DOCKER=${USE_CUDA} USE_CUDA_DOCKER=${USE_CUDA} \
USE_CUDA_DOCKER_VER=${USE_CUDA_VER} \
USE_EMBEDDING_MODEL_DOCKER=${USE_EMBEDDING_MODEL}
## Basis URL Config ## ## Basis URL Config ##
ENV OLLAMA_BASE_URL="/ollama" \ ENV OLLAMA_BASE_URL="/ollama" \
...@@ -61,7 +66,7 @@ ENV WHISPER_MODEL="base" \ ...@@ -61,7 +66,7 @@ ENV WHISPER_MODEL="base" \
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB) # for better performance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. # IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2" \ ENV RAG_EMBEDDING_MODEL="$USE_EMBEDDING_MODEL_DOCKER" \
RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models" \ RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models" \
SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" \ SENTENCE_TRANSFORMERS_HOME="/app/backend/data/cache/embedding/models" \
# device type for whisper tts and embbeding models - "cpu" (default) or "mps" (apple silicon) - choosing this right can lead to better performance # device type for whisper tts and embbeding models - "cpu" (default) or "mps" (apple silicon) - choosing this right can lead to better performance
...@@ -78,8 +83,10 @@ WORKDIR /app/backend ...@@ -78,8 +83,10 @@ WORKDIR /app/backend
COPY ./backend/requirements.txt ./requirements.txt COPY ./backend/requirements.txt ./requirements.txt
RUN if [ "$USE_CUDA" = "true" ]; then \ RUN if [ "$USE_CUDA" = "true" ]; then \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/$USE_CUDA_DOCKER_VER --no-cache-dir && \
pip3 install -r requirements.txt --no-cache-dir; \ pip3 install -r requirements.txt --no-cache-dir; \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])" && \
python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device='cpu')"; \
elif [ "$USE_MPS" = "true" ]; then \ elif [ "$USE_MPS" = "true" ]; then \
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir && \
pip3 install -r requirements.txt --no-cache-dir && \ pip3 install -r requirements.txt --no-cache-dir && \
......
...@@ -8,6 +8,8 @@ remove: ...@@ -8,6 +8,8 @@ remove:
start: start:
@docker-compose start @docker-compose start
startAndBuild:
docker-compose up -d --build
stop: stop:
@docker-compose stop @docker-compose stop
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
[![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s) [![Discord](https://img.shields.io/badge/Discord-Open_WebUI-blue?logo=discord&logoColor=white)](https://discord.gg/5rJgQTnV4s)
[![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck) [![](https://img.shields.io/static/v1?label=Sponsor&message=%E2%9D%A4&logo=GitHub&color=%23fe8e86)](https://github.com/sponsors/tjbck)
User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-compatible APIs. For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/). Open WebUI is an extensible, feature-rich, and user-friendly self-hosted WebUI designed to operate entirely offline. It supports various LLM runners, including Ollama and OpenAI-compatible APIs. For more information, be sure to check out our [Open WebUI Documentation](https://docs.openwebui.com/).
![Open WebUI Demo](./demo.gif) ![Open WebUI Demo](./demo.gif)
...@@ -79,6 +79,8 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co ...@@ -79,6 +79,8 @@ User-friendly WebUI for LLMs, supported LLM runners include Ollama and OpenAI-co
- 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security. - 🔒 **Backend Reverse Proxy Support**: Bolster security through direct communication between Open WebUI backend and Ollama. This key feature eliminates the need to expose Ollama over LAN. Requests made to the '/ollama/api' route from the web UI are seamlessly redirected to Ollama from the backend, enhancing overall system security.
- 🌐🌍 **Multilingual Support**: Experience Open WebUI in your preferred language with our internationalization (i18n) support. Join us in expanding our supported languages! We're actively seeking contributors!
- 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates and new features. - 🌟 **Continuous Updates**: We are committed to improving Open WebUI with regular updates and new features.
## 🔗 Also Check Out Open WebUI Community! ## 🔗 Also Check Out Open WebUI Community!
......
import os import os
import logging
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request, Request,
...@@ -21,11 +22,24 @@ from utils.utils import ( ...@@ -21,11 +22,24 @@ from utils.utils import (
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR, DEVICE_TYPE from config import (
SRC_LOG_LEVELS,
CACHE_DIR,
UPLOAD_DIR,
WHISPER_MODEL,
WHISPER_MODEL_DIR,
DEVICE_TYPE,
)
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["AUDIO"])
whisper_device_type = DEVICE_TYPE
if DEVICE_TYPE != "cuda": if whisper_device_type != "cuda":
whisper_device_type = "cpu" whisper_device_type = "cpu"
log.info(f"whisper_device_type: {whisper_device_type}")
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
...@@ -42,7 +56,7 @@ def transcribe( ...@@ -42,7 +56,7 @@ def transcribe(
file: UploadFile = File(...), file: UploadFile = File(...),
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
print(file.content_type) log.info(f"file.content_type: {file.content_type}")
if file.content_type not in ["audio/mpeg", "audio/wav"]: if file.content_type not in ["audio/mpeg", "audio/wav"]:
raise HTTPException( raise HTTPException(
...@@ -66,7 +80,7 @@ def transcribe( ...@@ -66,7 +80,7 @@ def transcribe(
) )
segments, info = model.transcribe(file_path, beam_size=5) segments, info = model.transcribe(file_path, beam_size=5)
print( log.info(
"Detected language '%s' with probability %f" "Detected language '%s' with probability %f"
% (info.language, info.language_probability) % (info.language, info.language_probability)
) )
...@@ -76,7 +90,7 @@ def transcribe( ...@@ -76,7 +90,7 @@ def transcribe(
return {"text": transcript.strip()} return {"text": transcript.strip()}
except Exception as e: except Exception as e:
print(e) log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
......
...@@ -18,6 +18,8 @@ from utils.utils import ( ...@@ -18,6 +18,8 @@ from utils.utils import (
get_current_user, get_current_user,
get_admin_user, get_admin_user,
) )
from apps.images.utils.comfyui import ImageGenerationPayload, comfyui_generate_image
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
from typing import Optional from typing import Optional
from pydantic import BaseModel from pydantic import BaseModel
...@@ -25,9 +27,13 @@ from pathlib import Path ...@@ -25,9 +27,13 @@ from pathlib import Path
import uuid import uuid
import base64 import base64
import json import json
import logging
from config import SRC_LOG_LEVELS, CACHE_DIR, AUTOMATIC1111_BASE_URL, COMFYUI_BASE_URL
from config import CACHE_DIR, AUTOMATIC1111_BASE_URL
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["IMAGES"])
IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/") IMAGE_CACHE_DIR = Path(CACHE_DIR).joinpath("./image/generations/")
IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True) IMAGE_CACHE_DIR.mkdir(parents=True, exist_ok=True)
...@@ -49,6 +55,8 @@ app.state.MODEL = "" ...@@ -49,6 +55,8 @@ app.state.MODEL = ""
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.IMAGE_SIZE = "512x512" app.state.IMAGE_SIZE = "512x512"
app.state.IMAGE_STEPS = 50 app.state.IMAGE_STEPS = 50
...@@ -71,32 +79,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user ...@@ -71,32 +79,48 @@ async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
class UrlUpdateForm(BaseModel): class EngineUrlUpdateForm(BaseModel):
url: str AUTOMATIC1111_BASE_URL: Optional[str] = None
COMFYUI_BASE_URL: Optional[str] = None
@app.get("/url") @app.get("/url")
async def get_automatic1111_url(user=Depends(get_admin_user)): async def get_engine_url(user=Depends(get_admin_user)):
return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL} return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
}
@app.post("/url/update") @app.post("/url/update")
async def update_automatic1111_url( async def update_engine_url(
form_data: UrlUpdateForm, user=Depends(get_admin_user) form_data: EngineUrlUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.url == "": if form_data.AUTOMATIC1111_BASE_URL == None:
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else: else:
url = form_data.url.strip("/") url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
app.state.AUTOMATIC1111_BASE_URL = url app.state.AUTOMATIC1111_BASE_URL = url
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None:
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else:
url = form_data.COMFYUI_BASE_URL.strip("/")
try:
r = requests.head(url)
app.state.COMFYUI_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return { return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
"status": True, "status": True,
} }
...@@ -186,6 +210,18 @@ def get_models(user=Depends(get_current_user)): ...@@ -186,6 +210,18 @@ def get_models(user=Depends(get_current_user)):
{"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"}, {"id": "dall-e-3", "name": "DALL·E 3"},
] ]
elif app.state.ENGINE == "comfyui":
r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
info = r.json()
return list(
map(
lambda model: {"id": model, "name": model},
info["CheckpointLoaderSimple"]["input"]["required"]["ckpt_name"][0],
)
)
else: else:
r = requests.get( r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
...@@ -207,6 +243,8 @@ async def get_default_model(user=Depends(get_admin_user)): ...@@ -207,6 +243,8 @@ async def get_default_model(user=Depends(get_admin_user)):
try: try:
if app.state.ENGINE == "openai": if app.state.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
elif app.state.ENGINE == "comfyui":
return {"model": app.state.MODEL if app.state.MODEL else ""}
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json() options = r.json()
...@@ -221,10 +259,12 @@ class UpdateModelForm(BaseModel): ...@@ -221,10 +259,12 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str): def set_model_handler(model: str):
if app.state.ENGINE == "openai": if app.state.ENGINE == "openai":
app.state.MODEL = model app.state.MODEL = model
return app.state.MODEL return app.state.MODEL
if app.state.ENGINE == "comfyui":
app.state.MODEL = model
return app.state.MODEL
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json() options = r.json()
...@@ -268,7 +308,24 @@ def save_b64_image(b64_str): ...@@ -268,7 +308,24 @@ def save_b64_image(b64_str):
return image_id return image_id
except Exception as e: except Exception as e:
print(f"Error saving image: {e}") log.error(f"Error saving image: {e}")
return None
def save_url_image(url):
image_id = str(uuid.uuid4())
file_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.png")
try:
r = requests.get(url)
r.raise_for_status()
with open(file_path, "wb") as image_file:
image_file.write(r.content)
return image_id
except Exception as e:
log.exception(f"Error saving image: {e}")
return None return None
...@@ -278,6 +335,8 @@ def generate_image( ...@@ -278,6 +335,8 @@ def generate_image(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
r = None r = None
try: try:
if app.state.ENGINE == "openai": if app.state.ENGINE == "openai":
...@@ -293,6 +352,7 @@ def generate_image( ...@@ -293,6 +352,7 @@ def generate_image(
"size": form_data.size if form_data.size else app.state.IMAGE_SIZE, "size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
"response_format": "b64_json", "response_format": "b64_json",
} }
r = requests.post( r = requests.post(
url=f"https://api.openai.com/v1/images/generations", url=f"https://api.openai.com/v1/images/generations",
json=data, json=data,
...@@ -300,7 +360,6 @@ def generate_image( ...@@ -300,7 +360,6 @@ def generate_image(
) )
r.raise_for_status() r.raise_for_status()
res = r.json() res = r.json()
images = [] images = []
...@@ -315,12 +374,47 @@ def generate_image( ...@@ -315,12 +374,47 @@ def generate_image(
return images return images
elif app.state.ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
"width": width,
"height": height,
"n": form_data.n,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(
app.state.MODEL,
data,
user.id,
app.state.COMFYUI_BASE_URL,
)
log.debug(f"res: {res}")
images = []
for image in res["data"]:
image_id = save_url_image(image["url"])
images.append({"url": f"/cache/image/generations/{image_id}.png"})
file_body_path = IMAGE_CACHE_DIR.joinpath(f"{image_id}.json")
with open(file_body_path, "w") as f:
json.dump(data.model_dump(exclude_none=True), f)
log.debug(f"images: {images}")
return images
else: else:
if form_data.model: if form_data.model:
set_model_handler(form_data.model) set_model_handler(form_data.model)
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
data = { data = {
"prompt": form_data.prompt, "prompt": form_data.prompt,
"batch_size": form_data.n, "batch_size": form_data.n,
...@@ -341,7 +435,7 @@ def generate_image( ...@@ -341,7 +435,7 @@ def generate_image(
res = r.json() res = r.json()
print(res) log.debug(f"res: {res}")
images = [] images = []
...@@ -356,7 +450,10 @@ def generate_image( ...@@ -356,7 +450,10 @@ def generate_image(
return images return images
except Exception as e: except Exception as e:
print(e) error = e
if r:
print(r.json()) if r != None:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) data = r.json()
if "error" in data:
error = data["error"]["message"]
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(error))
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json
import urllib.request
import urllib.parse
import random
import logging
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["COMFYUI"])
from pydantic import BaseModel
from typing import Optional
COMFYUI_DEFAULT_PROMPT = """
{
"3": {
"inputs": {
"seed": 0,
"steps": 20,
"cfg": 8,
"sampler_name": "euler",
"scheduler": "normal",
"denoise": 1,
"model": [
"4",
0
],
"positive": [
"6",
0
],
"negative": [
"7",
0
],
"latent_image": [
"5",
0
]
},
"class_type": "KSampler",
"_meta": {
"title": "KSampler"
}
},
"4": {
"inputs": {
"ckpt_name": "model.safetensors"
},
"class_type": "CheckpointLoaderSimple",
"_meta": {
"title": "Load Checkpoint"
}
},
"5": {
"inputs": {
"width": 512,
"height": 512,
"batch_size": 1
},
"class_type": "EmptyLatentImage",
"_meta": {
"title": "Empty Latent Image"
}
},
"6": {
"inputs": {
"text": "Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"7": {
"inputs": {
"text": "Negative Prompt",
"clip": [
"4",
1
]
},
"class_type": "CLIPTextEncode",
"_meta": {
"title": "CLIP Text Encode (Prompt)"
}
},
"8": {
"inputs": {
"samples": [
"3",
0
],
"vae": [
"4",
2
]
},
"class_type": "VAEDecode",
"_meta": {
"title": "VAE Decode"
}
},
"9": {
"inputs": {
"filename_prefix": "ComfyUI",
"images": [
"8",
0
]
},
"class_type": "SaveImage",
"_meta": {
"title": "Save Image"
}
}
}
"""
def queue_prompt(prompt, client_id, base_url):
log.info("queue_prompt")
p = {"prompt": prompt, "client_id": client_id}
data = json.dumps(p).encode("utf-8")
req = urllib.request.Request(f"{base_url}/prompt", data=data)
return json.loads(urllib.request.urlopen(req).read())
def get_image(filename, subfolder, folder_type, base_url):
log.info("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
with urllib.request.urlopen(f"{base_url}/view?{url_values}") as response:
return response.read()
def get_image_url(filename, subfolder, folder_type, base_url):
log.info("get_image")
data = {"filename": filename, "subfolder": subfolder, "type": folder_type}
url_values = urllib.parse.urlencode(data)
return f"{base_url}/view?{url_values}"
def get_history(prompt_id, base_url):
log.info("get_history")
with urllib.request.urlopen(f"{base_url}/history/{prompt_id}") as response:
return json.loads(response.read())
def get_images(ws, prompt, client_id, base_url):
prompt_id = queue_prompt(prompt, client_id, base_url)["prompt_id"]
output_images = []
while True:
out = ws.recv()
if isinstance(out, str):
message = json.loads(out)
if message["type"] == "executing":
data = message["data"]
if data["node"] is None and data["prompt_id"] == prompt_id:
break # Execution is done
else:
continue # previews are binary data
history = get_history(prompt_id, base_url)[prompt_id]
for o in history["outputs"]:
for node_id in history["outputs"]:
node_output = history["outputs"][node_id]
if "images" in node_output:
for image in node_output["images"]:
url = get_image_url(
image["filename"], image["subfolder"], image["type"], base_url
)
output_images.append({"url": url})
return {"data": output_images}
class ImageGenerationPayload(BaseModel):
prompt: str
negative_prompt: Optional[str] = ""
steps: Optional[int] = None
seed: Optional[int] = None
width: int
height: int
n: int = 1
def comfyui_generate_image(
model: str, payload: ImageGenerationPayload, client_id, base_url
):
host = base_url.replace("http://", "").replace("https://", "")
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
comfyui_prompt["5"]["inputs"]["width"] = payload.width
comfyui_prompt["5"]["inputs"]["height"] = payload.height
# set the text prompt for our positive CLIPTextEncode
comfyui_prompt["6"]["inputs"]["text"] = payload.prompt
comfyui_prompt["7"]["inputs"]["text"] = payload.negative_prompt
if payload.steps:
comfyui_prompt["3"]["inputs"]["steps"] = payload.steps
comfyui_prompt["3"]["inputs"]["seed"] = (
payload.seed if payload.seed else random.randint(0, 18446744073709551614)
)
try:
ws = websocket.WebSocket()
ws.connect(f"ws://{host}/ws?clientId={client_id}")
log.info("WebSocket connection established.")
except Exception as e:
log.exception(f"Failed to connect to WebSocket server: {e}")
return None
try:
images = get_images(ws, comfyui_prompt, client_id, base_url)
except Exception as e:
log.exception(f"Error while receiving images: {e}")
images = None
ws.close()
return images
import logging
from litellm.proxy.proxy_server import ProxyConfig, initialize from litellm.proxy.proxy_server import ProxyConfig, initialize
from litellm.proxy.proxy_server import app from litellm.proxy.proxy_server import app
from fastapi import FastAPI, Request, Depends, status from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
from utils.utils import get_http_authorization_cred, get_current_user from utils.utils import get_http_authorization_cred, get_current_user
from config import ENV from config import SRC_LOG_LEVELS, ENV
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import (
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
)
proxy_config = ProxyConfig() proxy_config = ProxyConfig()
...@@ -26,16 +43,58 @@ async def on_startup(): ...@@ -26,16 +43,58 @@ async def on_startup():
await startup() await startup()
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.middleware("http") @app.middleware("http")
async def auth_middleware(request: Request, call_next): async def auth_middleware(request: Request, call_next):
auth_header = request.headers.get("Authorization", "") auth_header = request.headers.get("Authorization", "")
request.state.user = None
if ENV != "dev":
try: try:
user = get_current_user(get_http_authorization_cred(auth_header)) user = get_current_user(get_http_authorization_cred(auth_header))
print(user) log.debug(f"user: {user}")
request.state.user = user
except Exception as e: except Exception as e:
return JSONResponse(status_code=400, content={"detail": str(e)}) return JSONResponse(status_code=400, content={"detail": str(e)})
response = await call_next(request) response = await call_next(request)
return response return response
class ModifyModelsResponseMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
user = request.state.user
if "/models" in request.url.path:
if isinstance(response, StreamingResponse):
# Read the content of the streaming response
body = b""
async for chunk in response.body_iterator:
body += chunk
data = json.loads(body.decode("utf-8"))
if app.state.MODEL_FILTER_ENABLED:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"]
in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
# Modified Flag
data["modified"] = True
return JSONResponse(content=data)
return response
app.add_middleware(ModifyModelsResponseMiddleware)
from fastapi import FastAPI, Request, Response, HTTPException, Depends, status from fastapi import (
FastAPI,
Request,
Response,
HTTPException,
Depends,
status,
UploadFile,
File,
BackgroundTasks,
)
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
import os
import copy
import random import random
import requests import requests
import json import json
import uuid import uuid
import aiohttp import aiohttp
import asyncio import asyncio
import logging
from urllib.parse import urlparse
from typing import Optional, List, Union
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user, get_admin_user from utils.utils import decode_token, get_current_user, get_admin_user
from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST
from typing import Optional, List, Union
from config import (
SRC_LOG_LEVELS,
OLLAMA_BASE_URLS,
MODEL_FILTER_ENABLED,
MODEL_FILTER_LIST,
UPLOAD_DIR,
)
from utils.misc import calculate_sha256
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
...@@ -69,7 +94,7 @@ class UrlUpdateForm(BaseModel): ...@@ -69,7 +94,7 @@ class UrlUpdateForm(BaseModel):
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
app.state.OLLAMA_BASE_URLS = form_data.urls app.state.OLLAMA_BASE_URLS = form_data.urls
print(app.state.OLLAMA_BASE_URLS) log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}")
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
...@@ -90,7 +115,7 @@ async def fetch_url(url): ...@@ -90,7 +115,7 @@ async def fetch_url(url):
return await response.json() return await response.json()
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") log.error(f"Connection error: {e}")
return None return None
...@@ -98,6 +123,7 @@ def merge_models_lists(model_lists): ...@@ -98,6 +123,7 @@ def merge_models_lists(model_lists):
merged_models = {} merged_models = {}
for idx, model_list in enumerate(model_lists): for idx, model_list in enumerate(model_lists):
if model_list is not None:
for model in model_list: for model in model_list:
digest = model["digest"] digest = model["digest"]
if digest not in merged_models: if digest not in merged_models:
...@@ -113,16 +139,16 @@ def merge_models_lists(model_lists): ...@@ -113,16 +139,16 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
print("get_all_models") log.info("get_all_models()")
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
models = { models = {
"models": merge_models_lists( "models": merge_models_lists(
map(lambda response: response["models"], responses) map(lambda response: response["models"] if response else None, responses)
) )
} }
app.state.MODELS = {model["model"]: model for model in models["models"]} app.state.MODELS = {model["model"]: model for model in models["models"]}
return models return models
...@@ -154,7 +180,7 @@ async def get_ollama_tags( ...@@ -154,7 +180,7 @@ async def get_ollama_tags(
return r.json() return r.json()
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -181,11 +207,17 @@ async def get_ollama_versions(url_idx: Optional[int] = None): ...@@ -181,11 +207,17 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses)) responses = list(filter(lambda x: x is not None, responses))
if len(responses) > 0:
lowest_version = min( lowest_version = min(
responses, key=lambda x: tuple(map(int, x["version"].split("."))) responses, key=lambda x: tuple(map(int, x["version"].split(".")))
) )
return {"version": lowest_version["version"]} return {"version": lowest_version["version"]}
else:
raise HTTPException(
status_code=500,
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
)
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
try: try:
...@@ -194,7 +226,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): ...@@ -194,7 +226,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
return r.json() return r.json()
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -220,18 +252,33 @@ async def pull_model( ...@@ -220,18 +252,33 @@ async def pull_model(
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
r = None r = None
def get_request(): def get_request():
nonlocal url nonlocal url
nonlocal r nonlocal r
request_id = str(uuid.uuid4())
try: try:
REQUEST_POOL.append(request_id)
def stream_content(): def stream_content():
try:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192): for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request( r = requests.request(
method="POST", method="POST",
...@@ -252,8 +299,9 @@ async def pull_model( ...@@ -252,8 +299,9 @@ async def pull_model(
try: try:
return await run_in_threadpool(get_request) return await run_in_threadpool(get_request)
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -292,7 +340,7 @@ async def push_model( ...@@ -292,7 +340,7 @@ async def push_model(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.debug(f"url: {url}")
r = None r = None
...@@ -324,7 +372,7 @@ async def push_model( ...@@ -324,7 +372,7 @@ async def push_model(
try: try:
return await run_in_threadpool(get_request) return await run_in_threadpool(get_request)
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -352,9 +400,9 @@ class CreateModelForm(BaseModel): ...@@ -352,9 +400,9 @@ class CreateModelForm(BaseModel):
async def create_model( async def create_model(
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
print(form_data) log.debug(f"form_data: {form_data}")
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
r = None r = None
...@@ -376,7 +424,7 @@ async def create_model( ...@@ -376,7 +424,7 @@ async def create_model(
r.raise_for_status() r.raise_for_status()
print(r) log.debug(f"r: {r}")
return StreamingResponse( return StreamingResponse(
stream_content(), stream_content(),
...@@ -389,7 +437,7 @@ async def create_model( ...@@ -389,7 +437,7 @@ async def create_model(
try: try:
return await run_in_threadpool(get_request) return await run_in_threadpool(get_request)
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -427,7 +475,7 @@ async def copy_model( ...@@ -427,7 +475,7 @@ async def copy_model(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
try: try:
r = requests.request( r = requests.request(
...@@ -437,11 +485,11 @@ async def copy_model( ...@@ -437,11 +485,11 @@ async def copy_model(
) )
r.raise_for_status() r.raise_for_status()
print(r.text) log.debug(f"r.text: {r.text}")
return True return True
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -474,7 +522,7 @@ async def delete_model( ...@@ -474,7 +522,7 @@ async def delete_model(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
try: try:
r = requests.request( r = requests.request(
...@@ -484,11 +532,11 @@ async def delete_model( ...@@ -484,11 +532,11 @@ async def delete_model(
) )
r.raise_for_status() r.raise_for_status()
print(r.text) log.debug(f"r.text: {r.text}")
return True return True
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -514,7 +562,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use ...@@ -514,7 +562,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
try: try:
r = requests.request( r = requests.request(
...@@ -526,7 +574,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use ...@@ -526,7 +574,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_use
return r.json() return r.json()
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -566,7 +614,7 @@ async def generate_embeddings( ...@@ -566,7 +614,7 @@ async def generate_embeddings(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
try: try:
r = requests.request( r = requests.request(
...@@ -578,7 +626,7 @@ async def generate_embeddings( ...@@ -578,7 +626,7 @@ async def generate_embeddings(
return r.json() return r.json()
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -622,11 +670,11 @@ async def generate_completion( ...@@ -622,11 +670,11 @@ async def generate_completion(
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail="error_detail", detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
r = None r = None
...@@ -647,7 +695,7 @@ async def generate_completion( ...@@ -647,7 +695,7 @@ async def generate_completion(
if request_id in REQUEST_POOL: if request_id in REQUEST_POOL:
yield chunk yield chunk
else: else:
print("User: canceled request") log.warning("User: canceled request")
break break
finally: finally:
if hasattr(r, "close"): if hasattr(r, "close"):
...@@ -702,7 +750,7 @@ class GenerateChatCompletionForm(BaseModel): ...@@ -702,7 +750,7 @@ class GenerateChatCompletionForm(BaseModel):
format: Optional[str] = None format: Optional[str] = None
options: Optional[dict] = None options: Optional[dict] = None
template: Optional[str] = None template: Optional[str] = None
stream: Optional[bool] = True stream: Optional[bool] = None
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
...@@ -724,11 +772,15 @@ async def generate_chat_completion( ...@@ -724,11 +772,15 @@ async def generate_chat_completion(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
r = None r = None
print(form_data.model_dump_json(exclude_none=True).encode()) log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
def get_request(): def get_request():
nonlocal form_data nonlocal form_data
...@@ -747,7 +799,7 @@ async def generate_chat_completion( ...@@ -747,7 +799,7 @@ async def generate_chat_completion(
if request_id in REQUEST_POOL: if request_id in REQUEST_POOL:
yield chunk yield chunk
else: else:
print("User: canceled request") log.warning("User: canceled request")
break break
finally: finally:
if hasattr(r, "close"): if hasattr(r, "close"):
...@@ -770,7 +822,7 @@ async def generate_chat_completion( ...@@ -770,7 +822,7 @@ async def generate_chat_completion(
headers=dict(r.headers), headers=dict(r.headers),
) )
except Exception as e: except Exception as e:
print(e) log.exception(e)
raise e raise e
try: try:
...@@ -824,7 +876,7 @@ async def generate_openai_chat_completion( ...@@ -824,7 +876,7 @@ async def generate_openai_chat_completion(
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.OLLAMA_BASE_URLS[url_idx]
print(url) log.info(f"url: {url}")
r = None r = None
...@@ -847,7 +899,7 @@ async def generate_openai_chat_completion( ...@@ -847,7 +899,7 @@ async def generate_openai_chat_completion(
if request_id in REQUEST_POOL: if request_id in REQUEST_POOL:
yield chunk yield chunk
else: else:
print("User: canceled request") log.warning("User: canceled request")
break break
finally: finally:
if hasattr(r, "close"): if hasattr(r, "close"):
...@@ -890,6 +942,220 @@ async def generate_openai_chat_completion( ...@@ -890,6 +942,220 @@ async def generate_openai_chat_completion(
) )
class UrlForm(BaseModel):
url: str
class UploadBlobForm(BaseModel):
filename: str
def parse_huggingface_url(hf_url):
try:
# Parse the URL
parsed_url = urlparse(hf_url)
# Get the path and split it into components
path_components = parsed_url.path.split("/")
# Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1]
return model_file
except ValueError:
return None
async def download_file_stream(
ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
):
done = False
if os.path.exists(file_path):
current_size = os.path.getsize(file_path)
else:
current_size = 0
headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(file_url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size):
current_size += len(data)
file.write(data)
done = current_size == total_size
progress = round((current_size / total_size) * 100, 2)
yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
if done:
file.seek(0)
hashed = calculate_sha256(file)
file.seek(0)
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
response = requests.post(url, data=file)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file_name,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise "Ollama: Could not create blob, Please try again."
# def number_generator():
# for i in range(1, 101):
# yield f"data: {i}\n"
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
@app.post("/models/download")
@app.post("/models/download/{url_idx}")
async def download_model(
form_data: UrlForm,
url_idx: Optional[int] = None,
):
allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
if not any(form_data.url.startswith(host) for host in allowed_hosts):
raise HTTPException(
status_code=400,
detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
)
if url_idx == None:
url_idx = 0
url = app.state.OLLAMA_BASE_URLS[url_idx]
file_name = parse_huggingface_url(form_data.url)
if file_name:
file_path = f"{UPLOAD_DIR}/{file_name}"
return StreamingResponse(
download_file_stream(url, form_data.url, file_path, file_name),
)
else:
return None
@app.post("/models/upload")
@app.post("/models/upload/{url_idx}")
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
if url_idx == None:
url_idx = 0
ollama_url = app.state.OLLAMA_BASE_URLS[url_idx]
file_path = f"{UPLOAD_DIR}/{file.filename}"
# Save file in chunks
with open(file_path, "wb+") as f:
for chunk in file.file:
f.write(chunk)
def file_process_stream():
nonlocal ollama_url
total_size = os.path.getsize(file_path)
chunk_size = 1024 * 1024
try:
with open(file_path, "rb") as f:
total = 0
done = False
while not done:
chunk = f.read(chunk_size)
if not chunk:
done = True
continue
total += len(chunk)
progress = round((total / total_size) * 100, 2)
res = {
"progress": progress,
"total": total_size,
"completed": total,
}
yield f"data: {json.dumps(res)}\n\n"
if done:
f.seek(0)
hashed = calculate_sha256(f)
f.seek(0)
url = f"{ollama_url}/api/blobs/sha256:{hashed}"
response = requests.post(url, data=f)
if response.ok:
res = {
"done": done,
"blob": f"sha256:{hashed}",
"name": file.filename,
}
os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n"
else:
raise Exception(
"Ollama: Could not create blob, Please try again."
)
except Exception as e:
res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
# async def file_upload_generator(file):
# print(file)
# try:
# async with aiofiles.open(file_location, "wb") as f:
# completed_size = 0
# while True:
# chunk = await file.read(1024*1024)
# if not chunk:
# break
# await f.write(chunk)
# completed_size += len(chunk)
# progress = (completed_size / total_size) * 100
# print(progress)
# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
# except Exception as e:
# print(e)
# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
# finally:
# await file.close()
# print("done")
# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
# return StreamingResponse(
# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
# )
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)): async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)):
url = app.state.OLLAMA_BASE_URLS[0] url = app.state.OLLAMA_BASE_URLS[0]
...@@ -940,7 +1206,7 @@ async def deprecated_proxy(path: str, request: Request, user=Depends(get_current ...@@ -940,7 +1206,7 @@ async def deprecated_proxy(path: str, request: Request, user=Depends(get_current
if request_id in REQUEST_POOL: if request_id in REQUEST_POOL:
yield chunk yield chunk
else: else:
print("User: canceled request") log.warning("User: canceled request")
break break
finally: finally:
if hasattr(r, "close"): if hasattr(r, "close"):
......
...@@ -6,6 +6,7 @@ import requests ...@@ -6,6 +6,7 @@ import requests
import aiohttp import aiohttp
import asyncio import asyncio
import json import json
import logging
from pydantic import BaseModel from pydantic import BaseModel
...@@ -19,6 +20,7 @@ from utils.utils import ( ...@@ -19,6 +20,7 @@ from utils.utils import (
get_admin_user, get_admin_user,
) )
from config import ( from config import (
SRC_LOG_LEVELS,
OPENAI_API_BASE_URLS, OPENAI_API_BASE_URLS,
OPENAI_API_KEYS, OPENAI_API_KEYS,
CACHE_DIR, CACHE_DIR,
...@@ -31,6 +33,9 @@ from typing import List, Optional ...@@ -31,6 +33,9 @@ from typing import List, Optional
import hashlib import hashlib
from pathlib import Path from pathlib import Path
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
...@@ -111,6 +116,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -111,6 +116,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}" headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None
try: try:
r = requests.post( r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech", url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
...@@ -133,7 +139,7 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -133,7 +139,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -143,7 +149,9 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -143,7 +149,9 @@ async def speech(request: Request, user=Depends(get_verified_user)):
except: except:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException(status_code=r.status_code, detail=error_detail) raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
except ValueError: except ValueError:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
...@@ -157,7 +165,7 @@ async def fetch_url(url, key): ...@@ -157,7 +165,7 @@ async def fetch_url(url, key):
return await response.json() return await response.json()
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
print(f"Connection error: {e}") log.error(f"Connection error: {e}")
return None return None
...@@ -165,6 +173,7 @@ def merge_models_lists(model_lists): ...@@ -165,6 +173,7 @@ def merge_models_lists(model_lists):
merged_list = [] merged_list = []
for idx, models in enumerate(model_lists): for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend( merged_list.extend(
[ [
{**model, "urlIdx": idx} {**model, "urlIdx": idx}
...@@ -178,7 +187,7 @@ def merge_models_lists(model_lists): ...@@ -178,7 +187,7 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
print("get_all_models") log.info("get_all_models()")
if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "": if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "":
models = {"data": []} models = {"data": []}
...@@ -187,15 +196,24 @@ async def get_all_models(): ...@@ -187,15 +196,24 @@ async def get_all_models():
fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
responses = list(
filter(lambda x: x is not None and "error" not in x, responses)
)
models = { models = {
"data": merge_models_lists( "data": merge_models_lists(
list(map(lambda response: response["data"], responses)) list(
map(
lambda response: (
response["data"]
if response and "data" in response
else None
),
responses,
)
)
) )
} }
log.info(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]} app.state.MODELS = {model["id"]: model for model in models["data"]}
return models return models
...@@ -218,6 +236,9 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use ...@@ -218,6 +236,9 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return models return models
else: else:
url = app.state.OPENAI_API_BASE_URLS[url_idx] url = app.state.OPENAI_API_BASE_URLS[url_idx]
r = None
try: try:
r = requests.request(method="GET", url=f"{url}/models") r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status() r.raise_for_status()
...@@ -230,7 +251,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use ...@@ -230,7 +251,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return response_data return response_data
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -264,7 +285,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -264,7 +285,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
if body.get("model") == "gpt-4-vision-preview": if body.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in body: if "max_tokens" not in body:
body["max_tokens"] = 4000 body["max_tokens"] = 4000
print("Modified body_dict:", body) log.debug("Modified body_dict:", body)
# Fix for ChatGPT calls failing because the num_ctx key is in body # Fix for ChatGPT calls failing because the num_ctx key is in body
if "num_ctx" in body: if "num_ctx" in body:
...@@ -276,7 +297,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -276,7 +297,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
# Convert the modified body back to JSON # Convert the modified body back to JSON
body = json.dumps(body) body = json.dumps(body)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
print("Error loading request body into a dictionary:", e) log.error("Error loading request body into a dictionary:", e)
url = app.state.OPENAI_API_BASE_URLS[idx] url = app.state.OPENAI_API_BASE_URLS[idx]
key = app.state.OPENAI_API_KEYS[idx] key = app.state.OPENAI_API_KEYS[idx]
...@@ -290,6 +311,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -290,6 +311,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers["Authorization"] = f"Bearer {key}" headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None
try: try:
r = requests.request( r = requests.request(
method=request.method, method=request.method,
...@@ -312,7 +335,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -312,7 +335,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
response_data = r.json() response_data = r.json()
return response_data return response_data
except Exception as e: except Exception as e:
print(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
...@@ -322,4 +345,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -322,4 +345,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except: except:
error_detail = f"External: {e}" error_detail = f"External: {e}"
raise HTTPException(status_code=r.status_code, detail=error_detail) raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
...@@ -8,7 +8,7 @@ from fastapi import ( ...@@ -8,7 +8,7 @@ from fastapi import (
Form, Form,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import os, shutil import os, shutil, logging
from pathlib import Path from pathlib import Path
from typing import List from typing import List
...@@ -21,6 +21,7 @@ from langchain_community.document_loaders import ( ...@@ -21,6 +21,7 @@ from langchain_community.document_loaders import (
TextLoader, TextLoader,
PyPDFLoader, PyPDFLoader,
CSVLoader, CSVLoader,
BSHTMLLoader,
Docx2txtLoader, Docx2txtLoader,
UnstructuredEPubLoader, UnstructuredEPubLoader,
UnstructuredWordDocumentLoader, UnstructuredWordDocumentLoader,
...@@ -54,6 +55,7 @@ from utils.misc import ( ...@@ -54,6 +55,7 @@ from utils.misc import (
) )
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from config import ( from config import (
SRC_LOG_LEVELS,
UPLOAD_DIR, UPLOAD_DIR,
DOCS_DIR, DOCS_DIR,
RAG_EMBEDDING_MODEL, RAG_EMBEDDING_MODEL,
...@@ -66,6 +68,9 @@ from config import ( ...@@ -66,6 +68,9 @@ from config import (
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
# #
# if RAG_EMBEDDING_MODEL: # if RAG_EMBEDDING_MODEL:
# sentence_transformer_ef = SentenceTransformer( # sentence_transformer_ef = SentenceTransformer(
...@@ -111,39 +116,6 @@ class StoreWebForm(CollectionNameForm): ...@@ -111,39 +116,6 @@ class StoreWebForm(CollectionNameForm):
url: str url: str
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
)
docs = text_splitter.split_documents(data)
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name:
print(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
print(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
return False
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return { return {
...@@ -273,7 +245,7 @@ def query_doc_handler( ...@@ -273,7 +245,7 @@ def query_doc_handler(
embedding_function=app.state.sentence_transformer_ef, embedding_function=app.state.sentence_transformer_ef,
) )
except Exception as e: except Exception as e:
print(e) log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
...@@ -317,13 +289,69 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): ...@@ -317,13 +289,69 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
"filename": form_data.url, "filename": form_data.url,
} }
except Exception as e: except Exception as e:
print(e) log.exception(e)
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e), detail=ERROR_MESSAGES.DEFAULT(e),
) )
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.split_documents(data)
if len(docs) > 0:
return store_docs_in_vector_db(docs, collection_name, overwrite), None
else:
raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False
) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.create_documents([text], metadatas=[metadata])
return store_docs_in_vector_db(docs, collection_name, overwrite)
def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
if collection_name == collection.name:
log.info(f"deleting existing collection {collection_name}")
CHROMA_CLIENT.delete_collection(name=collection_name)
collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
)
return True
except Exception as e:
log.exception(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
return False
def get_loader(filename: str, file_content_type: str, file_path: str): def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower() file_ext = filename.split(".")[-1].lower()
known_type = True known_type = True
...@@ -381,6 +409,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ...@@ -381,6 +409,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
loader = UnstructuredRSTLoader(file_path, mode="elements") loader = UnstructuredRSTLoader(file_path, mode="elements")
elif file_ext == "xml": elif file_ext == "xml":
loader = UnstructuredXMLLoader(file_path) loader = UnstructuredXMLLoader(file_path)
elif file_ext in ["htm", "html"]:
loader = BSHTMLLoader(file_path, open_encoding="unicode_escape")
elif file_ext == "md": elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path) loader = UnstructuredMarkdownLoader(file_path)
elif file_content_type == "application/epub+zip": elif file_content_type == "application/epub+zip":
...@@ -399,9 +429,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ...@@ -399,9 +429,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
elif file_ext in known_source_ext or ( elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0 file_content_type and file_content_type.find("text/") >= 0
): ):
loader = TextLoader(file_path) loader = TextLoader(file_path, autodetect_encoding=True)
else: else:
loader = TextLoader(file_path) loader = TextLoader(file_path, autodetect_encoding=True)
known_type = False known_type = False
return loader, known_type return loader, known_type
...@@ -415,7 +445,7 @@ def store_doc( ...@@ -415,7 +445,7 @@ def store_doc(
): ):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
print(file.content_type) log.info(f"file.content_type: {file.content_type}")
try: try:
filename = file.filename filename = file.filename
file_path = f"{UPLOAD_DIR}/{filename}" file_path = f"{UPLOAD_DIR}/{filename}"
...@@ -431,6 +461,8 @@ def store_doc( ...@@ -431,6 +461,8 @@ def store_doc(
loader, known_type = get_loader(file.filename, file.content_type, file_path) loader, known_type = get_loader(file.filename, file.content_type, file_path)
data = loader.load() data = loader.load()
try:
result = store_data_in_vector_db(data, collection_name) result = store_data_in_vector_db(data, collection_name)
if result: if result:
...@@ -440,13 +472,13 @@ def store_doc( ...@@ -440,13 +472,13 @@ def store_doc(
"filename": filename, "filename": filename,
"known_type": known_type, "known_type": known_type,
} }
else: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(), detail=e,
) )
except Exception as e: except Exception as e:
print(e) log.exception(e)
if "No pandoc was found" in str(e): if "No pandoc was found" in str(e):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
...@@ -459,6 +491,37 @@ def store_doc( ...@@ -459,6 +491,37 @@ def store_doc(
) )
class TextRAGForm(BaseModel):
name: str
content: str
collection_name: Optional[str] = None
@app.post("/text")
def store_text(
form_data: TextRAGForm,
user=Depends(get_current_user),
):
collection_name = form_data.collection_name
if collection_name == None:
collection_name = calculate_sha256_string(form_data.content)
result = store_text_in_vector_db(
form_data.content,
metadata={"name": form_data.name, "created_by": user.id},
collection_name=collection_name,
)
if result:
return {"status": True, "collection_name": collection_name}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
@app.get("/scan") @app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)): def scan_docs_dir(user=Depends(get_admin_user)):
for path in Path(DOCS_DIR).rglob("./**/*"): for path in Path(DOCS_DIR).rglob("./**/*"):
...@@ -477,6 +540,7 @@ def scan_docs_dir(user=Depends(get_admin_user)): ...@@ -477,6 +540,7 @@ def scan_docs_dir(user=Depends(get_admin_user)):
) )
data = loader.load() data = loader.load()
try:
result = store_data_in_vector_db(data, collection_name) result = store_data_in_vector_db(data, collection_name)
if result: if result:
...@@ -509,9 +573,12 @@ def scan_docs_dir(user=Depends(get_admin_user)): ...@@ -509,9 +573,12 @@ def scan_docs_dir(user=Depends(get_admin_user)):
} }
), ),
) )
except Exception as e:
log.exception(e)
pass
except Exception as e: except Exception as e:
print(e) log.exception(e)
return True return True
...@@ -532,11 +599,11 @@ def reset(user=Depends(get_admin_user)) -> bool: ...@@ -532,11 +599,11 @@ def reset(user=Depends(get_admin_user)) -> bool:
elif os.path.isdir(file_path): elif os.path.isdir(file_path):
shutil.rmtree(file_path) shutil.rmtree(file_path)
except Exception as e: except Exception as e:
print("Failed to delete %s. Reason: %s" % (file_path, e)) log.error("Failed to delete %s. Reason: %s" % (file_path, e))
try: try:
CHROMA_CLIENT.reset() CHROMA_CLIENT.reset()
except Exception as e: except Exception as e:
print(e) log.exception(e)
return True return True
import re import re
import logging
from typing import List from typing import List
from config import CHROMA_CLIENT from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def query_doc(collection_name: str, query: str, k: int, embedding_function): def query_doc(collection_name: str, query: str, k: int, embedding_function):
...@@ -91,14 +95,13 @@ def query_collection( ...@@ -91,14 +95,13 @@ def query_collection(
def rag_template(template: str, context: str, query: str): def rag_template(template: str, context: str, query: str):
template = re.sub(r"\[context\]", context, template) template = template.replace("[context]", context)
template = re.sub(r"\[query\]", query, template) template = template.replace("[query]", query)
return template return template
def rag_messages(docs, messages, template, k, embedding_function): def rag_messages(docs, messages, template, k, embedding_function):
print(docs) log.debug(f"docs: {docs}")
last_user_message_idx = None last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1): for i in range(len(messages) - 1, -1, -1):
...@@ -138,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function): ...@@ -138,6 +141,8 @@ def rag_messages(docs, messages, template, k, embedding_function):
k=k, k=k,
embedding_function=embedding_function, embedding_function=embedding_function,
) )
elif doc["type"] == "text":
context = doc["content"]
else: else:
context = query_doc( context = query_doc(
collection_name=doc["collection_name"], collection_name=doc["collection_name"],
...@@ -146,11 +151,13 @@ def rag_messages(docs, messages, template, k, embedding_function): ...@@ -146,11 +151,13 @@ def rag_messages(docs, messages, template, k, embedding_function):
embedding_function=embedding_function, embedding_function=embedding_function,
) )
except Exception as e: except Exception as e:
print(e) log.exception(e)
context = None context = None
relevant_contexts.append(context) relevant_contexts.append(context)
log.debug(f"relevant_contexts: {relevant_contexts}")
context_string = "" context_string = ""
for context in relevant_contexts: for context in relevant_contexts:
if context: if context:
......
from peewee import * from peewee import *
from config import DATA_DIR from config import SRC_LOG_LEVELS, DATA_DIR
import os import os
import logging
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
# Check if the file exists # Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"): if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file # Rename the file
os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db")
print("File renamed successfully.") log.info("File renamed successfully.")
else: else:
pass pass
......
...@@ -19,6 +19,7 @@ from config import ( ...@@ -19,6 +19,7 @@ from config import (
DEFAULT_USER_ROLE, DEFAULT_USER_ROLE,
ENABLE_SIGNUP, ENABLE_SIGNUP,
USER_PERMISSIONS, USER_PERMISSIONS,
WEBHOOK_URL,
) )
app = FastAPI() app = FastAPI()
...@@ -32,6 +33,7 @@ app.state.DEFAULT_MODELS = DEFAULT_MODELS ...@@ -32,6 +33,7 @@ app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS app.state.USER_PERMISSIONS = USER_PERMISSIONS
app.state.WEBHOOK_URL = WEBHOOK_URL
app.add_middleware( app.add_middleware(
......
...@@ -2,6 +2,7 @@ from pydantic import BaseModel ...@@ -2,6 +2,7 @@ from pydantic import BaseModel
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
import uuid import uuid
import logging
from peewee import * from peewee import *
from apps.web.models.users import UserModel, Users from apps.web.models.users import UserModel, Users
...@@ -9,6 +10,11 @@ from utils.utils import verify_password ...@@ -9,6 +10,11 @@ from utils.utils import verify_password
from apps.web.internal.db import DB from apps.web.internal.db import DB
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# DB MODEL # DB MODEL
#################### ####################
...@@ -86,7 +92,7 @@ class AuthsTable: ...@@ -86,7 +92,7 @@ class AuthsTable:
def insert_new_auth( def insert_new_auth(
self, email: str, password: str, name: str, role: str = "pending" self, email: str, password: str, name: str, role: str = "pending"
) -> Optional[UserModel]: ) -> Optional[UserModel]:
print("insert_new_auth") log.info("insert_new_auth")
id = str(uuid.uuid4()) id = str(uuid.uuid4())
...@@ -103,7 +109,7 @@ class AuthsTable: ...@@ -103,7 +109,7 @@ class AuthsTable:
return None return None
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]: def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
print("authenticate_user", email) log.info(f"authenticate_user: {email}")
try: try:
auth = Auth.get(Auth.email == email, Auth.active == True) auth = Auth.get(Auth.email == email, Auth.active == True)
if auth: if auth:
......
...@@ -95,20 +95,6 @@ class ChatTable: ...@@ -95,20 +95,6 @@ class ChatTable:
except: except:
return None return None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
query = Chat.update(
chat=json.dumps(chat),
title=chat["title"] if "title" in chat else "New Chat",
timestamp=int(time.time()),
).where(Chat.id == id)
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chat_lists_by_user_id( def get_chat_lists_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
......
...@@ -3,6 +3,7 @@ from peewee import * ...@@ -3,6 +3,7 @@ from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
import logging
from utils.utils import decode_token from utils.utils import decode_token
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
...@@ -11,6 +12,11 @@ from apps.web.internal.db import DB ...@@ -11,6 +12,11 @@ from apps.web.internal.db import DB
import json import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
#################### ####################
# Documents DB Schema # Documents DB Schema
#################### ####################
...@@ -118,7 +124,7 @@ class DocumentsTable: ...@@ -118,7 +124,7 @@ class DocumentsTable:
doc = Document.get(Document.name == form_data.name) doc = Document.get(Document.name == form_data.name)
return DocumentModel(**model_to_dict(doc)) return DocumentModel(**model_to_dict(doc))
except Exception as e: except Exception as e:
print(e) log.exception(e)
return None return None
def update_doc_content_by_name( def update_doc_content_by_name(
...@@ -138,7 +144,7 @@ class DocumentsTable: ...@@ -138,7 +144,7 @@ class DocumentsTable:
doc = Document.get(Document.name == name) doc = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(doc)) return DocumentModel(**model_to_dict(doc))
except Exception as e: except Exception as e:
print(e) log.exception(e)
return None return None
def delete_doc_by_name(self, name: str) -> bool: def delete_doc_by_name(self, name: str) -> bool:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment