"src/vscode:/vscode.git/clone" did not exist on "f6d4bd3a53e249b98380f59e80e112ac49a84443"
Unverified Commit 437d7ff6 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #897 from open-webui/main

dev
parents 02f364bf 81eceb48
.github
.DS_Store .DS_Store
docs
kubernetes
node_modules node_modules
/.svelte-kit /.svelte-kit
/package /package
......
...@@ -3,4 +3,10 @@ ...@@ -3,4 +3,10 @@
OLLAMA_API_BASE_URL='http://localhost:11434/api' OLLAMA_API_BASE_URL='http://localhost:11434/api'
OPENAI_API_BASE_URL='' OPENAI_API_BASE_URL=''
OPENAI_API_KEY='' OPENAI_API_KEY=''
\ No newline at end of file
# AUTOMATIC1111_BASE_URL="http://localhost:7860"
# DO NOT TRACK
SCARF_NO_ANALYTICS=true
DO_NOT_TRACK=true
\ No newline at end of file
*.sh text eol=lf
\ No newline at end of file
## Pull Request Checklist
- [ ] **Description:** Briefly describe the changes in this pull request.
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
- [ ] **Documentation:** Have you updated relevant documentation?
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
---
## Description
[Insert a brief description of the changes made in this pull request]
---
### Changelog Entry
### Added
- [List any new features or additions]
### Fixed
- [List any fixes or corrections]
### Changed
- [List any changes or updates]
### Removed
- [List any removed features or files]
name: Release
on:
push:
branches:
- main # or whatever branch you want to use
jobs:
release:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v2
- name: Check for changes in package.json
run: |
git diff --cached --diff-filter=d package.json || {
echo "No changes to package.json"
exit 1
}
- name: Get version number from package.json
id: get_version
run: |
VERSION=$(jq -r '.version' package.json)
echo "::set-output name=version::$VERSION"
- name: Create GitHub release
uses: actions/github-script@v5
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const release = await github.rest.repos.createRelease({
owner: context.repo.owner,
repo: context.repo.repo,
tag_name: `v${{ steps.get_version.outputs.version }}`,
name: `v${{ steps.get_version.outputs.version }}`,
body: 'Automatically created new release',
})
console.log(`Created release ${release.data.html_url}`)
- name: Upload package to GitHub release
uses: actions/upload-artifact@v3
with:
name: package
path: .
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
...@@ -40,15 +40,21 @@ jobs: ...@@ -40,15 +40,21 @@ jobs:
registry: ${{ env.REGISTRY }} registry: ${{ env.REGISTRY }}
username: ${{ github.actor }} username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
# This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels.
- name: Extract metadata (tags, labels) for Docker - name: Extract metadata for Docker images
id: meta id: meta
uses: docker/metadata-action@v5 uses: docker/metadata-action@v5
with: with:
images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}
# This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages. # This configuration dynamically generates tags based on the branch, tag, commit, and custom suffix for lite version.
# It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository. tags: |
# It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step. type=ref,event=branch
type=ref,event=tag
type=sha,prefix=git-
type=semver,pattern={{version}}
flavor: |
latest=${{ github.ref == 'refs/heads/main' }}
- name: Build and push Docker image - name: Build and push Docker image
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
......
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.102] - 2024-02-22
### Added
- **🖼️ Image Generation**: Generate Images using the AUTOMATIC1111/stable-diffusion-webui API. You can set this up in Settings > Images.
- **📝 Change title generation prompt**: Change the prompt used to generate titles for your chats. You can set this up in the Settings > Interface.
- **🤖 Change embedding model**: Change the embedding model used to generate embeddings for your chats in the Dockerfile. Use any sentence transformer model from huggingface.co.
- **📢 CHANGELOG.md/Popup**: This popup will show you the latest changes.
## [0.1.101] - 2024-02-22
### Fixed
- LaTex output formatting issue (#828)
### Changed
- Instead of having the previous 1.0.0-alpha.101, we switched to semantic versioning as a way to respect global conventions.
...@@ -5,9 +5,10 @@ FROM node:alpine as build ...@@ -5,9 +5,10 @@ FROM node:alpine as build
WORKDIR /app WORKDIR /app
# wget embedding model weight from alpine (does not exist from slim-buster) # wget embedding model weight from alpine (does not exist from slim-buster)
RUN wget "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz" RUN wget "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz" -O - | \
tar -xzf - -C /app
COPY package.json package-lock.json ./ COPY package.json package-lock.json ./
RUN npm ci RUN npm ci
COPY . . COPY . .
...@@ -17,35 +18,65 @@ RUN npm run build ...@@ -17,35 +18,65 @@ RUN npm run build
FROM python:3.11-slim-bookworm as base FROM python:3.11-slim-bookworm as base
ENV ENV=prod ENV ENV=prod
ENV PORT ""
ENV OLLAMA_API_BASE_URL "/ollama/api" ENV OLLAMA_API_BASE_URL "/ollama/api"
ENV OPENAI_API_BASE_URL "" ENV OPENAI_API_BASE_URL ""
ENV OPENAI_API_KEY "" ENV OPENAI_API_KEY ""
ENV WEBUI_JWT_SECRET_KEY "SECRET_KEY" ENV WEBUI_SECRET_KEY ""
WORKDIR /app ENV SCARF_NO_ANALYTICS true
ENV DO_NOT_TRACK true
# copy embedding weight from build ######## Preloaded models ########
RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2 # whisper TTS Settings
COPY --from=build /app/onnx.tar.gz /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2 ENV WHISPER_MODEL="base"
ENV WHISPER_MODEL_DIR="/app/backend/data/cache/whisper/models"
RUN cd /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2 &&\ # RAG Embedding Model Settings
tar -xzf onnx.tar.gz # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
# for better persormance and multilangauge support use "intfloat/multilingual-e5-large" (~2.5GB) or "intfloat/multilingual-e5-base" (~1.5GB)
# IMPORTANT: If you change the default model (all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ENV RAG_EMBEDDING_MODEL="all-MiniLM-L6-v2"
# device type for whisper tts and ebbeding models - "cpu" (default), "cuda" (nvidia gpu and CUDA required) or "mps" (apple silicon) - choosing this right can lead to better performance
ENV RAG_EMBEDDING_MODEL_DEVICE_TYPE="cpu"
ENV RAG_EMBEDDING_MODEL_DIR="/app/backend/data/cache/embedding/models"
ENV SENTENCE_TRANSFORMERS_HOME $RAG_EMBEDDING_MODEL_DIR
# copy built frontend files ######## Preloaded models ########
COPY --from=build /app/build /app/build
WORKDIR /app/backend WORKDIR /app/backend
# install python dependencies
COPY ./backend/requirements.txt ./requirements.txt COPY ./backend/requirements.txt ./requirements.txt
RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu RUN pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu --no-cache-dir
RUN pip3 install -r requirements.txt RUN pip3 install -r requirements.txt --no-cache-dir
# Install pandoc and netcat
# RUN python -c "import pypandoc; pypandoc.download_pandoc()"
RUN apt-get update \
&& apt-get install -y pandoc netcat-openbsd \
&& rm -rf /var/lib/apt/lists/*
# RUN python -c "from sentence_transformers import SentenceTransformer; model = SentenceTransformer('all-MiniLM-L6-v2')" # preload embedding model
RUN python -c "import os; from chromadb.utils import embedding_functions; sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name=os.environ['RAG_EMBEDDING_MODEL'], device=os.environ['RAG_EMBEDDING_MODEL_DEVICE_TYPE'])"
# preload tts model
RUN python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='auto', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"
# copy embedding weight from build
RUN mkdir -p /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2
COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
# copy built frontend files
COPY --from=build /app/build /app/build
COPY --from=build /app/CHANGELOG.md /app/CHANGELOG.md
COPY --from=build /app/package.json /app/package.json
# copy backend files
COPY ./backend . COPY ./backend .
CMD [ "sh", "start.sh"] CMD [ "bash", "start.sh"]
\ No newline at end of file \ No newline at end of file
### Installing Both Ollama and Ollama Web UI Using Kustomize ### Installing Both Ollama and Open WebUI Using Kustomize
For cpu-only pod For cpu-only pod
...@@ -12,7 +12,7 @@ For gpu-enabled pod ...@@ -12,7 +12,7 @@ For gpu-enabled pod
kubectl apply -k ./kubernetes/manifest kubectl apply -k ./kubernetes/manifest
``` ```
### Installing Both Ollama and Ollama Web UI Using Helm ### Installing Both Ollama and Open WebUI Using Helm
Package Helm file first Package Helm file first
......
This diff is collapsed.
# Ollama Web UI Troubleshooting Guide # Open WebUI Troubleshooting Guide
## Understanding the Ollama WebUI Architecture ## Understanding the Open WebUI Architecture
The Ollama WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues. The Open WebUI system is designed to streamline interactions between the client (your browser) and the Ollama API. At the heart of this design is a backend reverse proxy, enhancing security and resolving CORS issues.
- **How it Works**: The Ollama WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Ollama WebUI backend via `/ollama/api` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_API_BASE_URL` environment variable. Therefore, a request made to `/ollama/api` in the WebUI is effectively the same as making a request to `OLLAMA_API_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_API_BASE_URL/tags` in the backend. - **How it Works**: The Open WebUI is designed to interact with the Ollama API through a specific route. When a request is made from the WebUI to Ollama, it is not directly sent to the Ollama API. Initially, the request is sent to the Open WebUI backend via `/ollama/api` route. From there, the backend is responsible for forwarding the request to the Ollama API. This forwarding is accomplished by using the route specified in the `OLLAMA_API_BASE_URL` environment variable. Therefore, a request made to `/ollama/api` in the WebUI is effectively the same as making a request to `OLLAMA_API_BASE_URL` in the backend. For instance, a request to `/ollama/api/tags` in the WebUI is equivalent to `OLLAMA_API_BASE_URL/tags` in the backend.
- **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer. - **Security Benefits**: This design prevents direct exposure of the Ollama API to the frontend, safeguarding against potential CORS (Cross-Origin Resource Sharing) issues and unauthorized access. Requiring authentication to access the Ollama API further enhances this security layer.
## Ollama WebUI: Server Connection Error ## Open WebUI: Server Connection Error
If you're experiencing connection issues, it’s often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`. If you're experiencing connection issues, it’s often due to the WebUI docker container not being able to reach the Ollama server at 127.0.0.1:11434 (host.docker.internal:11434) inside the container . Use the `--network=host` flag in your docker command to resolve this. Note that the port changes from 3000 to 8080, resulting in the link: `http://localhost:8080`.
**Example Docker Command**: **Example Docker Command**:
```bash ```bash
docker run -d --network=host -v ollama-webui:/app/backend/data -e OLLAMA_API_BASE_URL=http://127.0.0.1:11434/api --name ollama-webui --restart always ghcr.io/ollama-webui/ollama-webui:main docker run -d --network=host -v open-webui:/app/backend/data -e OLLAMA_API_BASE_URL=http://127.0.0.1:11434/api --name open-webui --restart always ghcr.io/open-webui/open-webui:main
``` ```
### General Connection Errors ### General Connection Errors
**Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.ai/) for the latest updates. **Ensure Ollama Version is Up-to-Date**: Always start by checking that you have the latest version of Ollama. Visit [Ollama's official site](https://ollama.com/) for the latest updates.
**Troubleshooting Steps**: **Troubleshooting Steps**:
1. **Verify Ollama URL Format**: 1. **Verify Ollama URL Format**:
- When running the Web UI container, ensure the `OLLAMA_API_BASE_URL` is correctly set, including the `/api` suffix. (e.g., `http://192.168.1.1:11434/api` for different host setups). - When running the Web UI container, ensure the `OLLAMA_API_BASE_URL` is correctly set, including the `/api` suffix. (e.g., `http://192.168.1.1:11434/api` for different host setups).
- In the Ollama WebUI, navigate to "Settings" > "General". - In the Open WebUI, navigate to "Settings" > "General".
- Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]/api` (e.g., `http://localhost:11434/api`), including the `/api` suffix. - Confirm that the Ollama Server URL is correctly set to `[OLLAMA URL]/api` (e.g., `http://localhost:11434/api`), including the `/api` suffix.
By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord. By following these enhanced troubleshooting steps, connection issues should be effectively resolved. For further assistance or queries, feel free to reach out to us on our community Discord.
...@@ -6,4 +6,6 @@ uploads ...@@ -6,4 +6,6 @@ uploads
*.db *.db
_test _test
Pipfile Pipfile
data/* data/*
\ No newline at end of file !data/config.json
.webui_secret_key
\ No newline at end of file
import os
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
UploadFile,
File,
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
get_current_user,
get_verified_user,
get_admin_user,
)
from utils.misc import calculate_sha256
from config import CACHE_DIR, UPLOAD_DIR, WHISPER_MODEL, WHISPER_MODEL_DIR
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.post("/transcribe")
def transcribe(
file: UploadFile = File(...),
user=Depends(get_current_user),
):
print(file.content_type)
if file.content_type not in ["audio/mpeg", "audio/wav"]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
)
try:
filename = file.filename
file_path = f"{UPLOAD_DIR}/{filename}"
contents = file.file.read()
with open(file_path, "wb") as f:
f.write(contents)
f.close()
model = WhisperModel(
WHISPER_MODEL,
device="auto",
compute_type="int8",
download_root=WHISPER_MODEL_DIR,
)
segments, info = model.transcribe(file_path, beam_size=5)
print(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
transcript = "".join([segment.text for segment in list(segments)])
return {"text": transcript.strip()}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
import re
import requests
from fastapi import (
FastAPI,
Request,
Depends,
HTTPException,
status,
UploadFile,
File,
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
get_current_user,
get_admin_user,
)
from utils.misc import calculate_sha256
from typing import Optional
from pydantic import BaseModel
from config import AUTOMATIC1111_BASE_URL
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.ENABLED = app.state.AUTOMATIC1111_BASE_URL != ""
app.state.IMAGE_SIZE = "512x512"
@app.get("/enabled", response_model=bool)
async def get_enable_status(request: Request, user=Depends(get_admin_user)):
return app.state.ENABLED
@app.get("/enabled/toggle", response_model=bool)
async def toggle_enabled(request: Request, user=Depends(get_admin_user)):
try:
r = requests.head(app.state.AUTOMATIC1111_BASE_URL)
app.state.ENABLED = not app.state.ENABLED
return app.state.ENABLED
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
class UrlUpdateForm(BaseModel):
url: str
@app.get("/url")
async def get_openai_url(user=Depends(get_admin_user)):
return {"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL}
@app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
if form_data.url == "":
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else:
app.state.AUTOMATIC1111_BASE_URL = form_data.url.strip("/")
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"status": True,
}
class ImageSizeUpdateForm(BaseModel):
size: str
@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
@app.post("/size/update")
async def update_image_size(
form_data: ImageSizeUpdateForm, user=Depends(get_admin_user)
):
pattern = r"^\d+x\d+$" # Regular expression pattern
if re.match(pattern, form_data.size):
app.state.IMAGE_SIZE = form_data.size
return {
"IMAGE_SIZE": app.state.IMAGE_SIZE,
"status": True,
}
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.INCORRECT_FORMAT(" (e.g., 512x512)."),
)
@app.get("/models")
def get_models(user=Depends(get_current_user)):
try:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models")
models = r.json()
return models
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
try:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
return {"model": options["sd_model_checkpoint"]}
except Exception as e:
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
class UpdateModelForm(BaseModel):
model: str
def set_model_handler(model: str):
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
)
return options
@app.post("/models/default/update")
def update_default_model(
form_data: UpdateModelForm,
user=Depends(get_current_user),
):
return set_model_handler(form_data.model)
class GenerateImageForm(BaseModel):
model: Optional[str] = None
prompt: str
n: int = 1
size: str = "512x512"
negative_prompt: Optional[str] = None
@app.post("/generations")
def generate_image(
form_data: GenerateImageForm,
user=Depends(get_current_user),
):
print(form_data)
try:
if form_data.model:
set_model_handler(form_data.model)
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
data = {
"prompt": form_data.prompt,
"batch_size": form_data.n,
"width": width,
"height": height,
}
if form_data.negative_prompt != None:
data["negative_prompt"] = form_data.negative_prompt
print(data)
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
)
return r.json()
except Exception as e:
print(e)
raise HTTPException(status_code=r.status_code, detail=ERROR_MESSAGES.DEFAULT(e))
from fastapi import FastAPI, Request, Response, HTTPException, Depends from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool from fastapi.concurrency import run_in_threadpool
import requests import requests
import json import json
import uuid
from pydantic import BaseModel from pydantic import BaseModel
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user from utils.utils import decode_token, get_current_user, get_admin_user
from config import OLLAMA_API_BASE_URL, WEBUI_AUTH from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
app = FastAPI() app = FastAPI()
...@@ -26,12 +27,12 @@ app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL ...@@ -26,12 +27,12 @@ app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
# TARGET_SERVER_URL = OLLAMA_API_BASE_URL # TARGET_SERVER_URL = OLLAMA_API_BASE_URL
REQUEST_POOL = []
@app.get("/url") @app.get("/url")
async def get_ollama_api_url(user=Depends(get_current_user)): async def get_ollama_api_url(user=Depends(get_admin_user)):
if user and user.role == "admin": return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
class UrlUpdateForm(BaseModel): class UrlUpdateForm(BaseModel):
...@@ -39,12 +40,17 @@ class UrlUpdateForm(BaseModel): ...@@ -39,12 +40,17 @@ class UrlUpdateForm(BaseModel):
@app.post("/url/update") @app.post("/url/update")
async def update_ollama_api_url( async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
form_data: UrlUpdateForm, user=Depends(get_current_user) app.state.OLLAMA_API_BASE_URL = form_data.url
): return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
if user and user.role == "admin":
app.state.OLLAMA_API_BASE_URL = form_data.url
return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL} @app.get("/cancel/{request_id}")
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
if user:
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
return True
else: else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
...@@ -60,21 +66,45 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): ...@@ -60,21 +66,45 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
if path in ["pull", "delete", "push", "copy", "create"]: if path in ["pull", "delete", "push", "copy", "create"]:
if user.role != "admin": if user.role != "admin":
raise HTTPException( raise HTTPException(
status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
else: else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED) raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
headers.pop("Host", None) headers.pop("host", None)
headers.pop("Authorization", None) headers.pop("authorization", None)
headers.pop("Origin", None) headers.pop("origin", None)
headers.pop("Referer", None) headers.pop("referer", None)
r = None r = None
def get_request(): def get_request():
nonlocal r nonlocal r
request_id = str(uuid.uuid4())
try: try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if path in ["chat"]:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
print("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
REQUEST_POOL.remove(request_id)
r = requests.request( r = requests.request(
method=request.method, method=request.method,
url=target_url, url=target_url,
...@@ -85,8 +115,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): ...@@ -85,8 +115,10 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
r.raise_for_status() r.raise_for_status()
# r.close()
return StreamingResponse( return StreamingResponse(
r.iter_content(chunk_size=8192), stream_content(),
status_code=r.status_code, status_code=r.status_code,
headers=dict(r.headers), headers=dict(r.headers),
) )
...@@ -96,7 +128,7 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): ...@@ -96,7 +128,7 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
try: try:
return await run_in_threadpool(get_request) return await run_in_threadpool(get_request)
except Exception as e: except Exception as e:
error_detail = "Ollama WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
res = r.json() res = r.json()
......
...@@ -61,7 +61,7 @@ async def update_ollama_api_url( ...@@ -61,7 +61,7 @@ async def update_ollama_api_url(
# yield line # yield line
# except Exception as e: # except Exception as e:
# print(e) # print(e)
# error_detail = "Ollama WebUI: Server Connection Error" # error_detail = "Open WebUI: Server Connection Error"
# yield json.dumps({"error": error_detail, "message": str(e)}).encode() # yield json.dumps({"error": error_detail, "message": str(e)}).encode()
...@@ -110,7 +110,7 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): ...@@ -110,7 +110,7 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
except Exception as e: except Exception as e:
print(e) print(e)
error_detail = "Ollama WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if response is not None: if response is not None:
try: try:
......
from fastapi import FastAPI, Request, Response, HTTPException, Depends from fastapi import FastAPI, Request, Response, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
import requests import requests
import json import json
from pydantic import BaseModel from pydantic import BaseModel
from apps.web.models.users import Users from apps.web.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import decode_token, get_current_user from utils.utils import (
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY decode_token,
get_current_user,
get_verified_user,
get_admin_user,
)
from config import OPENAI_API_BASE_URL, OPENAI_API_KEY, CACHE_DIR
import hashlib
from pathlib import Path
app = FastAPI() app = FastAPI()
app.add_middleware( app.add_middleware(
...@@ -33,60 +42,114 @@ class KeyUpdateForm(BaseModel): ...@@ -33,60 +42,114 @@ class KeyUpdateForm(BaseModel):
@app.get("/url") @app.get("/url")
async def get_openai_url(user=Depends(get_current_user)): async def get_openai_url(user=Depends(get_admin_user)):
if user and user.role == "admin": return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/url/update") @app.post("/url/update")
async def update_openai_url(form_data: UrlUpdateForm, async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
user=Depends(get_current_user)): app.state.OPENAI_API_BASE_URL = form_data.url
if user and user.role == "admin": return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
app.state.OPENAI_API_BASE_URL = form_data.url
return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.get("/key") @app.get("/key")
async def get_openai_key(user=Depends(get_current_user)): async def get_openai_key(user=Depends(get_admin_user)):
if user and user.role == "admin": return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
@app.post("/key/update") @app.post("/key/update")
async def update_openai_key(form_data: KeyUpdateForm, async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_admin_user)):
user=Depends(get_current_user)): app.state.OPENAI_API_KEY = form_data.key
if user and user.role == "admin": return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
app.state.OPENAI_API_KEY = form_data.key
return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
else: @app.post("/audio/speech")
raise HTTPException(status_code=401, async def speech(request: Request, user=Depends(get_verified_user)):
detail=ERROR_MESSAGES.ACCESS_PROHIBITED) target_url = f"{app.state.OPENAI_API_BASE_URL}/audio/speech"
if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
body = await request.body()
name = hashlib.sha256(body).hexdigest()
SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
file_path = SPEECH_CACHE_DIR.joinpath(f"{name}.mp3")
file_body_path = SPEECH_CACHE_DIR.joinpath(f"{name}.json")
# Check if the file already exists in the cache
if file_path.is_file():
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
try:
print("openai")
r = requests.post(
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Save the streaming content to a file
with open(file_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
with open(file_body_path, "w") as f:
json.dump(json.loads(body.decode("utf-8")), f)
# Return the saved file
return FileResponse(file_path)
except Exception as e:
print(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status_code, detail=error_detail)
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"]) @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_current_user)): async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}" target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
print(target_url, app.state.OPENAI_API_KEY) print(target_url, app.state.OPENAI_API_KEY)
if user.role not in ["user", "admin"]:
raise HTTPException(status_code=401,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
if app.state.OPENAI_API_KEY == "": if app.state.OPENAI_API_KEY == "":
raise HTTPException(status_code=401, raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
body = await request.body() body = await request.body()
# headers = dict(request.headers)
# print(headers) # TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
try:
body = body.decode("utf-8")
body = json.loads(body)
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if body.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in body:
body["max_tokens"] = 4000
print("Modified body_dict:", body)
# Convert the modified body back to JSON
body = json.dumps(body)
except json.JSONDecodeError as e:
print("Error loading request body into a dictionary:", e)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
...@@ -121,17 +184,15 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)): ...@@ -121,17 +184,15 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
response_data = r.json() response_data = r.json()
print(type(response_data))
if "openai" in app.state.OPENAI_API_BASE_URL and path == "models": if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
response_data["data"] = list( response_data["data"] = list(
filter(lambda model: "gpt" in model["id"], filter(lambda model: "gpt" in model["id"], response_data["data"])
response_data["data"])) )
return response_data return response_data
except Exception as e: except Exception as e:
print(e) print(e)
error_detail = "Ollama WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
if r is not None: if r is not None:
try: try:
res = r.json() res = r.json()
......
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request,
Depends, Depends,
HTTPException, HTTPException,
status, status,
...@@ -11,7 +10,11 @@ from fastapi import ( ...@@ -11,7 +10,11 @@ from fastapi import (
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import os, shutil import os, shutil
# from chromadb.utils import embedding_functions from pathlib import Path
from typing import List
from sentence_transformers import SentenceTransformer
from chromadb.utils import embedding_functions
from langchain_community.document_loaders import ( from langchain_community.document_loaders import (
WebBaseLoader, WebBaseLoader,
...@@ -19,29 +22,71 @@ from langchain_community.document_loaders import ( ...@@ -19,29 +22,71 @@ from langchain_community.document_loaders import (
PyPDFLoader, PyPDFLoader,
CSVLoader, CSVLoader,
Docx2txtLoader, Docx2txtLoader,
UnstructuredEPubLoader,
UnstructuredWordDocumentLoader,
UnstructuredMarkdownLoader,
UnstructuredXMLLoader,
UnstructuredRSTLoader,
UnstructuredExcelLoader,
) )
from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from pydantic import BaseModel from pydantic import BaseModel
from typing import Optional from typing import Optional
import mimetypes
import uuid import uuid
import time import json
from apps.web.models.documents import (
Documents,
DocumentForm,
DocumentResponse,
)
from utils.misc import (
calculate_sha256,
calculate_sha256_string,
sanitize_filename,
extract_folders_after_data_docs,
)
from utils.utils import get_current_user, get_admin_user
from config import (
UPLOAD_DIR,
DOCS_DIR,
RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_DEVICE_TYPE,
CHROMA_CLIENT,
CHUNK_SIZE,
CHUNK_OVERLAP,
RAG_TEMPLATE,
)
from utils.misc import calculate_sha256
from utils.utils import get_current_user
from config import UPLOAD_DIR, EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
# EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction( #
# model_name=EMBED_MODEL # if RAG_EMBEDDING_MODEL:
# ) # sentence_transformer_ef = SentenceTransformer(
# model_name_or_path=RAG_EMBEDDING_MODEL,
# cache_folder=RAG_EMBEDDING_MODEL_DIR,
# device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
# )
app = FastAPI() app = FastAPI()
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
origins = ["*"] origins = ["*"]
app.add_middleware( app.add_middleware(
...@@ -63,7 +108,7 @@ class StoreWebForm(CollectionNameForm): ...@@ -63,7 +108,7 @@ class StoreWebForm(CollectionNameForm):
def store_data_in_vector_db(data, collection_name) -> bool: def store_data_in_vector_db(data, collection_name) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
) )
docs = text_splitter.split_documents(data) docs = text_splitter.split_documents(data)
...@@ -71,7 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool: ...@@ -71,7 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
metadatas = [doc.metadata for doc in docs] metadatas = [doc.metadata for doc in docs]
try: try:
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
collection.add( collection.add(
documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts] documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
...@@ -87,22 +135,112 @@ def store_data_in_vector_db(data, collection_name) -> bool: ...@@ -87,22 +135,112 @@ def store_data_in_vector_db(data, collection_name) -> bool:
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return {"status": True} return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
@app.get("/query/{collection_name}") @app.get("/embedding/model")
def query_collection( async def get_embedding_model(user=Depends(get_admin_user)):
collection_name: str, return {
query: str, "status": True,
k: Optional[int] = 4, "embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
class EmbeddingModelUpdateForm(BaseModel):
embedding_model: str
@app.post("/embedding/model/update")
async def update_embedding_model(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.sentence_transformer_ef = (
embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=app.state.RAG_EMBEDDING_MODEL,
device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
)
)
return {
"status": True,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
}
@app.get("/chunk")
async def get_chunk_params(user=Depends(get_admin_user)):
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}
class ChunkParamUpdateForm(BaseModel):
chunk_size: int
chunk_overlap: int
@app.post("/chunk/update")
async def update_chunk_params(
form_data: ChunkParamUpdateForm, user=Depends(get_admin_user)
):
app.state.CHUNK_SIZE = form_data.chunk_size
app.state.CHUNK_OVERLAP = form_data.chunk_overlap
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
}
@app.get("/template")
async def get_rag_template(user=Depends(get_current_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
}
class RAGTemplateForm(BaseModel):
template: str
@app.post("/template/update")
async def update_rag_template(form_data: RAGTemplateForm, user=Depends(get_admin_user)):
# TODO: check template requirements
app.state.RAG_TEMPLATE = (
form_data.template if form_data.template != "" else RAG_TEMPLATE
)
return {"status": True, "template": app.state.RAG_TEMPLATE}
class QueryDocForm(BaseModel):
collection_name: str
query: str
k: Optional[int] = 4
@app.post("/query/doc")
def query_doc(
form_data: QueryDocForm,
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection( collection = CHROMA_CLIENT.get_collection(
name=collection_name, name=form_data.collection_name,
embedding_function=app.state.sentence_transformer_ef,
) )
result = collection.query(query_texts=[query], n_results=k) result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
return result return result
except Exception as e: except Exception as e:
print(e) print(e)
...@@ -112,14 +250,99 @@ def query_collection( ...@@ -112,14 +250,99 @@ def query_collection(
) )
class QueryCollectionsForm(BaseModel):
collection_names: List[str]
query: str
k: Optional[int] = 4
def merge_and_sort_query_results(query_results, k):
# Initialize lists to store combined data
combined_ids = []
combined_distances = []
combined_metadatas = []
combined_documents = []
# Combine data from each dictionary
for data in query_results:
combined_ids.extend(data["ids"][0])
combined_distances.extend(data["distances"][0])
combined_metadatas.extend(data["metadatas"][0])
combined_documents.extend(data["documents"][0])
# Create a list of tuples (distance, id, metadata, document)
combined = list(
zip(combined_distances, combined_ids, combined_metadatas, combined_documents)
)
# Sort the list based on distances
combined.sort(key=lambda x: x[0])
# Unzip the sorted list
sorted_distances, sorted_ids, sorted_metadatas, sorted_documents = zip(*combined)
# Slicing the lists to include only k elements
sorted_distances = list(sorted_distances)[:k]
sorted_ids = list(sorted_ids)[:k]
sorted_metadatas = list(sorted_metadatas)[:k]
sorted_documents = list(sorted_documents)[:k]
# Create the output dictionary
merged_query_results = {
"ids": [sorted_ids],
"distances": [sorted_distances],
"metadatas": [sorted_metadatas],
"documents": [sorted_documents],
"embeddings": None,
"uris": None,
"data": None,
}
return merged_query_results
@app.post("/query/collection")
def query_collection(
form_data: QueryCollectionsForm,
user=Depends(get_current_user),
):
results = []
for collection_name in form_data.collection_names:
try:
# if you use docker use the model from the environment variable
collection = CHROMA_CLIENT.get_collection(
name=collection_name,
embedding_function=app.state.sentence_transformer_ef,
)
result = collection.query(
query_texts=[form_data.query], n_results=form_data.k
)
results.append(result)
except:
pass
return merge_and_sort_query_results(results, form_data.k)
@app.post("/web") @app.post("/web")
def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = WebBaseLoader(form_data.url) loader = WebBaseLoader(form_data.url)
data = loader.load() data = loader.load()
store_data_in_vector_db(data, form_data.collection_name)
return {"status": True, "collection_name": form_data.collection_name} collection_name = form_data.collection_name
if collection_name == "":
collection_name = calculate_sha256_string(form_data.url)[:63]
store_data_in_vector_db(data, collection_name)
return {
"status": True,
"collection_name": collection_name,
"filename": form_data.url,
}
except Exception as e: except Exception as e:
print(e) print(e)
raise HTTPException( raise HTTPException(
...@@ -128,6 +351,87 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)): ...@@ -128,6 +351,87 @@ def store_web(form_data: StoreWebForm, user=Depends(get_current_user)):
) )
def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
known_type = True
known_source_ext = [
"go",
"py",
"java",
"sh",
"bat",
"ps1",
"cmd",
"js",
"ts",
"css",
"cpp",
"hpp",
"h",
"c",
"cs",
"sql",
"log",
"ini",
"pl",
"pm",
"r",
"dart",
"dockerfile",
"env",
"php",
"hs",
"hsc",
"lua",
"nginxconf",
"conf",
"m",
"mm",
"plsql",
"perl",
"rb",
"rs",
"db2",
"scala",
"bash",
"swift",
"vue",
"svelte",
]
if file_ext == "pdf":
loader = PyPDFLoader(file_path)
elif file_ext == "csv":
loader = CSVLoader(file_path)
elif file_ext == "rst":
loader = UnstructuredRSTLoader(file_path, mode="elements")
elif file_ext == "xml":
loader = UnstructuredXMLLoader(file_path)
elif file_ext == "md":
loader = UnstructuredMarkdownLoader(file_path)
elif file_content_type == "application/epub+zip":
loader = UnstructuredEPubLoader(file_path)
elif (
file_content_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
or file_ext in ["doc", "docx"]
):
loader = Docx2txtLoader(file_path)
elif file_content_type in [
"application/vnd.ms-excel",
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
] or file_ext in ["xls", "xlsx"]:
loader = UnstructuredExcelLoader(file_path)
elif file_ext in known_source_ext or (file_content_type and file_content_type.find("text/") >= 0):
loader = TextLoader(file_path)
else:
loader = TextLoader(file_path)
known_type = False
return loader, known_type
@app.post("/doc") @app.post("/doc")
def store_doc( def store_doc(
collection_name: Optional[str] = Form(None), collection_name: Optional[str] = Form(None),
...@@ -136,17 +440,7 @@ def store_doc( ...@@ -136,17 +440,7 @@ def store_doc(
): ):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
if file.content_type not in [ print(file.content_type)
"application/pdf",
"text/plain",
"text/csv",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
]:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_NOT_SUPPORTED,
)
try: try:
filename = file.filename filename = file.filename
file_path = f"{UPLOAD_DIR}/{filename}" file_path = f"{UPLOAD_DIR}/{filename}"
...@@ -160,23 +454,17 @@ def store_doc( ...@@ -160,23 +454,17 @@ def store_doc(
collection_name = calculate_sha256(f)[:63] collection_name = calculate_sha256(f)[:63]
f.close() f.close()
if file.content_type == "application/pdf": loader, known_type = get_loader(file.filename, file.content_type, file_path)
loader = PyPDFLoader(file_path)
elif (
file.content_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
loader = Docx2txtLoader(file_path)
elif file.content_type == "text/plain":
loader = TextLoader(file_path)
elif file.content_type == "text/csv":
loader = CSVLoader(file_path)
data = loader.load() data = loader.load()
result = store_data_in_vector_db(data, collection_name) result = store_data_in_vector_db(data, collection_name)
if result: if result:
return {"status": True, "collection_name": collection_name} return {
"status": True,
"collection_name": collection_name,
"filename": filename,
"known_type": known_type,
}
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
...@@ -184,45 +472,96 @@ def store_doc( ...@@ -184,45 +472,96 @@ def store_doc(
) )
except Exception as e: except Exception as e:
print(e) print(e)
raise HTTPException( if "No pandoc was found" in str(e):
status_code=status.HTTP_400_BAD_REQUEST, raise HTTPException(
detail=ERROR_MESSAGES.DEFAULT(e), status_code=status.HTTP_400_BAD_REQUEST,
) detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
@app.get("/scan")
def scan_docs_dir(user=Depends(get_admin_user)):
for path in Path(DOCS_DIR).rglob("./**/*"):
try:
if path.is_file() and not path.name.startswith("."):
tags = extract_folders_after_data_docs(path)
filename = path.name
file_content_type = mimetypes.guess_type(path)
f = open(path, "rb")
collection_name = calculate_sha256(f)[:63]
f.close()
loader, known_type = get_loader(
filename, file_content_type[0], str(path)
)
data = loader.load()
result = store_data_in_vector_db(data, collection_name)
if result:
sanitized_filename = sanitize_filename(filename)
doc = Documents.get_doc_by_name(sanitized_filename)
if doc == None:
doc = Documents.insert_new_doc(
user.id,
DocumentForm(
**{
"name": sanitized_filename,
"title": filename,
"collection_name": collection_name,
"filename": filename,
"content": (
json.dumps(
{
"tags": list(
map(
lambda name: {"name": name},
tags,
)
)
}
)
if len(tags)
else "{}"
),
}
),
)
except Exception as e:
print(e)
return True
@app.get("/reset/db") @app.get("/reset/db")
def reset_vector_db(user=Depends(get_current_user)): def reset_vector_db(user=Depends(get_admin_user)):
if user.role == "admin": CHROMA_CLIENT.reset()
CHROMA_CLIENT.reset()
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@app.get("/reset") @app.get("/reset")
def reset(user=Depends(get_current_user)) -> bool: def reset(user=Depends(get_admin_user)) -> bool:
if user.role == "admin": folder = f"{UPLOAD_DIR}"
folder = f"{UPLOAD_DIR}" for filename in os.listdir(folder):
for filename in os.listdir(folder): file_path = os.path.join(folder, filename)
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print("Failed to delete %s. Reason: %s" % (file_path, e))
try: try:
CHROMA_CLIENT.reset() if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e: except Exception as e:
print(e) print("Failed to delete %s. Reason: %s" % (file_path, e))
return True try:
else: CHROMA_CLIENT.reset()
raise HTTPException( except Exception as e:
status_code=status.HTTP_403_FORBIDDEN, print(e)
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) return True
from peewee import * from peewee import *
from config import DATA_DIR
DB = SqliteDatabase("./data/ollama.db")
DB = SqliteDatabase(f"{DATA_DIR}/ollama.db")
DB.connect() DB.connect()
from fastapi import FastAPI, Depends from fastapi import FastAPI, Depends
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from apps.web.routers import auths, users, chats, modelfiles, prompts, configs, utils from apps.web.routers import (
from config import WEBUI_VERSION, WEBUI_AUTH auths,
users,
chats,
documents,
modelfiles,
prompts,
configs,
utils,
)
from config import (
WEBUI_VERSION,
WEBUI_AUTH,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
DEFAULT_USER_ROLE,
ENABLE_SIGNUP,
USER_PERMISSIONS,
)
app = FastAPI() app = FastAPI()
origins = ["*"] origins = ["*"]
app.state.ENABLE_SIGNUP = True app.state.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.DEFAULT_MODELS = None app.state.JWT_EXPIRES_IN = "-1"
app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
...@@ -22,9 +45,8 @@ app.add_middleware( ...@@ -22,9 +45,8 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(modelfiles.router, app.include_router(documents.router, prefix="/documents", tags=["documents"])
prefix="/modelfiles", app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
tags=["modelfiles"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(configs.router, prefix="/configs", tags=["configs"])
...@@ -35,7 +57,7 @@ app.include_router(utils.router, prefix="/utils", tags=["utils"]) ...@@ -35,7 +57,7 @@ app.include_router(utils.router, prefix="/utils", tags=["utils"])
async def get_status(): async def get_status():
return { return {
"status": True, "status": True,
"version": WEBUI_VERSION,
"auth": WEBUI_AUTH, "auth": WEBUI_AUTH,
"default_models": app.state.DEFAULT_MODELS, "default_models": app.state.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS,
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment