Unverified Commit be5534c6 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #2376 from open-webui/dev

0.1.125
parents 90503be2 bcc2bab6
## Pull Request Checklist # Pull Request Checklist
- [ ] **Target branch:** Pull requests should target the `dev` branch. ### Note to first-time contributors: Please open a discussion post in [Discussions](https://github.com/open-webui/open-webui/discussions) and describe your changes before submitting a pull request.
- [ ] **Description:** Briefly describe the changes in this pull request.
**Before submitting, make sure you've checked the following:**
- [ ] **Target branch:** Please verify that the pull request targets the `dev` branch.
- [ ] **Description:** Provide a concise description of the changes made in this pull request.
- [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description. - [ ] **Changelog:** Ensure a changelog entry following the format of [Keep a Changelog](https://keepachangelog.com/) is added at the bottom of the PR description.
- [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources? - [ ] **Documentation:** Have you updated relevant documentation [Open WebUI Docs](https://github.com/open-webui/docs), or other documentation sources?
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation? - [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
- [ ] **Testing:** Have you written and run sufficient tests for the changes? - [ ] **Testing:** Have you written and run sufficient tests for validating the changes?
- [ ] **Code Review:** Have you self-reviewed your code and addressed any coding standard issues? - [ ] **Code review:** Have you performed a self-review of your code, addressing any coding standard issues and ensuring adherence to the project's coding standards?
- [ ] **Label:** To cleary categorize this pull request, assign a relevant label to the pull request title, using one of the following:
--- - **BREAKING CHANGE**: Significant changes that may affect compatibility
- **build**: Changes that affect the build system or external dependencies
## Description - **ci**: Changes to our continuous integration processes or workflows
- **chore**: Refactor, cleanup, or other non-functional code changes
[Insert a brief description of the changes made in this pull request, including any relevant motivation and impact.] - **docs**: Documentation update or addition
- **feat**: Introduces a new feature or enhancement to the codebase
--- - **fix**: Bug fix or error correction
- **i18n**: Internationalization or localization changes
### Changelog Entry - **perf**: Performance improvement
- **refactor**: Code restructuring for better maintainability, readability, or scalability
- **style**: Changes that do not affect the meaning of the code (white-space, formatting, missing semi-colons, etc.)
- **test**: Adding missing tests or correcting existing tests
- **WIP**: Work in progress, a temporary label for incomplete or ongoing work
# Changelog Entry
### Description
- [Concisely describe the changes made in this pull request, including any relevant motivation and impact (e.g., fixing a bug, adding a feature, or improving performance)]
### Added ### Added
- [List any new features, functionalities, or additions] - [List any new features, functionalities, or additions]
### Fixed
- [List any fixes, corrections, or bug fixes]
### Changed ### Changed
- [List any changes, updates, refactorings, or optimizations] - [List any changes, updates, refactorings, or optimizations]
### Deprecated
- [List any deprecated functionality or features that have been removed]
### Removed ### Removed
- [List any removed features, files, or deprecated functionalities] - [List any removed features, files, or functionalities]
### Fixed
- [List any fixes, corrections, or bug fixes]
### Security ### Security
...@@ -40,12 +58,15 @@ ...@@ -40,12 +58,15 @@
### Breaking Changes ### Breaking Changes
- [List any breaking changes affecting compatibility or functionality] - **BREAKING CHANGE**: [List any breaking changes affecting compatibility or functionality]
--- ---
### Additional Information ### Additional Information
- [Insert any additional context, notes, or explanations for the changes] - [Insert any additional context, notes, or explanations for the changes]
- [Reference any related issues, commits, or other relevant information]
### Screenshots or Videos
- [Reference any related issues, commits, or other relevant information] - [Attach any relevant screenshots or videos demonstrating the changes]
...@@ -37,3 +37,21 @@ jobs: ...@@ -37,3 +37,21 @@ jobs:
- name: Build Frontend - name: Build Frontend
run: npm run build run: npm run build
test-frontend:
name: 'Frontend Unit Tests'
runs-on: ubuntu-latest
steps:
- name: Checkout Repository
uses: actions/checkout@v4
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: '20'
- name: Install Dependencies
run: npm ci
- name: Run vitest
run: npm run test:frontend
...@@ -16,6 +16,10 @@ __pycache__/ ...@@ -16,6 +16,10 @@ __pycache__/
# C extensions # C extensions
*.so *.so
# Pyodide distribution
static/pyodide/*
!static/pyodide/pyodide-lock.json
# Distribution / packaging # Distribution / packaging
.Python .Python
build/ build/
......
# Ignore files for PNPM, NPM and YARN
pnpm-lock.yaml
package-lock.json
yarn.lock
kubernetes/
# Copy of .gitignore
.DS_Store .DS_Store
node_modules node_modules
/build /build
...@@ -6,11 +14,303 @@ node_modules ...@@ -6,11 +14,303 @@ node_modules
.env .env
.env.* .env.*
!.env.example !.env.example
vite.config.js.timestamp-*
vite.config.ts.timestamp-*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
.pnpm-debug.log*
# Diagnostic reports (https://nodejs.org/api/report.html)
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
*.lcov
# nyc test coverage
.nyc_output
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
.grunt
# Bower dependency directory (https://bower.io/)
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons (https://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# Snowpack dependency directory (https://snowpack.dev/)
web_modules/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Optional stylelint cache
.stylelintcache
# Microbundle cache
.rpt2_cache/
.rts2_cache_cjs/
.rts2_cache_es/
.rts2_cache_umd/
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variable files
.env
.env.development.local
.env.test.local
.env.production.local
.env.local
# parcel-bundler cache (https://parceljs.org/)
.cache
.parcel-cache
# Next.js build output
.next
out
# Nuxt.js build / generate output
.nuxt
dist
# Gatsby files
.cache/
# Comment in the public line in if your project uses Gatsby and not Next.js
# https://nextjs.org/blog/next-9-1#public-directory-support
# public
# vuepress build output
.vuepress/dist
# vuepress v2.x temp and cache directory
.temp
.cache
# Docusaurus cache and generated files
.docusaurus
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# TernJS port file
.tern-port
# Stores VSCode versions used for testing VSCode extensions
.vscode-test
# yarn v2
.yarn/cache
.yarn/unplugged
.yarn/build-state.yml
.yarn/install-state.gz
.pnp.*
# cypress artifacts
cypress/videos
cypress/screenshots
# Ignore files for PNPM, NPM and YARN
pnpm-lock.yaml
package-lock.json
yarn.lock
# Ignore kubernetes files /static/*
kubernetes \ No newline at end of file
\ No newline at end of file
...@@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file. ...@@ -5,6 +5,25 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.1.125] - 2024-05-19
### Added
- **🔄 Updated UI**: Chat interface revamped with chat bubbles. Easily switch back to the old style via settings > interface > chat bubble UI.
- **📂 Enhanced Sidebar UI**: Model files, documents, prompts, and playground merged into Workspace for streamlined access.
- **🚀 Improved Many Model Interaction**: All responses now displayed simultaneously for a smoother experience.
- **🐍 Python Code Execution**: Execute Python code locally in the browser with libraries like 'requests', 'beautifulsoup4', 'numpy', 'pandas', 'seaborn', 'matplotlib', 'scikit-learn', 'scipy', 'regex'.
- **🧠 Experimental Memory Feature**: Manually input personal information you want LLMs to remember via settings > personalization > memory.
- **💾 Persistent Settings**: Settings now saved as config.json for convenience.
- **🩺 Health Check Endpoint**: Added for Docker deployment.
- **↕️ RTL Support**: Toggle chat direction via settings > interface > chat direction.
- **🖥️ PowerPoint Support**: RAG pipeline now supports PowerPoint documents.
- **🌐 Language Updates**: Ukrainian, Turkish, Arabic, Chinese, Serbian, Vietnamese updated; Punjabi added.
### Changed
- **👤 Shared Chat Update**: Shared chat now includes creator user information.
## [0.1.124] - 2024-05-08 ## [0.1.124] - 2024-05-08
### Added ### Added
......
...@@ -11,6 +11,9 @@ ARG USE_CUDA_VER=cu121 ...@@ -11,6 +11,9 @@ ARG USE_CUDA_VER=cu121
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
ARG USE_RERANKING_MODEL="" ARG USE_RERANKING_MODEL=""
# Override at your own risk - non-root configurations are untested
ARG UID=0
ARG GID=0
######## WebUI frontend ######## ######## WebUI frontend ########
FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
...@@ -32,6 +35,8 @@ ARG USE_OLLAMA ...@@ -32,6 +35,8 @@ ARG USE_OLLAMA
ARG USE_CUDA_VER ARG USE_CUDA_VER
ARG USE_EMBEDDING_MODEL ARG USE_EMBEDDING_MODEL
ARG USE_RERANKING_MODEL ARG USE_RERANKING_MODEL
ARG UID
ARG GID
## Basis ## ## Basis ##
ENV ENV=prod \ ENV ENV=prod \
...@@ -76,17 +81,28 @@ ENV HF_HOME="/app/backend/data/cache/embedding/models" ...@@ -76,17 +81,28 @@ ENV HF_HOME="/app/backend/data/cache/embedding/models"
WORKDIR /app/backend WORKDIR /app/backend
ENV HOME /root ENV HOME /root
# Create user and group if not root
RUN if [ $UID -ne 0 ]; then \
if [ $GID -ne 0 ]; then \
addgroup --gid $GID app; \
fi; \
adduser --uid $UID --gid $GID --home $HOME --disabled-password --no-create-home app; \
fi
RUN mkdir -p $HOME/.cache/chroma RUN mkdir -p $HOME/.cache/chroma
RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry_user_id
# Make sure the user has access to the app and root directory
RUN chown -R $UID:$GID /app $HOME
RUN if [ "$USE_OLLAMA" = "true" ]; then \ RUN if [ "$USE_OLLAMA" = "true" ]; then \
apt-get update && \ apt-get update && \
# Install pandoc and netcat # Install pandoc and netcat
apt-get install -y --no-install-recommends pandoc netcat-openbsd && \ apt-get install -y --no-install-recommends pandoc netcat-openbsd curl && \
# for RAG OCR # for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \ apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# install helper tools # install helper tools
apt-get install -y --no-install-recommends curl && \ apt-get install -y --no-install-recommends curl jq && \
# install ollama # install ollama
curl -fsSL https://ollama.com/install.sh | sh && \ curl -fsSL https://ollama.com/install.sh | sh && \
# cleanup # cleanup
...@@ -94,7 +110,7 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \ ...@@ -94,7 +110,7 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \
else \ else \
apt-get update && \ apt-get update && \
# Install pandoc and netcat # Install pandoc and netcat
apt-get install -y --no-install-recommends pandoc netcat-openbsd && \ apt-get install -y --no-install-recommends pandoc netcat-openbsd curl jq && \
# for RAG OCR # for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \ apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# cleanup # cleanup
...@@ -102,7 +118,7 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \ ...@@ -102,7 +118,7 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \
fi fi
# install python dependencies # install python dependencies
COPY ./backend/requirements.txt ./requirements.txt COPY --chown=$UID:$GID ./backend/requirements.txt ./requirements.txt
RUN pip3 install uv && \ RUN pip3 install uv && \
if [ "$USE_CUDA" = "true" ]; then \ if [ "$USE_CUDA" = "true" ]; then \
...@@ -125,13 +141,17 @@ RUN pip3 install uv && \ ...@@ -125,13 +141,17 @@ RUN pip3 install uv && \
# COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx # COPY --from=build /app/onnx /root/.cache/chroma/onnx_models/all-MiniLM-L6-v2/onnx
# copy built frontend files # copy built frontend files
COPY --from=build /app/build /app/build COPY --chown=$UID:$GID --from=build /app/build /app/build
COPY --from=build /app/CHANGELOG.md /app/CHANGELOG.md COPY --chown=$UID:$GID --from=build /app/CHANGELOG.md /app/CHANGELOG.md
COPY --from=build /app/package.json /app/package.json COPY --chown=$UID:$GID --from=build /app/package.json /app/package.json
# copy backend files # copy backend files
COPY ./backend . COPY --chown=$UID:$GID ./backend .
EXPOSE 8080 EXPOSE 8080
HEALTHCHECK CMD curl --silent --fail http://localhost:8080/health | jq -e '.status == true' || exit 1
USER $UID:$GID
CMD [ "bash", "start.sh"] CMD [ "bash", "start.sh"]
...@@ -45,6 +45,7 @@ from config import ( ...@@ -45,6 +45,7 @@ from config import (
AUDIO_OPENAI_API_KEY, AUDIO_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL, AUDIO_OPENAI_API_MODEL,
AUDIO_OPENAI_API_VOICE, AUDIO_OPENAI_API_VOICE,
AppConfig,
) )
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
...@@ -59,11 +60,11 @@ app.add_middleware( ...@@ -59,11 +60,11 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.config = AppConfig()
app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
app.state.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
app.state.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
# setting device type for whisper model # setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu" whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
...@@ -83,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel): ...@@ -83,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/config") @app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)): async def get_openai_config(user=Depends(get_admin_user)):
return { return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
} }
...@@ -97,17 +98,17 @@ async def update_openai_config( ...@@ -97,17 +98,17 @@ async def update_openai_config(
if form_data.key == "": if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url app.state.config.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key app.state.config.OPENAI_API_KEY = form_data.key
app.state.OPENAI_API_MODEL = form_data.model app.state.config.OPENAI_API_MODEL = form_data.model
app.state.OPENAI_API_VOICE = form_data.speaker app.state.config.OPENAI_API_VOICE = form_data.speaker
return { return {
"status": True, "status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL, "OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE, "OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
} }
...@@ -124,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -124,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech", url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech",
data=body, data=body,
headers=headers, headers=headers,
stream=True, stream=True,
......
...@@ -42,6 +42,7 @@ from config import ( ...@@ -42,6 +42,7 @@ from config import (
IMAGE_GENERATION_MODEL, IMAGE_GENERATION_MODEL,
IMAGE_SIZE, IMAGE_SIZE,
IMAGE_STEPS, IMAGE_STEPS,
AppConfig,
) )
...@@ -60,26 +61,31 @@ app.add_middleware( ...@@ -60,26 +61,31 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.ENGINE = IMAGE_GENERATION_ENGINE app.state.config = AppConfig()
app.state.ENABLED = ENABLE_IMAGE_GENERATION
app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
app.state.MODEL = IMAGE_GENERATION_MODEL app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.MODEL = IMAGE_GENERATION_MODEL
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.IMAGE_SIZE = IMAGE_SIZE
app.state.IMAGE_STEPS = IMAGE_STEPS app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
@app.get("/config") @app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)): async def get_config(request: Request, user=Depends(get_admin_user)):
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} return {
"engine": app.state.config.ENGINE,
"enabled": app.state.config.ENABLED,
}
class ConfigUpdateForm(BaseModel): class ConfigUpdateForm(BaseModel):
...@@ -89,9 +95,12 @@ class ConfigUpdateForm(BaseModel): ...@@ -89,9 +95,12 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update") @app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.ENGINE = form_data.engine app.state.config.ENGINE = form_data.engine
app.state.ENABLED = form_data.enabled app.state.config.ENABLED = form_data.enabled
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED} return {
"engine": app.state.config.ENGINE,
"enabled": app.state.config.ENABLED,
}
class EngineUrlUpdateForm(BaseModel): class EngineUrlUpdateForm(BaseModel):
...@@ -102,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel): ...@@ -102,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
@app.get("/url") @app.get("/url")
async def get_engine_url(user=Depends(get_admin_user)): async def get_engine_url(user=Depends(get_admin_user)):
return { return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
} }
...@@ -113,29 +122,29 @@ async def update_engine_url( ...@@ -113,29 +122,29 @@ async def update_engine_url(
): ):
if form_data.AUTOMATIC1111_BASE_URL == None: if form_data.AUTOMATIC1111_BASE_URL == None:
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else: else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/") url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
app.state.AUTOMATIC1111_BASE_URL = url app.state.config.AUTOMATIC1111_BASE_URL = url
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None: if form_data.COMFYUI_BASE_URL == None:
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else: else:
url = form_data.COMFYUI_BASE_URL.strip("/") url = form_data.COMFYUI_BASE_URL.strip("/")
try: try:
r = requests.head(url) r = requests.head(url)
app.state.COMFYUI_BASE_URL = url app.state.config.COMFYUI_BASE_URL = url
except Exception as e: except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return { return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL, "AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL, "COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"status": True, "status": True,
} }
...@@ -148,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel): ...@@ -148,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/openai/config") @app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)): async def get_openai_config(user=Depends(get_admin_user)):
return { return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
} }
...@@ -160,13 +169,13 @@ async def update_openai_config( ...@@ -160,13 +169,13 @@ async def update_openai_config(
if form_data.key == "": if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url app.state.config.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key app.state.config.OPENAI_API_KEY = form_data.key
return { return {
"status": True, "status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL, "OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY, "OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
} }
...@@ -176,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel): ...@@ -176,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
@app.get("/size") @app.get("/size")
async def get_image_size(user=Depends(get_admin_user)): async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_SIZE": app.state.IMAGE_SIZE} return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
@app.post("/size/update") @app.post("/size/update")
...@@ -185,9 +194,9 @@ async def update_image_size( ...@@ -185,9 +194,9 @@ async def update_image_size(
): ):
pattern = r"^\d+x\d+$" # Regular expression pattern pattern = r"^\d+x\d+$" # Regular expression pattern
if re.match(pattern, form_data.size): if re.match(pattern, form_data.size):
app.state.IMAGE_SIZE = form_data.size app.state.config.IMAGE_SIZE = form_data.size
return { return {
"IMAGE_SIZE": app.state.IMAGE_SIZE, "IMAGE_SIZE": app.state.config.IMAGE_SIZE,
"status": True, "status": True,
} }
else: else:
...@@ -203,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel): ...@@ -203,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
@app.get("/steps") @app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)): async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_STEPS": app.state.IMAGE_STEPS} return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
@app.post("/steps/update") @app.post("/steps/update")
...@@ -211,9 +220,9 @@ async def update_image_size( ...@@ -211,9 +220,9 @@ async def update_image_size(
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user) form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
): ):
if form_data.steps >= 0: if form_data.steps >= 0:
app.state.IMAGE_STEPS = form_data.steps app.state.config.IMAGE_STEPS = form_data.steps
return { return {
"IMAGE_STEPS": app.state.IMAGE_STEPS, "IMAGE_STEPS": app.state.config.IMAGE_STEPS,
"status": True, "status": True,
} }
else: else:
...@@ -226,14 +235,14 @@ async def update_image_size( ...@@ -226,14 +235,14 @@ async def update_image_size(
@app.get("/models") @app.get("/models")
def get_models(user=Depends(get_current_user)): def get_models(user=Depends(get_current_user)):
try: try:
if app.state.ENGINE == "openai": if app.state.config.ENGINE == "openai":
return [ return [
{"id": "dall-e-2", "name": "DALL·E 2"}, {"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"}, {"id": "dall-e-3", "name": "DALL·E 3"},
] ]
elif app.state.ENGINE == "comfyui": elif app.state.config.ENGINE == "comfyui":
r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info") r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
info = r.json() info = r.json()
return list( return list(
...@@ -245,7 +254,7 @@ def get_models(user=Depends(get_current_user)): ...@@ -245,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
else: else:
r = requests.get( r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models" url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
) )
models = r.json() models = r.json()
return list( return list(
...@@ -255,23 +264,29 @@ def get_models(user=Depends(get_current_user)): ...@@ -255,23 +264,29 @@ def get_models(user=Depends(get_current_user)):
) )
) )
except Exception as e: except Exception as e:
app.state.ENABLED = False app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default") @app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)): async def get_default_model(user=Depends(get_admin_user)):
try: try:
if app.state.ENGINE == "openai": if app.state.config.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"} return {
elif app.state.ENGINE == "comfyui": "model": (
return {"model": app.state.MODEL if app.state.MODEL else ""} app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
)
}
elif app.state.config.ENGINE == "comfyui":
return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
)
options = r.json() options = r.json()
return {"model": options["sd_model_checkpoint"]} return {"model": options["sd_model_checkpoint"]}
except Exception as e: except Exception as e:
app.state.ENABLED = False app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e)) raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
...@@ -280,20 +295,20 @@ class UpdateModelForm(BaseModel): ...@@ -280,20 +295,20 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str): def set_model_handler(model: str):
if app.state.ENGINE == "openai": if app.state.config.ENGINE in ["openai", "comfyui"]:
app.state.MODEL = model app.state.config.MODEL = model
return app.state.MODEL return app.state.config.MODEL
if app.state.ENGINE == "comfyui":
app.state.MODEL = model
return app.state.MODEL
else: else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options") r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
)
options = r.json() options = r.json()
if model != options["sd_model_checkpoint"]: if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model options["sd_model_checkpoint"] = model
r = requests.post( r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
) )
return options return options
...@@ -382,26 +397,32 @@ def generate_image( ...@@ -382,26 +397,32 @@ def generate_image(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x"))) width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
r = None r = None
try: try:
if app.state.ENGINE == "openai": if app.state.config.ENGINE == "openai":
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}" headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
data = { data = {
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2", "model": (
app.state.config.MODEL
if app.state.config.MODEL != ""
else "dall-e-2"
),
"prompt": form_data.prompt, "prompt": form_data.prompt,
"n": form_data.n, "n": form_data.n,
"size": form_data.size if form_data.size else app.state.IMAGE_SIZE, "size": (
form_data.size if form_data.size else app.state.config.IMAGE_SIZE
),
"response_format": "b64_json", "response_format": "b64_json",
} }
r = requests.post( r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URL}/images/generations", url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
json=data, json=data,
headers=headers, headers=headers,
) )
...@@ -421,7 +442,7 @@ def generate_image( ...@@ -421,7 +442,7 @@ def generate_image(
return images return images
elif app.state.ENGINE == "comfyui": elif app.state.config.ENGINE == "comfyui":
data = { data = {
"prompt": form_data.prompt, "prompt": form_data.prompt,
...@@ -430,19 +451,19 @@ def generate_image( ...@@ -430,19 +451,19 @@ def generate_image(
"n": form_data.n, "n": form_data.n,
} }
if app.state.IMAGE_STEPS != None: if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.IMAGE_STEPS data["steps"] = app.state.config.IMAGE_STEPS
if form_data.negative_prompt != None: if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt data["negative_prompt"] = form_data.negative_prompt
data = ImageGenerationPayload(**data) data = ImageGenerationPayload(**data)
res = comfyui_generate_image( res = comfyui_generate_image(
app.state.MODEL, app.state.config.MODEL,
data, data,
user.id, user.id,
app.state.COMFYUI_BASE_URL, app.state.config.COMFYUI_BASE_URL,
) )
log.debug(f"res: {res}") log.debug(f"res: {res}")
...@@ -469,14 +490,14 @@ def generate_image( ...@@ -469,14 +490,14 @@ def generate_image(
"height": height, "height": height,
} }
if app.state.IMAGE_STEPS != None: if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.IMAGE_STEPS data["steps"] = app.state.config.IMAGE_STEPS
if form_data.negative_prompt != None: if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt data["negative_prompt"] = form_data.negative_prompt
r = requests.post( r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img", url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data, json=data,
) )
......
import sys import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException from fastapi import FastAPI, Depends, HTTPException
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
...@@ -46,7 +47,16 @@ import asyncio ...@@ -46,7 +47,16 @@ import asyncio
import subprocess import subprocess
import yaml import yaml
app = FastAPI()
@asynccontextmanager
async def lifespan(app: FastAPI):
log.info("startup_event")
# TODO: Check config.yaml file and create one
asyncio.create_task(start_litellm_background())
yield
app = FastAPI(lifespan=lifespan)
origins = ["*"] origins = ["*"]
...@@ -65,6 +75,10 @@ with open(LITELLM_CONFIG_DIR, "r") as file: ...@@ -65,6 +75,10 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
litellm_config = yaml.safe_load(file) litellm_config = yaml.safe_load(file)
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
app.state.ENABLE = ENABLE_LITELLM app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config app.state.CONFIG = litellm_config
...@@ -141,17 +155,6 @@ async def shutdown_litellm_background(): ...@@ -141,17 +155,6 @@ async def shutdown_litellm_background():
background_process = None background_process = None
@app.on_event("startup")
async def startup_event():
log.info("startup_event")
# TODO: Check config.yaml file and create one
asyncio.create_task(start_litellm_background())
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
return {"status": True} return {"status": True}
......
...@@ -46,6 +46,7 @@ from config import ( ...@@ -46,6 +46,7 @@ from config import (
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
UPLOAD_DIR, UPLOAD_DIR,
AppConfig,
) )
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
...@@ -61,11 +62,12 @@ app.add_middleware( ...@@ -61,11 +62,12 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}
...@@ -96,7 +98,7 @@ async def get_status(): ...@@ -96,7 +98,7 @@ async def get_status():
@app.get("/urls") @app.get("/urls")
async def get_ollama_api_urls(user=Depends(get_admin_user)): async def get_ollama_api_urls(user=Depends(get_admin_user)):
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
class UrlUpdateForm(BaseModel): class UrlUpdateForm(BaseModel):
...@@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel): ...@@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel):
@app.post("/urls/update") @app.post("/urls/update")
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)): async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
app.state.OLLAMA_BASE_URLS = form_data.urls app.state.config.OLLAMA_BASE_URLS = form_data.urls
log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}") log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
@app.get("/cancel/{request_id}") @app.get("/cancel/{request_id}")
...@@ -122,8 +124,9 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)) ...@@ -122,8 +124,9 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user))
async def fetch_url(url): async def fetch_url(url):
timeout = aiohttp.ClientTimeout(total=5)
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url) as response: async with session.get(url) as response:
return await response.json() return await response.json()
except Exception as e: except Exception as e:
...@@ -153,7 +156,7 @@ def merge_models_lists(model_lists): ...@@ -153,7 +156,7 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS] tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
models = { models = {
...@@ -175,18 +178,19 @@ async def get_ollama_tags( ...@@ -175,18 +178,19 @@ async def get_ollama_tags(
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user": if user.role == "user":
models["models"] = list( models["models"] = list(
filter( filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST, lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"], models["models"],
) )
) )
return models return models
return models return models
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/tags") r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status() r.raise_for_status()
...@@ -216,7 +220,9 @@ async def get_ollama_versions(url_idx: Optional[int] = None): ...@@ -216,7 +220,9 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if url_idx == None: if url_idx == None:
# returns lowest version # returns lowest version
tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS] tasks = [
fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses)) responses = list(filter(lambda x: x is not None, responses))
...@@ -235,7 +241,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None): ...@@ -235,7 +241,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND, detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
) )
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/version") r = requests.request(method="GET", url=f"{url}/api/version")
r.raise_for_status() r.raise_for_status()
...@@ -267,7 +273,7 @@ class ModelNameForm(BaseModel): ...@@ -267,7 +273,7 @@ class ModelNameForm(BaseModel):
async def pull_model( async def pull_model(
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
...@@ -355,7 +361,7 @@ async def push_model( ...@@ -355,7 +361,7 @@ async def push_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}") log.debug(f"url: {url}")
r = None r = None
...@@ -417,7 +423,7 @@ async def create_model( ...@@ -417,7 +423,7 @@ async def create_model(
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user) form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
): ):
log.debug(f"form_data: {form_data}") log.debug(f"form_data: {form_data}")
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
...@@ -490,7 +496,7 @@ async def copy_model( ...@@ -490,7 +496,7 @@ async def copy_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
...@@ -537,7 +543,7 @@ async def delete_model( ...@@ -537,7 +543,7 @@ async def delete_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
...@@ -577,7 +583,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us ...@@ -577,7 +583,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
) )
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
...@@ -634,7 +640,7 @@ async def generate_embeddings( ...@@ -634,7 +640,7 @@ async def generate_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
...@@ -684,7 +690,7 @@ def generate_ollama_embeddings( ...@@ -684,7 +690,7 @@ def generate_ollama_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
try: try:
...@@ -753,7 +759,7 @@ async def generate_completion( ...@@ -753,7 +759,7 @@ async def generate_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
...@@ -856,7 +862,7 @@ async def generate_chat_completion( ...@@ -856,7 +862,7 @@ async def generate_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
...@@ -965,7 +971,7 @@ async def generate_openai_chat_completion( ...@@ -965,7 +971,7 @@ async def generate_openai_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
) )
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None r = None
...@@ -1041,11 +1047,12 @@ async def get_openai_models( ...@@ -1041,11 +1047,12 @@ async def get_openai_models(
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user": if user.role == "user":
models["models"] = list( models["models"] = list(
filter( filter(
lambda model: model["name"] in app.state.MODEL_FILTER_LIST, lambda model: model["name"]
in app.state.config.MODEL_FILTER_LIST,
models["models"], models["models"],
) )
) )
...@@ -1064,7 +1071,7 @@ async def get_openai_models( ...@@ -1064,7 +1071,7 @@ async def get_openai_models(
} }
else: else:
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try: try:
r = requests.request(method="GET", url=f"{url}/api/tags") r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status() r.raise_for_status()
...@@ -1198,7 +1205,7 @@ async def download_model( ...@@ -1198,7 +1205,7 @@ async def download_model(
if url_idx == None: if url_idx == None:
url_idx = 0 url_idx = 0
url = app.state.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
file_name = parse_huggingface_url(form_data.url) file_name = parse_huggingface_url(form_data.url)
...@@ -1217,7 +1224,7 @@ async def download_model( ...@@ -1217,7 +1224,7 @@ async def download_model(
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
if url_idx == None: if url_idx == None:
url_idx = 0 url_idx = 0
ollama_url = app.state.OLLAMA_BASE_URLS[url_idx] ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = f"{UPLOAD_DIR}/{file.filename}" file_path = f"{UPLOAD_DIR}/{file.filename}"
...@@ -1282,7 +1289,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): ...@@ -1282,7 +1289,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None): # async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None: # if url_idx == None:
# url_idx = 0 # url_idx = 0
# url = app.state.OLLAMA_BASE_URLS[url_idx] # url = app.state.config.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename) # file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size # total_size = file.size
...@@ -1319,7 +1326,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None): ...@@ -1319,7 +1326,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
async def deprecated_proxy( async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user) path: str, request: Request, user=Depends(get_verified_user)
): ):
url = app.state.OLLAMA_BASE_URLS[0] url = app.state.config.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}" target_url = f"{url}/{path}"
body = await request.body() body = await request.body()
......
...@@ -21,11 +21,13 @@ from utils.utils import ( ...@@ -21,11 +21,13 @@ from utils.utils import (
) )
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
ENABLE_OPENAI_API,
OPENAI_API_BASE_URLS, OPENAI_API_BASE_URLS,
OPENAI_API_KEYS, OPENAI_API_KEYS,
CACHE_DIR, CACHE_DIR,
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
AppConfig,
) )
from typing import List, Optional from typing import List, Optional
...@@ -45,11 +47,16 @@ app.add_middleware( ...@@ -45,11 +47,16 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config = AppConfig()
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.MODELS = {} app.state.MODELS = {}
...@@ -65,6 +72,21 @@ async def check_url(request: Request, call_next): ...@@ -65,6 +72,21 @@ async def check_url(request: Request, call_next):
return response return response
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
class OpenAIConfigForm(BaseModel):
enable_openai_api: Optional[bool] = None
@app.post("/config/update")
async def update_config(form_data: OpenAIConfigForm, user=Depends(get_admin_user)):
app.state.config.ENABLE_OPENAI_API = form_data.enable_openai_api
return {"ENABLE_OPENAI_API": app.state.config.ENABLE_OPENAI_API}
class UrlsUpdateForm(BaseModel): class UrlsUpdateForm(BaseModel):
urls: List[str] urls: List[str]
...@@ -75,32 +97,32 @@ class KeysUpdateForm(BaseModel): ...@@ -75,32 +97,32 @@ class KeysUpdateForm(BaseModel):
@app.get("/urls") @app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)): async def get_openai_urls(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.post("/urls/update") @app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)): async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
await get_all_models() await get_all_models()
app.state.OPENAI_API_BASE_URLS = form_data.urls app.state.config.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS} return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.get("/keys") @app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)): async def get_openai_keys(user=Depends(get_admin_user)):
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/keys/update") @app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)): async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_KEYS = form_data.keys app.state.config.OPENAI_API_KEYS = form_data.keys
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS} return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/audio/speech") @app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)): async def speech(request: Request, user=Depends(get_verified_user)):
idx = None idx = None
try: try:
idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1") idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
body = await request.body() body = await request.body()
name = hashlib.sha256(body).hexdigest() name = hashlib.sha256(body).hexdigest()
...@@ -114,13 +136,15 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -114,13 +136,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path) return FileResponse(file_path)
headers = {} headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}" headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.config.OPENAI_API_BASE_URLS[idx]:
headers["HTTP-Referer"] = "https://openwebui.com/"
headers["X-Title"] = "Open WebUI"
r = None r = None
try: try:
r = requests.post( r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech", url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
data=body, data=body,
headers=headers, headers=headers,
stream=True, stream=True,
...@@ -159,11 +183,15 @@ async def speech(request: Request, user=Depends(get_verified_user)): ...@@ -159,11 +183,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
async def fetch_url(url, key): async def fetch_url(url, key):
timeout = aiohttp.ClientTimeout(total=5)
try: try:
if key != "":
headers = {"Authorization": f"Bearer {key}"} headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response: async with session.get(url, headers=headers) as response:
return await response.json() return await response.json()
else:
return None
except Exception as e: except Exception as e:
# Handle connection error here # Handle connection error here
log.error(f"Connection error: {e}") log.error(f"Connection error: {e}")
...@@ -180,7 +208,8 @@ def merge_models_lists(model_lists): ...@@ -180,7 +208,8 @@ def merge_models_lists(model_lists):
[ [
{**model, "urlIdx": idx} {**model, "urlIdx": idx}
for model in models for model in models
if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx] if "api.openai.com"
not in app.state.config.OPENAI_API_BASE_URLS[idx]
or "gpt" in model["id"] or "gpt" in model["id"]
] ]
) )
...@@ -191,12 +220,15 @@ def merge_models_lists(model_lists): ...@@ -191,12 +220,15 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "": if (
len(app.state.config.OPENAI_API_KEYS) == 1
and app.state.config.OPENAI_API_KEYS[0] == ""
) or not app.state.config.ENABLE_OPENAI_API:
models = {"data": []} models = {"data": []}
else: else:
tasks = [ tasks = [
fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx]) fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS) for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
...@@ -228,18 +260,18 @@ async def get_all_models(): ...@@ -228,18 +260,18 @@ async def get_all_models():
async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)): async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
if url_idx == None: if url_idx == None:
models = await get_all_models() models = await get_all_models()
if app.state.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
if user.role == "user": if user.role == "user":
models["data"] = list( models["data"] = list(
filter( filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST, lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
models["data"], models["data"],
) )
) )
return models return models
return models return models
else: else:
url = app.state.OPENAI_API_BASE_URLS[url_idx] url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
r = None r = None
...@@ -303,8 +335,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -303,8 +335,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e) log.error("Error loading request body into a dictionary:", e)
url = app.state.OPENAI_API_BASE_URLS[idx] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.OPENAI_API_KEYS[idx] key = app.state.config.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}" target_url = f"{url}/{path}"
......
...@@ -69,6 +69,7 @@ from utils.misc import ( ...@@ -69,6 +69,7 @@ from utils.misc import (
from utils.utils import get_current_user, get_admin_user from utils.utils import get_current_user, get_admin_user
from config import ( from config import (
ENV,
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
UPLOAD_DIR, UPLOAD_DIR,
DOCS_DIR, DOCS_DIR,
...@@ -93,6 +94,7 @@ from config import ( ...@@ -93,6 +94,7 @@ from config import (
RAG_TEMPLATE, RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH, ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE, YOUTUBE_LOADER_LANGUAGE,
AppConfig,
) )
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -102,30 +104,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"]) ...@@ -102,30 +104,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI() app = FastAPI()
app.state.TOP_K = RAG_TOP_K app.state.config = AppConfig()
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH app.state.config.TOP_K = RAG_TOP_K
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
) )
app.state.CHUNK_SIZE = CHUNK_SIZE app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None app.state.YOUTUBE_LOADER_TRANSLATION = None
...@@ -133,7 +137,7 @@ def update_embedding_model( ...@@ -133,7 +137,7 @@ def update_embedding_model(
embedding_model: str, embedding_model: str,
update_model: bool = False, update_model: bool = False,
): ):
if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "": if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer( app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model), get_model_path(embedding_model, update_model),
device=DEVICE_TYPE, device=DEVICE_TYPE,
...@@ -158,22 +162,22 @@ def update_reranking_model( ...@@ -158,22 +162,22 @@ def update_reranking_model(
update_embedding_model( update_embedding_model(
app.state.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE, RAG_EMBEDDING_MODEL_AUTO_UPDATE,
) )
update_reranking_model( update_reranking_model(
app.state.RAG_RERANKING_MODEL, app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE, RAG_RERANKING_MODEL_AUTO_UPDATE,
) )
app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY, app.state.config.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL, app.state.config.OPENAI_API_BASE_URL,
) )
origins = ["*"] origins = ["*"]
...@@ -200,12 +204,12 @@ class UrlForm(CollectionNameForm): ...@@ -200,12 +204,12 @@ class UrlForm(CollectionNameForm):
async def get_status(): async def get_status():
return { return {
"status": True, "status": True,
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.config.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE, "template": app.state.config.RAG_TEMPLATE,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.RAG_RERANKING_MODEL, "reranking_model": app.state.config.RAG_RERANKING_MODEL,
} }
...@@ -213,18 +217,21 @@ async def get_status(): ...@@ -213,18 +217,21 @@ async def get_status():
async def get_embedding_config(user=Depends(get_admin_user)): async def get_embedding_config(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": { "openai_config": {
"url": app.state.OPENAI_API_BASE_URL, "url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY, "key": app.state.config.OPENAI_API_KEY,
}, },
} }
@app.get("/reranking") @app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)): async def get_reraanking_config(user=Depends(get_admin_user)):
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL} return {
"status": True,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
class OpenAIConfigForm(BaseModel): class OpenAIConfigForm(BaseModel):
...@@ -243,34 +250,34 @@ async def update_embedding_config( ...@@ -243,34 +250,34 @@ async def update_embedding_config(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user) form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
): ):
log.info( log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}" f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
) )
try: try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]: if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config != None: if form_data.openai_config != None:
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key app.state.config.OPENAI_API_KEY = form_data.openai_config.key
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True) update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
app.state.EMBEDDING_FUNCTION = get_embedding_function( app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY, app.state.config.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL, app.state.config.OPENAI_API_BASE_URL,
) )
return { return {
"status": True, "status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE, "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL, "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": { "openai_config": {
"url": app.state.OPENAI_API_BASE_URL, "url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY, "key": app.state.config.OPENAI_API_KEY,
}, },
} }
except Exception as e: except Exception as e:
...@@ -290,16 +297,16 @@ async def update_reranking_config( ...@@ -290,16 +297,16 @@ async def update_reranking_config(
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user) form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
): ):
log.info( log.info(
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}" f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
) )
try: try:
app.state.RAG_RERANKING_MODEL = form_data.reranking_model app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(app.state.RAG_RERANKING_MODEL, True) update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
return { return {
"status": True, "status": True,
"reranking_model": app.state.RAG_RERANKING_MODEL, "reranking_model": app.state.config.RAG_RERANKING_MODEL,
} }
except Exception as e: except Exception as e:
log.exception(f"Problem updating reranking model: {e}") log.exception(f"Problem updating reranking model: {e}")
...@@ -313,14 +320,14 @@ async def update_reranking_config( ...@@ -313,14 +320,14 @@ async def update_reranking_config(
async def get_rag_config(user=Depends(get_admin_user)): async def get_rag_config(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": { "chunk": {
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.config.CHUNK_OVERLAP,
}, },
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": { "youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE, "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION, "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
}, },
} }
...@@ -345,50 +352,52 @@ class ConfigUpdateForm(BaseModel): ...@@ -345,50 +352,52 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update") @app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)): async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = ( app.state.config.PDF_EXTRACT_IMAGES = (
form_data.pdf_extract_images form_data.pdf_extract_images
if form_data.pdf_extract_images != None if form_data.pdf_extract_images is not None
else app.state.PDF_EXTRACT_IMAGES else app.state.config.PDF_EXTRACT_IMAGES
) )
app.state.CHUNK_SIZE = ( app.state.config.CHUNK_SIZE = (
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE form_data.chunk.chunk_size
if form_data.chunk is not None
else app.state.config.CHUNK_SIZE
) )
app.state.CHUNK_OVERLAP = ( app.state.config.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap form_data.chunk.chunk_overlap
if form_data.chunk != None if form_data.chunk is not None
else app.state.CHUNK_OVERLAP else app.state.config.CHUNK_OVERLAP
) )
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = ( app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None if form_data.web_loader_ssl_verification != None
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
) )
app.state.YOUTUBE_LOADER_LANGUAGE = ( app.state.config.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language form_data.youtube.language
if form_data.youtube != None if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_LANGUAGE else app.state.config.YOUTUBE_LOADER_LANGUAGE
) )
app.state.YOUTUBE_LOADER_TRANSLATION = ( app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation form_data.youtube.translation
if form_data.youtube != None if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_TRANSLATION else app.state.YOUTUBE_LOADER_TRANSLATION
) )
return { return {
"status": True, "status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES, "pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": { "chunk": {
"chunk_size": app.state.CHUNK_SIZE, "chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP, "chunk_overlap": app.state.config.CHUNK_OVERLAP,
}, },
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION, "web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": { "youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE, "language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION, "translation": app.state.YOUTUBE_LOADER_TRANSLATION,
}, },
} }
...@@ -398,7 +407,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_ ...@@ -398,7 +407,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
async def get_rag_template(user=Depends(get_current_user)): async def get_rag_template(user=Depends(get_current_user)):
return { return {
"status": True, "status": True,
"template": app.state.RAG_TEMPLATE, "template": app.state.config.RAG_TEMPLATE,
} }
...@@ -406,10 +415,10 @@ async def get_rag_template(user=Depends(get_current_user)): ...@@ -406,10 +415,10 @@ async def get_rag_template(user=Depends(get_current_user)):
async def get_query_settings(user=Depends(get_admin_user)): async def get_query_settings(user=Depends(get_admin_user)):
return { return {
"status": True, "status": True,
"template": app.state.RAG_TEMPLATE, "template": app.state.config.RAG_TEMPLATE,
"k": app.state.TOP_K, "k": app.state.config.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD, "r": app.state.config.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
} }
...@@ -424,16 +433,20 @@ class QuerySettingsForm(BaseModel): ...@@ -424,16 +433,20 @@ class QuerySettingsForm(BaseModel):
async def update_query_settings( async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user) form_data: QuerySettingsForm, user=Depends(get_admin_user)
): ):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE app.state.config.RAG_TEMPLATE = (
app.state.TOP_K = form_data.k if form_data.k else 4 form_data.template if form_data.template else RAG_TEMPLATE
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0 )
app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False app.state.config.TOP_K = form_data.k if form_data.k else 4
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
form_data.hybrid if form_data.hybrid else False
)
return { return {
"status": True, "status": True,
"template": app.state.RAG_TEMPLATE, "template": app.state.config.RAG_TEMPLATE,
"k": app.state.TOP_K, "k": app.state.config.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD, "r": app.state.config.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH, "hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
} }
...@@ -451,21 +464,23 @@ def query_doc_handler( ...@@ -451,21 +464,23 @@ def query_doc_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.ENABLE_RAG_HYBRID_SEARCH: if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_doc_with_hybrid_search( return query_doc_with_hybrid_search(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf, reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, r=(
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
),
) )
else: else:
return query_doc( return query_doc(
collection_name=form_data.collection_name, collection_name=form_data.collection_name,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.config.TOP_K,
) )
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
...@@ -489,21 +504,23 @@ def query_collection_handler( ...@@ -489,21 +504,23 @@ def query_collection_handler(
user=Depends(get_current_user), user=Depends(get_current_user),
): ):
try: try:
if app.state.ENABLE_RAG_HYBRID_SEARCH: if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_collection_with_hybrid_search( return query_collection_with_hybrid_search(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf, reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD, r=(
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
),
) )
else: else:
return query_collection( return query_collection(
collection_names=form_data.collection_names, collection_names=form_data.collection_names,
query=form_data.query, query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION, embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K, k=form_data.k if form_data.k else app.state.config.TOP_K,
) )
except Exception as e: except Exception as e:
...@@ -520,7 +537,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)): ...@@ -520,7 +537,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
loader = YoutubeLoader.from_youtube_url( loader = YoutubeLoader.from_youtube_url(
form_data.url, form_data.url,
add_video_info=True, add_video_info=True,
language=app.state.YOUTUBE_LOADER_LANGUAGE, language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
translation=app.state.YOUTUBE_LOADER_TRANSLATION, translation=app.state.YOUTUBE_LOADER_TRANSLATION,
) )
data = loader.load() data = loader.load()
...@@ -548,7 +565,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)): ...@@ -548,7 +565,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm" # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try: try:
loader = get_web_loader( loader = get_web_loader(
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION form_data.url,
verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
) )
data = loader.load() data = loader.load()
...@@ -604,8 +622,8 @@ def resolve_hostname(hostname): ...@@ -604,8 +622,8 @@ def resolve_hostname(hostname):
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool: def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP, chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True, add_start_index=True,
) )
...@@ -622,8 +640,8 @@ def store_text_in_vector_db( ...@@ -622,8 +640,8 @@ def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False text, metadata, collection_name, overwrite: bool = False
) -> bool: ) -> bool:
text_splitter = RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP, chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True, add_start_index=True,
) )
docs = text_splitter.create_documents([text], metadatas=[metadata]) docs = text_splitter.create_documents([text], metadatas=[metadata])
...@@ -646,11 +664,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b ...@@ -646,11 +664,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name) collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = get_embedding_function( embedding_func = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE, app.state.config.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL, app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef, app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY, app.state.config.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL, app.state.config.OPENAI_API_BASE_URL,
) )
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts)) embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
...@@ -724,7 +742,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str): ...@@ -724,7 +742,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
] ]
if file_ext == "pdf": if file_ext == "pdf":
loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES) loader = PyPDFLoader(
file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
)
elif file_ext == "csv": elif file_ext == "csv":
loader = CSVLoader(file_path) loader = CSVLoader(file_path)
elif file_ext == "rst": elif file_ext == "rst":
...@@ -932,3 +952,14 @@ def reset(user=Depends(get_admin_user)) -> bool: ...@@ -932,3 +952,14 @@ def reset(user=Depends(get_admin_user)) -> bool:
log.exception(e) log.exception(e)
return True return True
if ENV == "dev":
@app.get("/ef")
async def get_embeddings():
return {"result": app.state.EMBEDDING_FUNCTION("hello world")}
@app.get("/ef/{text}")
async def get_embeddings_text(text: str):
return {"result": app.state.EMBEDDING_FUNCTION(text)}
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
@migrator.create_model
class Memory(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
content = pw.TextField(null=False)
updated_at = pw.BigIntegerField(null=False)
created_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "memory"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("memory")
...@@ -9,6 +9,7 @@ from apps.web.routers import ( ...@@ -9,6 +9,7 @@ from apps.web.routers import (
modelfiles, modelfiles,
prompts, prompts,
configs, configs,
memories,
utils, utils,
) )
from config import ( from config import (
...@@ -21,22 +22,27 @@ from config import ( ...@@ -21,22 +22,27 @@ from config import (
USER_PERMISSIONS, USER_PERMISSIONS,
WEBHOOK_URL, WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
AppConfig,
) )
app = FastAPI() app = FastAPI()
origins = ["*"] origins = ["*"]
app.state.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config = AppConfig()
app.state.JWT_EXPIRES_IN = "-1"
app.state.DEFAULT_MODELS = DEFAULT_MODELS app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.WEBHOOK_URL = WEBHOOK_URL app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=origins, allow_origins=origins,
...@@ -48,9 +54,12 @@ app.add_middleware( ...@@ -48,9 +54,12 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"]) app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(configs.router, prefix="/configs", tags=["configs"]) app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])
...@@ -61,6 +70,6 @@ async def get_status(): ...@@ -61,6 +70,6 @@ async def get_status():
return { return {
"status": True, "status": True,
"auth": WEBUI_AUTH, "auth": WEBUI_AUTH,
"default_models": app.state.DEFAULT_MODELS, "default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
} }
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from apps.web.internal.db import DB
from apps.web.models.chats import Chats
import time
import uuid
####################
# Memory DB Schema
####################
class Memory(Model):
id = CharField(unique=True)
user_id = CharField()
content = TextField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class MemoryModel(BaseModel):
id: str
user_id: str
content: str
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class MemoriesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Memory])
def insert_new_memory(
self,
user_id: str,
content: str,
) -> Optional[MemoryModel]:
id = str(uuid.uuid4())
memory = MemoryModel(
**{
"id": id,
"user_id": user_id,
"content": content,
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = Memory.create(**memory.model_dump())
if result:
return memory
else:
return None
def get_memories(self) -> List[MemoryModel]:
try:
memories = Memory.select()
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
except:
return None
def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
try:
memories = Memory.select().where(Memory.user_id == user_id)
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
except:
return None
def get_memory_by_id(self, id) -> Optional[MemoryModel]:
try:
memory = Memory.get(Memory.id == id)
return MemoryModel(**model_to_dict(memory))
except:
return None
def delete_memory_by_id(self, id: str) -> bool:
try:
query = Memory.delete().where(Memory.id == id)
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
def delete_memories_by_user_id(self, user_id: str) -> bool:
try:
query = Memory.delete().where(Memory.user_id == user_id)
query.execute()
return True
except:
return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id)
query.execute()
return True
except:
return False
Memories = MemoriesTable(DB)
...@@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm): ...@@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
if user: if user:
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
) )
return { return {
...@@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm): ...@@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse) @router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, form_data: SignupForm): async def signup(request: Request, form_data: SignupForm):
if not request.app.state.ENABLE_SIGNUP and WEBUI_AUTH: if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
raise HTTPException( raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
) )
...@@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
role = ( role = (
"admin" "admin"
if Users.get_num_users() == 0 if Users.get_num_users() == 0
else request.app.state.DEFAULT_USER_ROLE else request.app.state.config.DEFAULT_USER_ROLE
) )
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth( user = Auths.insert_new_auth(
...@@ -194,13 +194,13 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -194,13 +194,13 @@ async def signup(request: Request, form_data: SignupForm):
if user: if user:
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN), expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
) )
# response.set_cookie(key='token', value=token, httponly=True) # response.set_cookie(key='token', value=token, httponly=True)
if request.app.state.WEBHOOK_URL: if request.app.state.config.WEBHOOK_URL:
post_webhook( post_webhook(
request.app.state.WEBHOOK_URL, request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name), WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{ {
"action": "signup", "action": "signup",
...@@ -276,13 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)): ...@@ -276,13 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
@router.get("/signup/enabled", response_model=bool) @router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)): async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
return request.app.state.ENABLE_SIGNUP return request.app.state.config.ENABLE_SIGNUP
@router.get("/signup/enabled/toggle", response_model=bool) @router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
return request.app.state.ENABLE_SIGNUP return request.app.state.config.ENABLE_SIGNUP
############################ ############################
...@@ -292,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)): ...@@ -292,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
@router.get("/signup/user/role") @router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)): async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
return request.app.state.DEFAULT_USER_ROLE return request.app.state.config.DEFAULT_USER_ROLE
class UpdateRoleForm(BaseModel): class UpdateRoleForm(BaseModel):
...@@ -304,8 +304,8 @@ async def update_default_user_role( ...@@ -304,8 +304,8 @@ async def update_default_user_role(
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user) request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
): ):
if form_data.role in ["pending", "user", "admin"]: if form_data.role in ["pending", "user", "admin"]:
request.app.state.DEFAULT_USER_ROLE = form_data.role request.app.state.config.DEFAULT_USER_ROLE = form_data.role
return request.app.state.DEFAULT_USER_ROLE return request.app.state.config.DEFAULT_USER_ROLE
############################ ############################
...@@ -315,7 +315,7 @@ async def update_default_user_role( ...@@ -315,7 +315,7 @@ async def update_default_user_role(
@router.get("/token/expires") @router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)): async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return request.app.state.JWT_EXPIRES_IN return request.app.state.config.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel): class UpdateJWTExpiresDurationForm(BaseModel):
...@@ -332,10 +332,10 @@ async def update_token_expires_duration( ...@@ -332,10 +332,10 @@ async def update_token_expires_duration(
# Check if the input string matches the pattern # Check if the input string matches the pattern
if re.match(pattern, form_data.duration): if re.match(pattern, form_data.duration):
request.app.state.JWT_EXPIRES_IN = form_data.duration request.app.state.config.JWT_EXPIRES_IN = form_data.duration
return request.app.state.JWT_EXPIRES_IN return request.app.state.config.JWT_EXPIRES_IN
else: else:
return request.app.state.JWT_EXPIRES_IN return request.app.state.config.JWT_EXPIRES_IN
############################ ############################
......
...@@ -58,7 +58,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user) ...@@ -58,7 +58,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
if ( if (
user.role == "user" user.role == "user"
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"] and not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]
): ):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
...@@ -266,7 +266,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_ ...@@ -266,7 +266,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
result = Chats.delete_chat_by_id(id) result = Chats.delete_chat_by_id(id)
return result return result
else: else:
if not request.app.state.USER_PERMISSIONS["chat"]["deletion"]: if not request.app.state.config.USER_PERMISSIONS["chat"]["deletion"]:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
......
...@@ -44,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel): ...@@ -44,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel):
async def set_global_default_models( async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user) request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
): ):
request.app.state.DEFAULT_MODELS = form_data.models request.app.state.config.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS return request.app.state.config.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion]) @router.post("/default/suggestions", response_model=List[PromptSuggestion])
...@@ -55,5 +55,5 @@ async def set_global_default_suggestions( ...@@ -55,5 +55,5 @@ async def set_global_default_suggestions(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
data = form_data.model_dump() data = form_data.model_dump()
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"] request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import logging
from apps.web.models.memories import Memories, MemoryModel
from utils.utils import get_verified_user
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
@router.get("/ef")
async def get_embeddings(request: Request):
return {"result": request.app.state.EMBEDDING_FUNCTION("hello world")}
############################
# GetMemories
############################
@router.get("/", response_model=List[MemoryModel])
async def get_memories(user=Depends(get_verified_user)):
return Memories.get_memories_by_user_id(user.id)
############################
# AddMemory
############################
class AddMemoryForm(BaseModel):
content: str
@router.post("/add", response_model=Optional[MemoryModel])
async def add_memory(
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
):
memory = Memories.insert_new_memory(user.id, form_data.content)
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
collection.upsert(
documents=[memory.content],
ids=[memory.id],
embeddings=[memory_embedding],
metadatas=[{"created_at": memory.created_at}],
)
return memory
############################
# QueryMemory
############################
class QueryMemoryForm(BaseModel):
content: str
@router.post("/query")
async def query_memory(
request: Request, form_data: QueryMemoryForm, user=Depends(get_verified_user)
):
query_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
results = collection.query(
query_embeddings=[query_embedding],
n_results=1, # how many results to return
)
return results
############################
# ResetMemoryFromVectorDB
############################
@router.get("/reset", response_model=bool)
async def reset_memory_from_vector_db(
request: Request, user=Depends(get_verified_user)
):
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
collection = CHROMA_CLIENT.get_or_create_collection(name=f"user-memory-{user.id}")
memories = Memories.get_memories_by_user_id(user.id)
for memory in memories:
memory_embedding = request.app.state.EMBEDDING_FUNCTION(memory.content)
collection.upsert(
documents=[memory.content],
ids=[memory.id],
embeddings=[memory_embedding],
)
return True
############################
# DeleteMemoriesByUserId
############################
@router.delete("/user", response_model=bool)
async def delete_memory_by_user_id(user=Depends(get_verified_user)):
result = Memories.delete_memories_by_user_id(user.id)
if result:
try:
CHROMA_CLIENT.delete_collection(f"user-memory-{user.id}")
except Exception as e:
log.error(e)
return True
return False
############################
# DeleteMemoryById
############################
@router.delete("/{memory_id}", response_model=bool)
async def delete_memory_by_id(memory_id: str, user=Depends(get_verified_user)):
result = Memories.delete_memory_by_id_and_user_id(memory_id, user.id)
if result:
collection = CHROMA_CLIENT.get_or_create_collection(
name=f"user-memory-{user.id}"
)
collection.delete(ids=[memory_id])
return True
return False
...@@ -11,8 +11,9 @@ import logging ...@@ -11,8 +11,9 @@ import logging
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths from apps.web.models.auths import Auths
from apps.web.models.chats import Chats
from utils.utils import get_current_user, get_password_hash, get_admin_user from utils.utils import get_verified_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS from config import SRC_LOG_LEVELS
...@@ -39,15 +40,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user) ...@@ -39,15 +40,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
@router.get("/permissions/user") @router.get("/permissions/user")
async def get_user_permissions(request: Request, user=Depends(get_admin_user)): async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
return request.app.state.USER_PERMISSIONS return request.app.state.config.USER_PERMISSIONS
@router.post("/permissions/user") @router.post("/permissions/user")
async def update_user_permissions( async def update_user_permissions(
request: Request, form_data: dict, user=Depends(get_admin_user) request: Request, form_data: dict, user=Depends(get_admin_user)
): ):
request.app.state.USER_PERMISSIONS = form_data request.app.state.config.USER_PERMISSIONS = form_data
return request.app.state.USER_PERMISSIONS return request.app.state.config.USER_PERMISSIONS
############################ ############################
...@@ -67,6 +68,41 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin ...@@ -67,6 +68,41 @@ async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin
) )
############################
# GetUserById
############################
class UserResponse(BaseModel):
name: str
profile_image_url: str
@router.get("/{user_id}", response_model=UserResponse)
async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
if user_id.startswith("shared-"):
chat_id = user_id.replace("shared-", "")
chat = Chats.get_chat_by_id(chat_id)
if chat:
user_id = chat.user_id
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
user = Users.get_user_by_id(user_id)
if user:
return UserResponse(name=user.name, profile_image_url=user.profile_image_url)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################ ############################
# UpdateUserById # UpdateUserById
############################ ############################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment