Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
open-webui
Commits
564a3a29
Unverified
Commit
564a3a29
authored
May 14, 2024
by
Timothy Jaeryang Baek
Committed by
GitHub
May 14, 2024
Browse files
Merge branch 'dev' into fix/handlebars-harden
parents
a789b785
2290eefc
Changes
57
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
999 additions
and
427 deletions
+999
-427
.github/pull_request_template.md
.github/pull_request_template.md
+32
-19
.prettierignore
.prettierignore
+302
-6
Dockerfile
Dockerfile
+4
-2
backend/apps/audio/main.py
backend/apps/audio/main.py
+20
-19
backend/apps/images/main.py
backend/apps/images/main.py
+89
-68
backend/apps/litellm/main.py
backend/apps/litellm/main.py
+11
-8
backend/apps/ollama/main.py
backend/apps/ollama/main.py
+29
-25
backend/apps/openai/main.py
backend/apps/openai/main.py
+28
-19
backend/apps/rag/main.py
backend/apps/rag/main.py
+126
-107
backend/apps/web/main.py
backend/apps/web/main.py
+13
-9
backend/apps/web/routers/auths.py
backend/apps/web/routers/auths.py
+16
-16
backend/apps/web/routers/chats.py
backend/apps/web/routers/chats.py
+2
-2
backend/apps/web/routers/configs.py
backend/apps/web/routers/configs.py
+4
-4
backend/apps/web/routers/users.py
backend/apps/web/routers/users.py
+3
-3
backend/config.py
backend/config.py
+278
-73
backend/main.py
backend/main.py
+42
-39
docker-compose.api.yaml
docker-compose.api.yaml
+0
-2
docker-compose.data.yaml
docker-compose.data.yaml
+0
-2
docker-compose.gpu.yaml
docker-compose.gpu.yaml
+0
-2
docker-compose.yaml
docker-compose.yaml
+0
-2
No files found.
.github/pull_request_template.md
View file @
564a3a29
#
# Pull Request Checklist
# Pull Request Checklist
-
[ ]
**Target branch:**
Pull requests should target the
`dev`
branch.
-
[ ]
**Description:**
Briefly describe the changes in this pull request.
...
...
@@ -7,32 +7,46 @@
-
[ ]
**Dependencies:**
Are there any new dependencies? Have you updated the dependency versions in the documentation?
-
[ ]
**Testing:**
Have you written and run sufficient tests for the changes?
-
[ ]
**Code Review:**
Have you self-reviewed your code and addressed any coding standard issues?
---
## Description
[Insert a brief description of the changes made in this pull request, including any relevant motivation and impact.]
---
### Changelog Entry
-
[ ]
**Label title:**
Ensure the pull request title is labeled properly using one of the following:
-
**BREAKING CHANGE**
: Significant changes that may affect compatibility
-
**build**
: Changes that affect the build system or external dependencies
-
**ci**
: Changes to our continuous integration processes or workflows
-
**chore**
: Refactor, cleanup, or other non-functional code changes
-
**docs**
: Documentation update or addition
-
**feat**
: Introduces a new feature or enhancement to the codebase
-
**fix**
: Bug fix or error correction
-
**i18n**
: Internationalization or localization changes
-
**perf**
: Performance improvement
-
**refactor**
: Code restructuring for better maintainability, readability, or scalability
-
**style**
: Changes that do not affect the meaning of the code (white-space, formatting, missing semi-colons, etc.)
-
**test**
: Adding missing tests or correcting existing tests
-
**WIP**
: Work in progress, a temporary label for incomplete or ongoing work
# Changelog Entry
### Description
-
[Briefly describe the changes made in this pull request, including any relevant motivation and impact.]
### Added
-
[List any new features, functionalities, or additions]
### Fixed
-
[List any fixes, corrections, or bug fixes]
### Changed
-
[List any changes, updates, refactorings, or optimizations]
### Deprecated
-
[List any deprecated functionality or features that have been removed]
### Removed
-
[List any removed features, files, or deprecated functionalities]
-
[List any removed features, files, or functionalities]
### Fixed
-
[List any fixes, corrections, or bug fixes]
### Security
...
...
@@ -40,12 +54,11 @@
### Breaking Changes
-
[List any breaking changes affecting compatibility or functionality]
-
**BREAKING CHANGE**
:
[List any breaking changes affecting compatibility or functionality]
---
### Additional Information
-
[Insert any additional context, notes, or explanations for the changes]
-
[Reference any related issues, commits, or other relevant information]
-
[Reference any related issues, commits, or other relevant information]
.prettierignore
View file @
564a3a29
# Ignore files for PNPM, NPM and YARN
pnpm-lock.yaml
package-lock.json
yarn.lock
kubernetes/
# Copy of .gitignore
.DS_Store
node_modules
/build
...
...
@@ -6,11 +14,299 @@ node_modules
.env
.env.*
!.env.example
vite.config.js.timestamp-*
vite.config.ts.timestamp-*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# Ignore files for PNPM, NPM and YARN
pnpm-lock.yaml
package-lock.json
yarn.lock
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
.pnpm-debug.log*
# Diagnostic reports (https://nodejs.org/api/report.html)
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
*.lcov
# nyc test coverage
.nyc_output
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
.grunt
# Bower dependency directory (https://bower.io/)
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons (https://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# Snowpack dependency directory (https://snowpack.dev/)
web_modules/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Optional stylelint cache
.stylelintcache
# Microbundle cache
.rpt2_cache/
.rts2_cache_cjs/
.rts2_cache_es/
.rts2_cache_umd/
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variable files
.env
.env.development.local
.env.test.local
.env.production.local
.env.local
# parcel-bundler cache (https://parceljs.org/)
.cache
.parcel-cache
# Next.js build output
.next
out
# Nuxt.js build / generate output
.nuxt
dist
# Gatsby files
.cache/
# Comment in the public line in if your project uses Gatsby and not Next.js
# https://nextjs.org/blog/next-9-1#public-directory-support
# public
# vuepress build output
.vuepress/dist
# vuepress v2.x temp and cache directory
.temp
.cache
# Docusaurus cache and generated files
.docusaurus
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# TernJS port file
.tern-port
# Stores VSCode versions used for testing VSCode extensions
.vscode-test
# yarn v2
.yarn/cache
.yarn/unplugged
.yarn/build-state.yml
.yarn/install-state.gz
.pnp.*
# Ignore kubernetes files
kubernetes
\ No newline at end of file
# cypress artifacts
cypress/videos
cypress/screenshots
Dockerfile
View file @
564a3a29
...
...
@@ -82,7 +82,7 @@ RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry
RUN if
[
"
$USE_OLLAMA
"
=
"true"
]
;
then
\
apt-get update
&&
\
# Install pandoc and netcat
apt-get install -y --no-install-recommends pandoc netcat-openbsd && \
apt-get install -y --no-install-recommends pandoc netcat-openbsd
curl
&& \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# install helper tools
...
...
@@ -94,7 +94,7 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \
else \
apt-get update && \
# Install pandoc and netcat
apt-get install -y --no-install-recommends pandoc netcat-openbsd && \
apt-get install -y --no-install-recommends pandoc netcat-openbsd
curl
&& \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# cleanup
...
...
@@ -134,4 +134,6 @@ COPY ./backend .
EXPOSE
8080
HEALTHCHECK
CMD curl --fail http://localhost:8080 || exit 1
CMD
[ "bash", "start.sh"]
backend/apps/audio/main.py
View file @
564a3a29
...
...
@@ -45,6 +45,7 @@ from config import (
AUDIO_OPENAI_API_KEY
,
AUDIO_OPENAI_API_MODEL
,
AUDIO_OPENAI_API_VOICE
,
AppConfig
,
)
log
=
logging
.
getLogger
(
__name__
)
...
...
@@ -59,11 +60,11 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
app
.
state
.
OPENAI_API_BASE_URL
=
AUDIO_OPENAI_API_BASE_URL
app
.
state
.
OPENAI_API_KEY
=
AUDIO_OPENAI_API_KEY
app
.
state
.
OPENAI_API_MODEL
=
AUDIO_OPENAI_API_MODEL
app
.
state
.
OPENAI_API_VOICE
=
AUDIO_OPENAI_API_VOICE
app
.
state
.
config
=
AppConfig
()
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
AUDIO_OPENAI_API_BASE_URL
app
.
state
.
config
.
OPENAI_API_KEY
=
AUDIO_OPENAI_API_KEY
app
.
state
.
config
.
OPENAI_API_MODEL
=
AUDIO_OPENAI_API_MODEL
app
.
state
.
config
.
OPENAI_API_VOICE
=
AUDIO_OPENAI_API_VOICE
# setting device type for whisper model
whisper_device_type
=
DEVICE_TYPE
if
DEVICE_TYPE
and
DEVICE_TYPE
==
"cuda"
else
"cpu"
...
...
@@ -83,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel):
@
app
.
get
(
"/config"
)
async
def
get_openai_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"OPENAI_API_BASE_URL"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
OPENAI_API_KEY
,
"OPENAI_API_MODEL"
:
app
.
state
.
OPENAI_API_MODEL
,
"OPENAI_API_VOICE"
:
app
.
state
.
OPENAI_API_VOICE
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
"OPENAI_API_MODEL"
:
app
.
state
.
config
.
OPENAI_API_MODEL
,
"OPENAI_API_VOICE"
:
app
.
state
.
config
.
OPENAI_API_VOICE
,
}
...
...
@@ -97,17 +98,17 @@ async def update_openai_config(
if
form_data
.
key
==
""
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
app
.
state
.
OPENAI_API_BASE_URL
=
form_data
.
url
app
.
state
.
OPENAI_API_KEY
=
form_data
.
key
app
.
state
.
OPENAI_API_MODEL
=
form_data
.
model
app
.
state
.
OPENAI_API_VOICE
=
form_data
.
speaker
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
form_data
.
url
app
.
state
.
config
.
OPENAI_API_KEY
=
form_data
.
key
app
.
state
.
config
.
OPENAI_API_MODEL
=
form_data
.
model
app
.
state
.
config
.
OPENAI_API_VOICE
=
form_data
.
speaker
return
{
"status"
:
True
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
OPENAI_API_KEY
,
"OPENAI_API_MODEL"
:
app
.
state
.
OPENAI_API_MODEL
,
"OPENAI_API_VOICE"
:
app
.
state
.
OPENAI_API_VOICE
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
"OPENAI_API_MODEL"
:
app
.
state
.
config
.
OPENAI_API_MODEL
,
"OPENAI_API_VOICE"
:
app
.
state
.
config
.
OPENAI_API_VOICE
,
}
...
...
@@ -124,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return
FileResponse
(
file_path
)
headers
=
{}
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
OPENAI_API_KEY
}
"
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
config
.
OPENAI_API_KEY
}
"
headers
[
"Content-Type"
]
=
"application/json"
r
=
None
try
:
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
OPENAI_API_BASE_URL
}
/audio/speech"
,
url
=
f
"
{
app
.
state
.
config
.
OPENAI_API_BASE_URL
}
/audio/speech"
,
data
=
body
,
headers
=
headers
,
stream
=
True
,
...
...
backend/apps/images/main.py
View file @
564a3a29
...
...
@@ -42,6 +42,7 @@ from config import (
IMAGE_GENERATION_MODEL
,
IMAGE_SIZE
,
IMAGE_STEPS
,
AppConfig
,
)
...
...
@@ -60,26 +61,31 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
app
.
state
.
ENGINE
=
IMAGE_GENERATION_ENGINE
app
.
state
.
ENABLED
=
ENABLE_IMAGE_GENERATION
app
.
state
.
config
=
AppConfig
()
app
.
state
.
OPENAI_API_BASE_URL
=
IMAGES_OPENAI_API_BASE_URL
app
.
state
.
OPENAI_API_KEY
=
IMAGES_OPENAI_API_KEY
app
.
state
.
config
.
ENGINE
=
IMAGE_GENERATION_ENGINE
app
.
state
.
config
.
ENABLED
=
ENABLE_IMAGE_GENERATION
app
.
state
.
MODEL
=
IMAGE_GENERATION_MODEL
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
IMAGES_OPENAI_API_BASE_URL
app
.
state
.
config
.
OPENAI_API_KEY
=
IMAGES_OPENAI_API_KEY
app
.
state
.
config
.
MODEL
=
IMAGE_GENERATION_MODEL
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
config
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
app
.
state
.
IMAGE_SIZE
=
IMAGE_SIZE
app
.
state
.
IMAGE_STEPS
=
IMAGE_STEPS
app
.
state
.
config
.
IMAGE_SIZE
=
IMAGE_SIZE
app
.
state
.
config
.
IMAGE_STEPS
=
IMAGE_STEPS
@
app
.
get
(
"/config"
)
async
def
get_config
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
{
"engine"
:
app
.
state
.
ENGINE
,
"enabled"
:
app
.
state
.
ENABLED
}
return
{
"engine"
:
app
.
state
.
config
.
ENGINE
,
"enabled"
:
app
.
state
.
config
.
ENABLED
,
}
class
ConfigUpdateForm
(
BaseModel
):
...
...
@@ -89,9 +95,12 @@ class ConfigUpdateForm(BaseModel):
@
app
.
post
(
"/config/update"
)
async
def
update_config
(
form_data
:
ConfigUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
ENGINE
=
form_data
.
engine
app
.
state
.
ENABLED
=
form_data
.
enabled
return
{
"engine"
:
app
.
state
.
ENGINE
,
"enabled"
:
app
.
state
.
ENABLED
}
app
.
state
.
config
.
ENGINE
=
form_data
.
engine
app
.
state
.
config
.
ENABLED
=
form_data
.
enabled
return
{
"engine"
:
app
.
state
.
config
.
ENGINE
,
"enabled"
:
app
.
state
.
config
.
ENABLED
,
}
class
EngineUrlUpdateForm
(
BaseModel
):
...
...
@@ -102,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
@
app
.
get
(
"/url"
)
async
def
get_engine_url
(
user
=
Depends
(
get_admin_user
)):
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
COMFYUI_BASE_URL
,
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
config
.
COMFYUI_BASE_URL
,
}
...
...
@@ -113,29 +122,29 @@ async def update_engine_url(
):
if
form_data
.
AUTOMATIC1111_BASE_URL
==
None
:
app
.
state
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
=
AUTOMATIC1111_BASE_URL
else
:
url
=
form_data
.
AUTOMATIC1111_BASE_URL
.
strip
(
"/"
)
try
:
r
=
requests
.
head
(
url
)
app
.
state
.
AUTOMATIC1111_BASE_URL
=
url
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
=
url
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
if
form_data
.
COMFYUI_BASE_URL
==
None
:
app
.
state
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
app
.
state
.
config
.
COMFYUI_BASE_URL
=
COMFYUI_BASE_URL
else
:
url
=
form_data
.
COMFYUI_BASE_URL
.
strip
(
"/"
)
try
:
r
=
requests
.
head
(
url
)
app
.
state
.
COMFYUI_BASE_URL
=
url
app
.
state
.
config
.
COMFYUI_BASE_URL
=
url
except
Exception
as
e
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
return
{
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
COMFYUI_BASE_URL
,
"AUTOMATIC1111_BASE_URL"
:
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
,
"COMFYUI_BASE_URL"
:
app
.
state
.
config
.
COMFYUI_BASE_URL
,
"status"
:
True
,
}
...
...
@@ -148,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
@
app
.
get
(
"/openai/config"
)
async
def
get_openai_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"OPENAI_API_BASE_URL"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
OPENAI_API_KEY
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
}
...
...
@@ -160,13 +169,13 @@ async def update_openai_config(
if
form_data
.
key
==
""
:
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
API_KEY_NOT_FOUND
)
app
.
state
.
OPENAI_API_BASE_URL
=
form_data
.
url
app
.
state
.
OPENAI_API_KEY
=
form_data
.
key
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
form_data
.
url
app
.
state
.
config
.
OPENAI_API_KEY
=
form_data
.
key
return
{
"status"
:
True
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
OPENAI_API_KEY
,
"OPENAI_API_BASE_URL"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"OPENAI_API_KEY"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
}
...
...
@@ -176,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
@
app
.
get
(
"/size"
)
async
def
get_image_size
(
user
=
Depends
(
get_admin_user
)):
return
{
"IMAGE_SIZE"
:
app
.
state
.
IMAGE_SIZE
}
return
{
"IMAGE_SIZE"
:
app
.
state
.
config
.
IMAGE_SIZE
}
@
app
.
post
(
"/size/update"
)
...
...
@@ -185,9 +194,9 @@ async def update_image_size(
):
pattern
=
r
"^\d+x\d+$"
# Regular expression pattern
if
re
.
match
(
pattern
,
form_data
.
size
):
app
.
state
.
IMAGE_SIZE
=
form_data
.
size
app
.
state
.
config
.
IMAGE_SIZE
=
form_data
.
size
return
{
"IMAGE_SIZE"
:
app
.
state
.
IMAGE_SIZE
,
"IMAGE_SIZE"
:
app
.
state
.
config
.
IMAGE_SIZE
,
"status"
:
True
,
}
else
:
...
...
@@ -203,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
@
app
.
get
(
"/steps"
)
async
def
get_image_size
(
user
=
Depends
(
get_admin_user
)):
return
{
"IMAGE_STEPS"
:
app
.
state
.
IMAGE_STEPS
}
return
{
"IMAGE_STEPS"
:
app
.
state
.
config
.
IMAGE_STEPS
}
@
app
.
post
(
"/steps/update"
)
...
...
@@ -211,9 +220,9 @@ async def update_image_size(
form_data
:
ImageStepsUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
if
form_data
.
steps
>=
0
:
app
.
state
.
IMAGE_STEPS
=
form_data
.
steps
app
.
state
.
config
.
IMAGE_STEPS
=
form_data
.
steps
return
{
"IMAGE_STEPS"
:
app
.
state
.
IMAGE_STEPS
,
"IMAGE_STEPS"
:
app
.
state
.
config
.
IMAGE_STEPS
,
"status"
:
True
,
}
else
:
...
...
@@ -226,14 +235,14 @@ async def update_image_size(
@
app
.
get
(
"/models"
)
def
get_models
(
user
=
Depends
(
get_current_user
)):
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
if
app
.
state
.
config
.
ENGINE
==
"openai"
:
return
[
{
"id"
:
"dall-e-2"
,
"name"
:
"DALL·E 2"
},
{
"id"
:
"dall-e-3"
,
"name"
:
"DALL·E 3"
},
]
elif
app
.
state
.
ENGINE
==
"comfyui"
:
elif
app
.
state
.
config
.
ENGINE
==
"comfyui"
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
COMFYUI_BASE_URL
}
/object_info"
)
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
config
.
COMFYUI_BASE_URL
}
/object_info"
)
info
=
r
.
json
()
return
list
(
...
...
@@ -245,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/sd-models"
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/sd-models"
)
models
=
r
.
json
()
return
list
(
...
...
@@ -255,23 +264,29 @@ def get_models(user=Depends(get_current_user)):
)
)
except
Exception
as
e
:
app
.
state
.
ENABLED
=
False
app
.
state
.
config
.
ENABLED
=
False
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
@
app
.
get
(
"/models/default"
)
async
def
get_default_model
(
user
=
Depends
(
get_admin_user
)):
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
return
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
else
"dall-e-2"
}
elif
app
.
state
.
ENGINE
==
"comfyui"
:
return
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
else
""
}
if
app
.
state
.
config
.
ENGINE
==
"openai"
:
return
{
"model"
:
(
app
.
state
.
config
.
MODEL
if
app
.
state
.
config
.
MODEL
else
"dall-e-2"
)
}
elif
app
.
state
.
config
.
ENGINE
==
"comfyui"
:
return
{
"model"
:
(
app
.
state
.
config
.
MODEL
if
app
.
state
.
config
.
MODEL
else
""
)}
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
return
{
"model"
:
options
[
"sd_model_checkpoint"
]}
except
Exception
as
e
:
app
.
state
.
ENABLED
=
False
app
.
state
.
config
.
ENABLED
=
False
raise
HTTPException
(
status_code
=
400
,
detail
=
ERROR_MESSAGES
.
DEFAULT
(
e
))
...
...
@@ -280,20 +295,20 @@ class UpdateModelForm(BaseModel):
def
set_model_handler
(
model
:
str
):
if
app
.
state
.
ENGINE
==
"openai"
:
app
.
state
.
MODEL
=
model
return
app
.
state
.
MODEL
if
app
.
state
.
ENGINE
==
"comfyui"
:
app
.
state
.
MODEL
=
model
return
app
.
state
.
MODEL
if
app
.
state
.
config
.
ENGINE
in
[
"openai"
,
"comfyui"
]:
app
.
state
.
config
.
MODEL
=
model
return
app
.
state
.
config
.
MODEL
else
:
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
r
=
requests
.
get
(
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
)
options
=
r
.
json
()
if
model
!=
options
[
"sd_model_checkpoint"
]:
options
[
"sd_model_checkpoint"
]
=
model
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
,
json
=
options
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/options"
,
json
=
options
,
)
return
options
...
...
@@ -382,26 +397,32 @@ def generate_image(
user
=
Depends
(
get_current_user
),
):
width
,
height
=
tuple
(
map
(
int
,
app
.
state
.
IMAGE_SIZE
.
split
(
"x"
))
)
width
,
height
=
tuple
(
map
(
int
,
app
.
state
.
config
.
IMAGE_SIZE
)
.
split
(
"x"
))
r
=
None
try
:
if
app
.
state
.
ENGINE
==
"openai"
:
if
app
.
state
.
config
.
ENGINE
==
"openai"
:
headers
=
{}
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
OPENAI_API_KEY
}
"
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
config
.
OPENAI_API_KEY
}
"
headers
[
"Content-Type"
]
=
"application/json"
data
=
{
"model"
:
app
.
state
.
MODEL
if
app
.
state
.
MODEL
!=
""
else
"dall-e-2"
,
"model"
:
(
app
.
state
.
config
.
MODEL
if
app
.
state
.
config
.
MODEL
!=
""
else
"dall-e-2"
),
"prompt"
:
form_data
.
prompt
,
"n"
:
form_data
.
n
,
"size"
:
form_data
.
size
if
form_data
.
size
else
app
.
state
.
IMAGE_SIZE
,
"size"
:
(
form_data
.
size
if
form_data
.
size
else
app
.
state
.
config
.
IMAGE_SIZE
),
"response_format"
:
"b64_json"
,
}
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
OPENAI_API_BASE_URL
}
/images/generations"
,
url
=
f
"
{
app
.
state
.
config
.
OPENAI_API_BASE_URL
}
/images/generations"
,
json
=
data
,
headers
=
headers
,
)
...
...
@@ -421,7 +442,7 @@ def generate_image(
return
images
elif
app
.
state
.
ENGINE
==
"comfyui"
:
elif
app
.
state
.
config
.
ENGINE
==
"comfyui"
:
data
=
{
"prompt"
:
form_data
.
prompt
,
...
...
@@ -430,19 +451,19 @@ def generate_image(
"n"
:
form_data
.
n
,
}
if
app
.
state
.
IMAGE_STEPS
!=
None
:
data
[
"steps"
]
=
app
.
state
.
IMAGE_STEPS
if
app
.
state
.
config
.
IMAGE_STEPS
is
not
None
:
data
[
"steps"
]
=
app
.
state
.
config
.
IMAGE_STEPS
if
form_data
.
negative_prompt
!=
None
:
if
form_data
.
negative_prompt
is
not
None
:
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
data
=
ImageGenerationPayload
(
**
data
)
res
=
comfyui_generate_image
(
app
.
state
.
MODEL
,
app
.
state
.
config
.
MODEL
,
data
,
user
.
id
,
app
.
state
.
COMFYUI_BASE_URL
,
app
.
state
.
config
.
COMFYUI_BASE_URL
,
)
log
.
debug
(
f
"res:
{
res
}
"
)
...
...
@@ -469,14 +490,14 @@ def generate_image(
"height"
:
height
,
}
if
app
.
state
.
IMAGE_STEPS
!=
None
:
data
[
"steps"
]
=
app
.
state
.
IMAGE_STEPS
if
app
.
state
.
config
.
IMAGE_STEPS
is
not
None
:
data
[
"steps"
]
=
app
.
state
.
config
.
IMAGE_STEPS
if
form_data
.
negative_prompt
!=
None
:
if
form_data
.
negative_prompt
is
not
None
:
data
[
"negative_prompt"
]
=
form_data
.
negative_prompt
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/txt2img"
,
url
=
f
"
{
app
.
state
.
config
.
AUTOMATIC1111_BASE_URL
}
/sdapi/v1/txt2img"
,
json
=
data
,
)
...
...
backend/apps/litellm/main.py
View file @
564a3a29
import
sys
from
contextlib
import
asynccontextmanager
from
fastapi
import
FastAPI
,
Depends
,
HTTPException
from
fastapi.routing
import
APIRoute
...
...
@@ -46,7 +47,16 @@ import asyncio
import
subprocess
import
yaml
app
=
FastAPI
()
@
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
log
.
info
(
"startup_event"
)
# TODO: Check config.yaml file and create one
asyncio
.
create_task
(
start_litellm_background
())
yield
app
=
FastAPI
(
lifespan
=
lifespan
)
origins
=
[
"*"
]
...
...
@@ -141,13 +151,6 @@ async def shutdown_litellm_background():
background_process
=
None
@
app
.
on_event
(
"startup"
)
async
def
startup_event
():
log
.
info
(
"startup_event"
)
# TODO: Check config.yaml file and create one
asyncio
.
create_task
(
start_litellm_background
())
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
...
...
backend/apps/ollama/main.py
View file @
564a3a29
...
...
@@ -46,6 +46,7 @@ from config import (
ENABLE_MODEL_FILTER
,
MODEL_FILTER_LIST
,
UPLOAD_DIR
,
AppConfig
,
)
from
utils.misc
import
calculate_sha256
...
...
@@ -61,11 +62,12 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
app
.
state
.
config
=
AppConfig
()
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
OLLAMA_BASE_URLS
=
OLLAMA_BASE_URLS
app
.
state
.
config
.
OLLAMA_BASE_URLS
=
OLLAMA_BASE_URLS
app
.
state
.
MODELS
=
{}
...
...
@@ -96,7 +98,7 @@ async def get_status():
@
app
.
get
(
"/urls"
)
async
def
get_ollama_api_urls
(
user
=
Depends
(
get_admin_user
)):
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
OLLAMA_BASE_URLS
}
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
config
.
OLLAMA_BASE_URLS
}
class
UrlUpdateForm
(
BaseModel
):
...
...
@@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel):
@
app
.
post
(
"/urls/update"
)
async
def
update_ollama_api_url
(
form_data
:
UrlUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
OLLAMA_BASE_URLS
=
form_data
.
urls
app
.
state
.
config
.
OLLAMA_BASE_URLS
=
form_data
.
urls
log
.
info
(
f
"app.state.OLLAMA_BASE_URLS:
{
app
.
state
.
OLLAMA_BASE_URLS
}
"
)
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
OLLAMA_BASE_URLS
}
log
.
info
(
f
"app.state.
config.
OLLAMA_BASE_URLS:
{
app
.
state
.
config
.
OLLAMA_BASE_URLS
}
"
)
return
{
"OLLAMA_BASE_URLS"
:
app
.
state
.
config
.
OLLAMA_BASE_URLS
}
@
app
.
get
(
"/cancel/{request_id}"
)
...
...
@@ -153,7 +155,7 @@ def merge_models_lists(model_lists):
async
def
get_all_models
():
log
.
info
(
"get_all_models()"
)
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/tags"
)
for
url
in
app
.
state
.
OLLAMA_BASE_URLS
]
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/tags"
)
for
url
in
app
.
state
.
config
.
OLLAMA_BASE_URLS
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
models
=
{
...
...
@@ -186,7 +188,7 @@ async def get_ollama_tags(
return
models
return
models
else
:
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
try
:
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/tags"
)
r
.
raise_for_status
()
...
...
@@ -216,7 +218,9 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if
url_idx
==
None
:
# returns lowest version
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/version"
)
for
url
in
app
.
state
.
OLLAMA_BASE_URLS
]
tasks
=
[
fetch_url
(
f
"
{
url
}
/api/version"
)
for
url
in
app
.
state
.
config
.
OLLAMA_BASE_URLS
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
responses
=
list
(
filter
(
lambda
x
:
x
is
not
None
,
responses
))
...
...
@@ -235,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
detail
=
ERROR_MESSAGES
.
OLLAMA_NOT_FOUND
,
)
else
:
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
try
:
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/version"
)
r
.
raise_for_status
()
...
...
@@ -267,7 +271,7 @@ class ModelNameForm(BaseModel):
async
def
pull_model
(
form_data
:
ModelNameForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -355,7 +359,7 @@ async def push_model(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
name
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
debug
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -417,7 +421,7 @@ async def create_model(
form_data
:
CreateModelForm
,
url_idx
:
int
=
0
,
user
=
Depends
(
get_admin_user
)
):
log
.
debug
(
f
"form_data:
{
form_data
}
"
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -490,7 +494,7 @@ async def copy_model(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
source
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
try
:
...
...
@@ -537,7 +541,7 @@ async def delete_model(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
name
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
try
:
...
...
@@ -577,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
)
url_idx
=
random
.
choice
(
app
.
state
.
MODELS
[
form_data
.
name
][
"urls"
])
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
try
:
...
...
@@ -634,7 +638,7 @@ async def generate_embeddings(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
try
:
...
...
@@ -684,7 +688,7 @@ def generate_ollama_embeddings(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
try
:
...
...
@@ -753,7 +757,7 @@ async def generate_completion(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -856,7 +860,7 @@ async def generate_chat_completion(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -965,7 +969,7 @@ async def generate_openai_chat_completion(
detail
=
ERROR_MESSAGES
.
MODEL_NOT_FOUND
(
form_data
.
model
),
)
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
log
.
info
(
f
"url:
{
url
}
"
)
r
=
None
...
...
@@ -1064,7 +1068,7 @@ async def get_openai_models(
}
else
:
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
try
:
r
=
requests
.
request
(
method
=
"GET"
,
url
=
f
"
{
url
}
/api/tags"
)
r
.
raise_for_status
()
...
...
@@ -1198,7 +1202,7 @@ async def download_model(
if
url_idx
==
None
:
url_idx
=
0
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
file_name
=
parse_huggingface_url
(
form_data
.
url
)
...
...
@@ -1217,7 +1221,7 @@ async def download_model(
def
upload_model
(
file
:
UploadFile
=
File
(...),
url_idx
:
Optional
[
int
]
=
None
):
if
url_idx
==
None
:
url_idx
=
0
ollama_url
=
app
.
state
.
OLLAMA_BASE_URLS
[
url_idx
]
ollama_url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
url_idx
]
file_path
=
f
"
{
UPLOAD_DIR
}
/
{
file
.
filename
}
"
...
...
@@ -1282,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.OLLAMA_BASE_URLS[url_idx]
# url = app.state.
config.
OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
...
...
@@ -1319,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
async
def
deprecated_proxy
(
path
:
str
,
request
:
Request
,
user
=
Depends
(
get_verified_user
)
):
url
=
app
.
state
.
OLLAMA_BASE_URLS
[
0
]
url
=
app
.
state
.
config
.
OLLAMA_BASE_URLS
[
0
]
target_url
=
f
"
{
url
}
/
{
path
}
"
body
=
await
request
.
body
()
...
...
backend/apps/openai/main.py
View file @
564a3a29
...
...
@@ -26,6 +26,7 @@ from config import (
CACHE_DIR
,
ENABLE_MODEL_FILTER
,
MODEL_FILTER_LIST
,
AppConfig
,
)
from
typing
import
List
,
Optional
...
...
@@ -45,11 +46,13 @@ app.add_middleware(
allow_headers
=
[
"*"
],
)
app
.
state
.
config
=
AppConfig
()
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
OPENAI_API_BASE_URLS
=
OPENAI_API_BASE_URLS
app
.
state
.
OPENAI_API_KEYS
=
OPENAI_API_KEYS
app
.
state
.
config
.
OPENAI_API_BASE_URLS
=
OPENAI_API_BASE_URLS
app
.
state
.
config
.
OPENAI_API_KEYS
=
OPENAI_API_KEYS
app
.
state
.
MODELS
=
{}
...
...
@@ -75,32 +78,32 @@ class KeysUpdateForm(BaseModel):
@
app
.
get
(
"/urls"
)
async
def
get_openai_urls
(
user
=
Depends
(
get_admin_user
)):
return
{
"OPENAI_API_BASE_URLS"
:
app
.
state
.
OPENAI_API_BASE_URLS
}
return
{
"OPENAI_API_BASE_URLS"
:
app
.
state
.
config
.
OPENAI_API_BASE_URLS
}
@
app
.
post
(
"/urls/update"
)
async
def
update_openai_urls
(
form_data
:
UrlsUpdateForm
,
user
=
Depends
(
get_admin_user
)):
await
get_all_models
()
app
.
state
.
OPENAI_API_BASE_URLS
=
form_data
.
urls
return
{
"OPENAI_API_BASE_URLS"
:
app
.
state
.
OPENAI_API_BASE_URLS
}
app
.
state
.
config
.
OPENAI_API_BASE_URLS
=
form_data
.
urls
return
{
"OPENAI_API_BASE_URLS"
:
app
.
state
.
config
.
OPENAI_API_BASE_URLS
}
@
app
.
get
(
"/keys"
)
async
def
get_openai_keys
(
user
=
Depends
(
get_admin_user
)):
return
{
"OPENAI_API_KEYS"
:
app
.
state
.
OPENAI_API_KEYS
}
return
{
"OPENAI_API_KEYS"
:
app
.
state
.
config
.
OPENAI_API_KEYS
}
@
app
.
post
(
"/keys/update"
)
async
def
update_openai_key
(
form_data
:
KeysUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
OPENAI_API_KEYS
=
form_data
.
keys
return
{
"OPENAI_API_KEYS"
:
app
.
state
.
OPENAI_API_KEYS
}
app
.
state
.
config
.
OPENAI_API_KEYS
=
form_data
.
keys
return
{
"OPENAI_API_KEYS"
:
app
.
state
.
config
.
OPENAI_API_KEYS
}
@
app
.
post
(
"/audio/speech"
)
async
def
speech
(
request
:
Request
,
user
=
Depends
(
get_verified_user
)):
idx
=
None
try
:
idx
=
app
.
state
.
OPENAI_API_BASE_URLS
.
index
(
"https://api.openai.com/v1"
)
idx
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
.
index
(
"https://api.openai.com/v1"
)
body
=
await
request
.
body
()
name
=
hashlib
.
sha256
(
body
).
hexdigest
()
...
...
@@ -114,13 +117,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return
FileResponse
(
file_path
)
headers
=
{}
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
OPENAI_API_KEYS
[
idx
]
}
"
headers
[
"Authorization"
]
=
f
"Bearer
{
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
}
"
headers
[
"Content-Type"
]
=
"application/json"
if
"openrouter.ai"
in
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]:
headers
[
"HTTP-Referer"
]
=
"https://openwebui.com/"
headers
[
"X-Title"
]
=
"Open WebUI"
r
=
None
try
:
r
=
requests
.
post
(
url
=
f
"
{
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]
}
/audio/speech"
,
url
=
f
"
{
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
}
/audio/speech"
,
data
=
body
,
headers
=
headers
,
stream
=
True
,
...
...
@@ -180,7 +185,8 @@ def merge_models_lists(model_lists):
[
{
**
model
,
"urlIdx"
:
idx
}
for
model
in
models
if
"api.openai.com"
not
in
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]
if
"api.openai.com"
not
in
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
or
"gpt"
in
model
[
"id"
]
]
)
...
...
@@ -191,12 +197,15 @@ def merge_models_lists(model_lists):
async
def
get_all_models
():
log
.
info
(
"get_all_models()"
)
if
len
(
app
.
state
.
OPENAI_API_KEYS
)
==
1
and
app
.
state
.
OPENAI_API_KEYS
[
0
]
==
""
:
if
(
len
(
app
.
state
.
config
.
OPENAI_API_KEYS
)
==
1
and
app
.
state
.
config
.
OPENAI_API_KEYS
[
0
]
==
""
):
models
=
{
"data"
:
[]}
else
:
tasks
=
[
fetch_url
(
f
"
{
url
}
/models"
,
app
.
state
.
OPENAI_API_KEYS
[
idx
])
for
idx
,
url
in
enumerate
(
app
.
state
.
OPENAI_API_BASE_URLS
)
fetch_url
(
f
"
{
url
}
/models"
,
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
])
for
idx
,
url
in
enumerate
(
app
.
state
.
config
.
OPENAI_API_BASE_URLS
)
]
responses
=
await
asyncio
.
gather
(
*
tasks
)
...
...
@@ -239,7 +248,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return
models
return
models
else
:
url
=
app
.
state
.
OPENAI_API_BASE_URLS
[
url_idx
]
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
url_idx
]
r
=
None
...
...
@@ -303,8 +312,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except
json
.
JSONDecodeError
as
e
:
log
.
error
(
"Error loading request body into a dictionary:"
,
e
)
url
=
app
.
state
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
OPENAI_API_KEYS
[
idx
]
url
=
app
.
state
.
config
.
OPENAI_API_BASE_URLS
[
idx
]
key
=
app
.
state
.
config
.
OPENAI_API_KEYS
[
idx
]
target_url
=
f
"
{
url
}
/
{
path
}
"
...
...
backend/apps/rag/main.py
View file @
564a3a29
...
...
@@ -93,6 +93,7 @@ from config import (
RAG_TEMPLATE
,
ENABLE_RAG_LOCAL_WEB_FETCH
,
YOUTUBE_LOADER_LANGUAGE
,
AppConfig
,
)
from
constants
import
ERROR_MESSAGES
...
...
@@ -102,30 +103,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app
=
FastAPI
()
app
.
state
.
TOP_K
=
RAG_TOP_K
app
.
state
.
RELEVANCE_THRESHOLD
=
RAG_RELEVANCE_THRESHOLD
app
.
state
.
config
=
AppConfig
()
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
=
ENABLE_RAG_HYBRID_SEARCH
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
(
app
.
state
.
config
.
TOP_K
=
RAG_TOP_K
app
.
state
.
config
.
RELEVANCE_THRESHOLD
=
RAG_RELEVANCE_THRESHOLD
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
=
ENABLE_RAG_HYBRID_SEARCH
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
(
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app
.
state
.
CHUNK_SIZE
=
CHUNK_SIZE
app
.
state
.
CHUNK_OVERLAP
=
CHUNK_OVERLAP
app
.
state
.
config
.
CHUNK_SIZE
=
CHUNK_SIZE
app
.
state
.
config
.
CHUNK_OVERLAP
=
CHUNK_OVERLAP
app
.
state
.
RAG_EMBEDDING_ENGINE
=
RAG_EMBEDDING_ENGINE
app
.
state
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
RAG_RERANKING_MODEL
=
RAG_RERANKING_MODEL
app
.
state
.
RAG_TEMPLATE
=
RAG_TEMPLATE
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
=
RAG_EMBEDDING_ENGINE
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
=
RAG_EMBEDDING_MODEL
app
.
state
.
config
.
RAG_RERANKING_MODEL
=
RAG_RERANKING_MODEL
app
.
state
.
config
.
RAG_TEMPLATE
=
RAG_TEMPLATE
app
.
state
.
OPENAI_API_BASE_URL
=
RAG_OPENAI_API_BASE_URL
app
.
state
.
OPENAI_API_KEY
=
RAG_OPENAI_API_KEY
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
RAG_OPENAI_API_BASE_URL
app
.
state
.
config
.
OPENAI_API_KEY
=
RAG_OPENAI_API_KEY
app
.
state
.
PDF_EXTRACT_IMAGES
=
PDF_EXTRACT_IMAGES
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
=
PDF_EXTRACT_IMAGES
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
=
YOUTUBE_LOADER_LANGUAGE
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
=
YOUTUBE_LOADER_LANGUAGE
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
=
None
...
...
@@ -133,7 +136,7 @@ def update_embedding_model(
embedding_model
:
str
,
update_model
:
bool
=
False
,
):
if
embedding_model
and
app
.
state
.
RAG_EMBEDDING_ENGINE
==
""
:
if
embedding_model
and
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
==
""
:
app
.
state
.
sentence_transformer_ef
=
sentence_transformers
.
SentenceTransformer
(
get_model_path
(
embedding_model
,
update_model
),
device
=
DEVICE_TYPE
,
...
...
@@ -158,22 +161,22 @@ def update_reranking_model(
update_embedding_model
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
RAG_EMBEDDING_MODEL_AUTO_UPDATE
,
)
update_reranking_model
(
app
.
state
.
RAG_RERANKING_MODEL
,
app
.
state
.
config
.
RAG_RERANKING_MODEL
,
RAG_RERANKING_MODEL_AUTO_UPDATE
,
)
app
.
state
.
EMBEDDING_FUNCTION
=
get_embedding_function
(
app
.
state
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
sentence_transformer_ef
,
app
.
state
.
OPENAI_API_KEY
,
app
.
state
.
OPENAI_API_BASE_URL
,
app
.
state
.
config
.
OPENAI_API_KEY
,
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
)
origins
=
[
"*"
]
...
...
@@ -200,12 +203,12 @@ class UrlForm(CollectionNameForm):
async
def
get_status
():
return
{
"status"
:
True
,
"chunk_size"
:
app
.
state
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
CHUNK_OVERLAP
,
"template"
:
app
.
state
.
RAG_TEMPLATE
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"reranking_model"
:
app
.
state
.
RAG_RERANKING_MODEL
,
"chunk_size"
:
app
.
state
.
config
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
config
.
CHUNK_OVERLAP
,
"template"
:
app
.
state
.
config
.
RAG_TEMPLATE
,
"embedding_engine"
:
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
"reranking_model"
:
app
.
state
.
config
.
RAG_RERANKING_MODEL
,
}
...
...
@@ -213,18 +216,21 @@ async def get_status():
async
def
get_embedding_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"status"
:
True
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"embedding_engine"
:
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
"openai_config"
:
{
"url"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
OPENAI_API_KEY
,
"url"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
},
}
@
app
.
get
(
"/reranking"
)
async
def
get_reraanking_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"status"
:
True
,
"reranking_model"
:
app
.
state
.
RAG_RERANKING_MODEL
}
return
{
"status"
:
True
,
"reranking_model"
:
app
.
state
.
config
.
RAG_RERANKING_MODEL
,
}
class
OpenAIConfigForm
(
BaseModel
):
...
...
@@ -243,34 +249,34 @@ async def update_embedding_config(
form_data
:
EmbeddingModelUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
log
.
info
(
f
"Updating embedding model:
{
app
.
state
.
RAG_EMBEDDING_MODEL
}
to
{
form_data
.
embedding_model
}
"
f
"Updating embedding model:
{
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
}
to
{
form_data
.
embedding_model
}
"
)
try
:
app
.
state
.
RAG_EMBEDDING_ENGINE
=
form_data
.
embedding_engine
app
.
state
.
RAG_EMBEDDING_MODEL
=
form_data
.
embedding_model
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
=
form_data
.
embedding_engine
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
=
form_data
.
embedding_model
if
app
.
state
.
RAG_EMBEDDING_ENGINE
in
[
"ollama"
,
"openai"
]:
if
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
in
[
"ollama"
,
"openai"
]:
if
form_data
.
openai_config
!=
None
:
app
.
state
.
OPENAI_API_BASE_URL
=
form_data
.
openai_config
.
url
app
.
state
.
OPENAI_API_KEY
=
form_data
.
openai_config
.
key
app
.
state
.
config
.
OPENAI_API_BASE_URL
=
form_data
.
openai_config
.
url
app
.
state
.
config
.
OPENAI_API_KEY
=
form_data
.
openai_config
.
key
update_embedding_model
(
app
.
state
.
RAG_EMBEDDING_MODEL
,
True
)
update_embedding_model
(
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
)
,
True
app
.
state
.
EMBEDDING_FUNCTION
=
get_embedding_function
(
app
.
state
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
sentence_transformer_ef
,
app
.
state
.
OPENAI_API_KEY
,
app
.
state
.
OPENAI_API_BASE_URL
,
app
.
state
.
config
.
OPENAI_API_KEY
,
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
)
return
{
"status"
:
True
,
"embedding_engine"
:
app
.
state
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
RAG_EMBEDDING_MODEL
,
"embedding_engine"
:
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
"embedding_model"
:
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
"openai_config"
:
{
"url"
:
app
.
state
.
OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
OPENAI_API_KEY
,
"url"
:
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
"key"
:
app
.
state
.
config
.
OPENAI_API_KEY
,
},
}
except
Exception
as
e
:
...
...
@@ -290,16 +296,16 @@ async def update_reranking_config(
form_data
:
RerankingModelUpdateForm
,
user
=
Depends
(
get_admin_user
)
):
log
.
info
(
f
"Updating reranking model:
{
app
.
state
.
RAG_RERANKING_MODEL
}
to
{
form_data
.
reranking_model
}
"
f
"Updating reranking model:
{
app
.
state
.
config
.
RAG_RERANKING_MODEL
}
to
{
form_data
.
reranking_model
}
"
)
try
:
app
.
state
.
RAG_RERANKING_MODEL
=
form_data
.
reranking_model
app
.
state
.
config
.
RAG_RERANKING_MODEL
=
form_data
.
reranking_model
update_reranking_model
(
app
.
state
.
RAG_RERANKING_MODEL
,
True
)
update_reranking_model
(
app
.
state
.
config
.
RAG_RERANKING_MODEL
)
,
True
return
{
"status"
:
True
,
"reranking_model"
:
app
.
state
.
RAG_RERANKING_MODEL
,
"reranking_model"
:
app
.
state
.
config
.
RAG_RERANKING_MODEL
,
}
except
Exception
as
e
:
log
.
exception
(
f
"Problem updating reranking model:
{
e
}
"
)
...
...
@@ -313,14 +319,14 @@ async def update_reranking_config(
async
def
get_rag_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"status"
:
True
,
"pdf_extract_images"
:
app
.
state
.
PDF_EXTRACT_IMAGES
,
"pdf_extract_images"
:
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
,
"chunk"
:
{
"chunk_size"
:
app
.
state
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
CHUNK_OVERLAP
,
"chunk_size"
:
app
.
state
.
config
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
config
.
CHUNK_OVERLAP
,
},
"web_loader_ssl_verification"
:
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
,
"web_loader_ssl_verification"
:
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
,
"youtube"
:
{
"language"
:
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
,
"language"
:
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
,
"translation"
:
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
,
},
}
...
...
@@ -345,50 +351,52 @@ class ConfigUpdateForm(BaseModel):
@
app
.
post
(
"/config/update"
)
async
def
update_rag_config
(
form_data
:
ConfigUpdateForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
PDF_EXTRACT_IMAGES
=
(
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
=
(
form_data
.
pdf_extract_images
if
form_data
.
pdf_extract_images
!=
None
else
app
.
state
.
PDF_EXTRACT_IMAGES
if
form_data
.
pdf_extract_images
is
not
None
else
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
)
app
.
state
.
CHUNK_SIZE
=
(
form_data
.
chunk
.
chunk_size
if
form_data
.
chunk
!=
None
else
app
.
state
.
CHUNK_SIZE
app
.
state
.
config
.
CHUNK_SIZE
=
(
form_data
.
chunk
.
chunk_size
if
form_data
.
chunk
is
not
None
else
app
.
state
.
config
.
CHUNK_SIZE
)
app
.
state
.
CHUNK_OVERLAP
=
(
app
.
state
.
config
.
CHUNK_OVERLAP
=
(
form_data
.
chunk
.
chunk_overlap
if
form_data
.
chunk
!=
None
else
app
.
state
.
CHUNK_OVERLAP
if
form_data
.
chunk
is
not
None
else
app
.
state
.
config
.
CHUNK_OVERLAP
)
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
(
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
(
form_data
.
web_loader_ssl_verification
if
form_data
.
web_loader_ssl_verification
!=
None
else
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
else
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
=
(
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
=
(
form_data
.
youtube
.
language
if
form_data
.
youtube
!=
None
else
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
if
form_data
.
youtube
is
not
None
else
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
)
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
=
(
form_data
.
youtube
.
translation
if
form_data
.
youtube
!=
None
if
form_data
.
youtube
is
not
None
else
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
)
return
{
"status"
:
True
,
"pdf_extract_images"
:
app
.
state
.
PDF_EXTRACT_IMAGES
,
"pdf_extract_images"
:
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
,
"chunk"
:
{
"chunk_size"
:
app
.
state
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
CHUNK_OVERLAP
,
"chunk_size"
:
app
.
state
.
config
.
CHUNK_SIZE
,
"chunk_overlap"
:
app
.
state
.
config
.
CHUNK_OVERLAP
,
},
"web_loader_ssl_verification"
:
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
,
"web_loader_ssl_verification"
:
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
,
"youtube"
:
{
"language"
:
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
,
"language"
:
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
,
"translation"
:
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
,
},
}
...
...
@@ -398,7 +406,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
async
def
get_rag_template
(
user
=
Depends
(
get_current_user
)):
return
{
"status"
:
True
,
"template"
:
app
.
state
.
RAG_TEMPLATE
,
"template"
:
app
.
state
.
config
.
RAG_TEMPLATE
,
}
...
...
@@ -406,10 +414,10 @@ async def get_rag_template(user=Depends(get_current_user)):
async
def
get_query_settings
(
user
=
Depends
(
get_admin_user
)):
return
{
"status"
:
True
,
"template"
:
app
.
state
.
RAG_TEMPLATE
,
"k"
:
app
.
state
.
TOP_K
,
"r"
:
app
.
state
.
RELEVANCE_THRESHOLD
,
"hybrid"
:
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
,
"template"
:
app
.
state
.
config
.
RAG_TEMPLATE
,
"k"
:
app
.
state
.
config
.
TOP_K
,
"r"
:
app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
"hybrid"
:
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
}
...
...
@@ -424,16 +432,20 @@ class QuerySettingsForm(BaseModel):
async
def
update_query_settings
(
form_data
:
QuerySettingsForm
,
user
=
Depends
(
get_admin_user
)
):
app
.
state
.
RAG_TEMPLATE
=
form_data
.
template
if
form_data
.
template
else
RAG_TEMPLATE
app
.
state
.
TOP_K
=
form_data
.
k
if
form_data
.
k
else
4
app
.
state
.
RELEVANCE_THRESHOLD
=
form_data
.
r
if
form_data
.
r
else
0.0
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
=
form_data
.
hybrid
if
form_data
.
hybrid
else
False
app
.
state
.
config
.
RAG_TEMPLATE
=
(
form_data
.
template
if
form_data
.
template
else
RAG_TEMPLATE
,
)
app
.
state
.
config
.
TOP_K
=
form_data
.
k
if
form_data
.
k
else
4
app
.
state
.
config
.
RELEVANCE_THRESHOLD
=
form_data
.
r
if
form_data
.
r
else
0.0
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
=
(
form_data
.
hybrid
if
form_data
.
hybrid
else
False
,
)
return
{
"status"
:
True
,
"template"
:
app
.
state
.
RAG_TEMPLATE
,
"k"
:
app
.
state
.
TOP_K
,
"r"
:
app
.
state
.
RELEVANCE_THRESHOLD
,
"hybrid"
:
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
,
"template"
:
app
.
state
.
config
.
RAG_TEMPLATE
,
"k"
:
app
.
state
.
config
.
TOP_K
,
"r"
:
app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
"hybrid"
:
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
}
...
...
@@ -451,21 +463,23 @@ def query_doc_handler(
user
=
Depends
(
get_current_user
),
):
try
:
if
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
:
if
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
:
return
query_doc_with_hybrid_search
(
collection_name
=
form_data
.
collection_name
,
query
=
form_data
.
query
,
embedding_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
config
.
TOP_K
,
reranking_function
=
app
.
state
.
sentence_transformer_rf
,
r
=
form_data
.
r
if
form_data
.
r
else
app
.
state
.
RELEVANCE_THRESHOLD
,
r
=
(
form_data
.
r
if
form_data
.
r
else
app
.
state
.
config
.
RELEVANCE_THRESHOLD
),
)
else
:
return
query_doc
(
collection_name
=
form_data
.
collection_name
,
query
=
form_data
.
query
,
embedding_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
config
.
TOP_K
,
)
except
Exception
as
e
:
log
.
exception
(
e
)
...
...
@@ -489,21 +503,23 @@ def query_collection_handler(
user
=
Depends
(
get_current_user
),
):
try
:
if
app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
:
if
app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
:
return
query_collection_with_hybrid_search
(
collection_names
=
form_data
.
collection_names
,
query
=
form_data
.
query
,
embedding_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
config
.
TOP_K
,
reranking_function
=
app
.
state
.
sentence_transformer_rf
,
r
=
form_data
.
r
if
form_data
.
r
else
app
.
state
.
RELEVANCE_THRESHOLD
,
r
=
(
form_data
.
r
if
form_data
.
r
else
app
.
state
.
config
.
RELEVANCE_THRESHOLD
),
)
else
:
return
query_collection
(
collection_names
=
form_data
.
collection_names
,
query
=
form_data
.
query
,
embedding_function
=
app
.
state
.
EMBEDDING_FUNCTION
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
TOP_K
,
k
=
form_data
.
k
if
form_data
.
k
else
app
.
state
.
config
.
TOP_K
,
)
except
Exception
as
e
:
...
...
@@ -520,7 +536,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
loader
=
YoutubeLoader
.
from_youtube_url
(
form_data
.
url
,
add_video_info
=
True
,
language
=
app
.
state
.
YOUTUBE_LOADER_LANGUAGE
,
language
=
app
.
state
.
config
.
YOUTUBE_LOADER_LANGUAGE
,
translation
=
app
.
state
.
YOUTUBE_LOADER_TRANSLATION
,
)
data
=
loader
.
load
()
...
...
@@ -548,7 +564,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try
:
loader
=
get_web_loader
(
form_data
.
url
,
verify_ssl
=
app
.
state
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
form_data
.
url
,
verify_ssl
=
app
.
state
.
config
.
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
,
)
data
=
loader
.
load
()
...
...
@@ -604,8 +621,8 @@ def resolve_hostname(hostname):
def
store_data_in_vector_db
(
data
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
text_splitter
=
RecursiveCharacterTextSplitter
(
chunk_size
=
app
.
state
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
CHUNK_OVERLAP
,
chunk_size
=
app
.
state
.
config
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
config
.
CHUNK_OVERLAP
,
add_start_index
=
True
,
)
...
...
@@ -622,8 +639,8 @@ def store_text_in_vector_db(
text
,
metadata
,
collection_name
,
overwrite
:
bool
=
False
)
->
bool
:
text_splitter
=
RecursiveCharacterTextSplitter
(
chunk_size
=
app
.
state
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
CHUNK_OVERLAP
,
chunk_size
=
app
.
state
.
config
.
CHUNK_SIZE
,
chunk_overlap
=
app
.
state
.
config
.
CHUNK_OVERLAP
,
add_start_index
=
True
,
)
docs
=
text_splitter
.
create_documents
([
text
],
metadatas
=
[
metadata
])
...
...
@@ -646,11 +663,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection
=
CHROMA_CLIENT
.
create_collection
(
name
=
collection_name
)
embedding_func
=
get_embedding_function
(
app
.
state
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
config
.
RAG_EMBEDDING_ENGINE
,
app
.
state
.
config
.
RAG_EMBEDDING_MODEL
,
app
.
state
.
sentence_transformer_ef
,
app
.
state
.
OPENAI_API_KEY
,
app
.
state
.
OPENAI_API_BASE_URL
,
app
.
state
.
config
.
OPENAI_API_KEY
,
app
.
state
.
config
.
OPENAI_API_BASE_URL
,
)
embedding_texts
=
list
(
map
(
lambda
x
:
x
.
replace
(
"
\n
"
,
" "
),
texts
))
...
...
@@ -724,7 +741,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
]
if
file_ext
==
"pdf"
:
loader
=
PyPDFLoader
(
file_path
,
extract_images
=
app
.
state
.
PDF_EXTRACT_IMAGES
)
loader
=
PyPDFLoader
(
file_path
,
extract_images
=
app
.
state
.
config
.
PDF_EXTRACT_IMAGES
)
elif
file_ext
==
"csv"
:
loader
=
CSVLoader
(
file_path
)
elif
file_ext
==
"rst"
:
...
...
backend/apps/web/main.py
View file @
564a3a29
...
...
@@ -21,20 +21,24 @@ from config import (
USER_PERMISSIONS
,
WEBHOOK_URL
,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
,
JWT_EXPIRES_IN
,
AppConfig
,
)
app
=
FastAPI
()
origins
=
[
"*"
]
app
.
state
.
ENABLE_SIGNUP
=
ENABLE_SIGNUP
app
.
state
.
JWT_EXPIRES_IN
=
"-1"
app
.
state
.
config
=
AppConfig
()
app
.
state
.
DEFAULT_MODELS
=
DEFAULT_MODELS
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
=
DEFAULT_PROMPT_SUGGESTIONS
app
.
state
.
DEFAULT_USER_ROLE
=
DEFAULT_USER_ROLE
app
.
state
.
USER_PERMISSIONS
=
USER_PERMISSIONS
app
.
state
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
config
.
ENABLE_SIGNUP
=
ENABLE_SIGNUP
app
.
state
.
config
.
JWT_EXPIRES_IN
=
JWT_EXPIRES_IN
app
.
state
.
config
.
DEFAULT_MODELS
=
DEFAULT_MODELS
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
=
DEFAULT_PROMPT_SUGGESTIONS
app
.
state
.
config
.
DEFAULT_USER_ROLE
=
DEFAULT_USER_ROLE
app
.
state
.
config
.
USER_PERMISSIONS
=
USER_PERMISSIONS
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
=
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app
.
add_middleware
(
...
...
@@ -61,6 +65,6 @@ async def get_status():
return
{
"status"
:
True
,
"auth"
:
WEBUI_AUTH
,
"default_models"
:
app
.
state
.
DEFAULT_MODELS
,
"default_prompt_suggestions"
:
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
,
"default_models"
:
app
.
state
.
config
.
DEFAULT_MODELS
,
"default_prompt_suggestions"
:
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
,
}
backend/apps/web/routers/auths.py
View file @
564a3a29
...
...
@@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
if
user
:
token
=
create_token
(
data
=
{
"id"
:
user
.
id
},
expires_delta
=
parse_duration
(
request
.
app
.
state
.
JWT_EXPIRES_IN
),
expires_delta
=
parse_duration
(
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
),
)
return
{
...
...
@@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
@
router
.
post
(
"/signup"
,
response_model
=
SigninResponse
)
async
def
signup
(
request
:
Request
,
form_data
:
SignupForm
):
if
not
request
.
app
.
state
.
ENABLE_SIGNUP
and
WEBUI_AUTH
:
if
not
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
and
WEBUI_AUTH
:
raise
HTTPException
(
status
.
HTTP_403_FORBIDDEN
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
)
...
...
@@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
role
=
(
"admin"
if
Users
.
get_num_users
()
==
0
else
request
.
app
.
state
.
DEFAULT_USER_ROLE
else
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
)
hashed
=
get_password_hash
(
form_data
.
password
)
user
=
Auths
.
insert_new_auth
(
...
...
@@ -194,13 +194,13 @@ async def signup(request: Request, form_data: SignupForm):
if
user
:
token
=
create_token
(
data
=
{
"id"
:
user
.
id
},
expires_delta
=
parse_duration
(
request
.
app
.
state
.
JWT_EXPIRES_IN
),
expires_delta
=
parse_duration
(
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
),
)
# response.set_cookie(key='token', value=token, httponly=True)
if
request
.
app
.
state
.
WEBHOOK_URL
:
if
request
.
app
.
state
.
config
.
WEBHOOK_URL
:
post_webhook
(
request
.
app
.
state
.
WEBHOOK_URL
,
request
.
app
.
state
.
config
.
WEBHOOK_URL
,
WEBHOOK_MESSAGES
.
USER_SIGNUP
(
user
.
name
),
{
"action"
:
"signup"
,
...
...
@@ -276,13 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
@
router
.
get
(
"/signup/enabled"
,
response_model
=
bool
)
async
def
get_sign_up_status
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
request
.
app
.
state
.
ENABLE_SIGNUP
return
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
@
router
.
get
(
"/signup/enabled/toggle"
,
response_model
=
bool
)
async
def
toggle_sign_up
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
request
.
app
.
state
.
ENABLE_SIGNUP
=
not
request
.
app
.
state
.
ENABLE_SIGNUP
return
request
.
app
.
state
.
ENABLE_SIGNUP
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
=
not
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
return
request
.
app
.
state
.
config
.
ENABLE_SIGNUP
############################
...
...
@@ -292,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
@
router
.
get
(
"/signup/user/role"
)
async
def
get_default_user_role
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
request
.
app
.
state
.
DEFAULT_USER_ROLE
return
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
class
UpdateRoleForm
(
BaseModel
):
...
...
@@ -304,8 +304,8 @@ async def update_default_user_role(
request
:
Request
,
form_data
:
UpdateRoleForm
,
user
=
Depends
(
get_admin_user
)
):
if
form_data
.
role
in
[
"pending"
,
"user"
,
"admin"
]:
request
.
app
.
state
.
DEFAULT_USER_ROLE
=
form_data
.
role
return
request
.
app
.
state
.
DEFAULT_USER_ROLE
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
=
form_data
.
role
return
request
.
app
.
state
.
config
.
DEFAULT_USER_ROLE
############################
...
...
@@ -315,7 +315,7 @@ async def update_default_user_role(
@
router
.
get
(
"/token/expires"
)
async
def
get_token_expires_duration
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
request
.
app
.
state
.
JWT_EXPIRES_IN
return
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
class
UpdateJWTExpiresDurationForm
(
BaseModel
):
...
...
@@ -332,10 +332,10 @@ async def update_token_expires_duration(
# Check if the input string matches the pattern
if
re
.
match
(
pattern
,
form_data
.
duration
):
request
.
app
.
state
.
JWT_EXPIRES_IN
=
form_data
.
duration
return
request
.
app
.
state
.
JWT_EXPIRES_IN
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
=
form_data
.
duration
return
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
else
:
return
request
.
app
.
state
.
JWT_EXPIRES_IN
return
request
.
app
.
state
.
config
.
JWT_EXPIRES_IN
############################
...
...
backend/apps/web/routers/chats.py
View file @
564a3a29
...
...
@@ -58,7 +58,7 @@ async def delete_all_user_chats(request: Request, user=Depends(get_current_user)
if
(
user
.
role
==
"user"
and
not
request
.
app
.
state
.
USER_PERMISSIONS
[
"chat"
][
"deletion"
]
and
not
request
.
app
.
state
.
config
.
USER_PERMISSIONS
[
"chat"
][
"deletion"
]
):
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
...
...
@@ -266,7 +266,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
result
=
Chats
.
delete_chat_by_id
(
id
)
return
result
else
:
if
not
request
.
app
.
state
.
USER_PERMISSIONS
[
"chat"
][
"deletion"
]:
if
not
request
.
app
.
state
.
config
.
USER_PERMISSIONS
[
"chat"
][
"deletion"
]:
raise
HTTPException
(
status_code
=
status
.
HTTP_401_UNAUTHORIZED
,
detail
=
ERROR_MESSAGES
.
ACCESS_PROHIBITED
,
...
...
backend/apps/web/routers/configs.py
View file @
564a3a29
...
...
@@ -44,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel):
async
def
set_global_default_models
(
request
:
Request
,
form_data
:
SetDefaultModelsForm
,
user
=
Depends
(
get_admin_user
)
):
request
.
app
.
state
.
DEFAULT_MODELS
=
form_data
.
models
return
request
.
app
.
state
.
DEFAULT_MODELS
request
.
app
.
state
.
config
.
DEFAULT_MODELS
=
form_data
.
models
return
request
.
app
.
state
.
config
.
DEFAULT_MODELS
@
router
.
post
(
"/default/suggestions"
,
response_model
=
List
[
PromptSuggestion
])
...
...
@@ -55,5 +55,5 @@ async def set_global_default_suggestions(
user
=
Depends
(
get_admin_user
),
):
data
=
form_data
.
model_dump
()
request
.
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
=
data
[
"suggestions"
]
return
request
.
app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
request
.
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
=
data
[
"suggestions"
]
return
request
.
app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
backend/apps/web/routers/users.py
View file @
564a3a29
...
...
@@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
@
router
.
get
(
"/permissions/user"
)
async
def
get_user_permissions
(
request
:
Request
,
user
=
Depends
(
get_admin_user
)):
return
request
.
app
.
state
.
USER_PERMISSIONS
return
request
.
app
.
state
.
config
.
USER_PERMISSIONS
@
router
.
post
(
"/permissions/user"
)
async
def
update_user_permissions
(
request
:
Request
,
form_data
:
dict
,
user
=
Depends
(
get_admin_user
)
):
request
.
app
.
state
.
USER_PERMISSIONS
=
form_data
return
request
.
app
.
state
.
USER_PERMISSIONS
request
.
app
.
state
.
config
.
USER_PERMISSIONS
=
form_data
return
request
.
app
.
state
.
config
.
USER_PERMISSIONS
############################
...
...
backend/config.py
View file @
564a3a29
...
...
@@ -5,6 +5,7 @@ import chromadb
from
chromadb
import
Settings
from
base64
import
b64encode
from
bs4
import
BeautifulSoup
from
typing
import
TypeVar
,
Generic
,
Union
from
pathlib
import
Path
import
json
...
...
@@ -17,7 +18,6 @@ import shutil
from
secrets
import
token_bytes
from
constants
import
ERROR_MESSAGES
####################################
# Load .env file
####################################
...
...
@@ -71,7 +71,6 @@ for source in log_sources:
log
.
setLevel
(
SRC_LOG_LEVELS
[
"CONFIG"
])
WEBUI_NAME
=
os
.
environ
.
get
(
"WEBUI_NAME"
,
"Open WebUI"
)
if
WEBUI_NAME
!=
"Open WebUI"
:
WEBUI_NAME
+=
" (Open WebUI)"
...
...
@@ -161,16 +160,6 @@ CHANGELOG = changelog_json
WEBUI_VERSION
=
os
.
environ
.
get
(
"WEBUI_VERSION"
,
"v1.0.0-alpha.100"
)
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH
=
os
.
environ
.
get
(
"WEBUI_AUTH"
,
"True"
).
lower
()
==
"true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
=
os
.
environ
.
get
(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER"
,
None
)
####################################
# DATA/FRONTEND BUILD DIR
####################################
...
...
@@ -184,6 +173,108 @@ try:
except
:
CONFIG_DATA
=
{}
####################################
# Config helpers
####################################
def
save_config
():
try
:
with
open
(
f
"
{
DATA_DIR
}
/config.json"
,
"w"
)
as
f
:
json
.
dump
(
CONFIG_DATA
,
f
,
indent
=
"
\t
"
)
except
Exception
as
e
:
log
.
exception
(
e
)
def
get_config_value
(
config_path
:
str
):
path_parts
=
config_path
.
split
(
"."
)
cur_config
=
CONFIG_DATA
for
key
in
path_parts
:
if
key
in
cur_config
:
cur_config
=
cur_config
[
key
]
else
:
return
None
return
cur_config
T
=
TypeVar
(
"T"
)
class
PersistentConfig
(
Generic
[
T
]):
def
__init__
(
self
,
env_name
:
str
,
config_path
:
str
,
env_value
:
T
):
self
.
env_name
=
env_name
self
.
config_path
=
config_path
self
.
env_value
=
env_value
self
.
config_value
=
get_config_value
(
config_path
)
if
self
.
config_value
is
not
None
:
log
.
info
(
f
"'
{
env_name
}
' loaded from config.json"
)
self
.
value
=
self
.
config_value
else
:
self
.
value
=
env_value
def
__str__
(
self
):
return
str
(
self
.
value
)
@
property
def
__dict__
(
self
):
raise
TypeError
(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
def
__getattribute__
(
self
,
item
):
if
item
==
"__dict__"
:
raise
TypeError
(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
return
super
().
__getattribute__
(
item
)
def
save
(
self
):
# Don't save if the value is the same as the env value and the config value
if
self
.
env_value
==
self
.
value
:
if
self
.
config_value
==
self
.
value
:
return
log
.
info
(
f
"Saving '
{
self
.
env_name
}
' to config.json"
)
path_parts
=
self
.
config_path
.
split
(
"."
)
config
=
CONFIG_DATA
for
key
in
path_parts
[:
-
1
]:
if
key
not
in
config
:
config
[
key
]
=
{}
config
=
config
[
key
]
config
[
path_parts
[
-
1
]]
=
self
.
value
save_config
()
self
.
config_value
=
self
.
value
class
AppConfig
:
_state
:
dict
[
str
,
PersistentConfig
]
def
__init__
(
self
):
super
().
__setattr__
(
"_state"
,
{})
def
__setattr__
(
self
,
key
,
value
):
if
isinstance
(
value
,
PersistentConfig
):
self
.
_state
[
key
]
=
value
else
:
self
.
_state
[
key
].
value
=
value
self
.
_state
[
key
].
save
()
def
__getattr__
(
self
,
key
):
return
self
.
_state
[
key
].
value
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH
=
os
.
environ
.
get
(
"WEBUI_AUTH"
,
"True"
).
lower
()
==
"true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER
=
os
.
environ
.
get
(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER"
,
None
)
JWT_EXPIRES_IN
=
PersistentConfig
(
"JWT_EXPIRES_IN"
,
"auth.jwt_expiry"
,
os
.
environ
.
get
(
"JWT_EXPIRES_IN"
,
"-1"
)
)
####################################
# Static DIR
####################################
...
...
@@ -318,7 +409,9 @@ OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
OLLAMA_BASE_URLS
=
OLLAMA_BASE_URLS
if
OLLAMA_BASE_URLS
!=
""
else
OLLAMA_BASE_URL
OLLAMA_BASE_URLS
=
[
url
.
strip
()
for
url
in
OLLAMA_BASE_URLS
.
split
(
";"
)]
OLLAMA_BASE_URLS
=
PersistentConfig
(
"OLLAMA_BASE_URLS"
,
"ollama.base_urls"
,
OLLAMA_BASE_URLS
)
####################################
# OPENAI_API
...
...
@@ -335,7 +428,9 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
OPENAI_API_KEYS
=
OPENAI_API_KEYS
if
OPENAI_API_KEYS
!=
""
else
OPENAI_API_KEY
OPENAI_API_KEYS
=
[
url
.
strip
()
for
url
in
OPENAI_API_KEYS
.
split
(
";"
)]
OPENAI_API_KEYS
=
PersistentConfig
(
"OPENAI_API_KEYS"
,
"openai.api_keys"
,
OPENAI_API_KEYS
)
OPENAI_API_BASE_URLS
=
os
.
environ
.
get
(
"OPENAI_API_BASE_URLS"
,
""
)
OPENAI_API_BASE_URLS
=
(
...
...
@@ -346,37 +441,42 @@ OPENAI_API_BASE_URLS = [
url
.
strip
()
if
url
!=
""
else
"https://api.openai.com/v1"
for
url
in
OPENAI_API_BASE_URLS
.
split
(
";"
)
]
OPENAI_API_BASE_URLS
=
PersistentConfig
(
"OPENAI_API_BASE_URLS"
,
"openai.api_base_urls"
,
OPENAI_API_BASE_URLS
)
OPENAI_API_KEY
=
""
try
:
OPENAI_API_KEY
=
OPENAI_API_KEYS
[
OPENAI_API_BASE_URLS
.
index
(
"https://api.openai.com/v1"
)
OPENAI_API_KEY
=
OPENAI_API_KEYS
.
value
[
OPENAI_API_BASE_URLS
.
value
.
index
(
"https://api.openai.com/v1"
)
]
except
:
pass
OPENAI_API_BASE_URL
=
"https://api.openai.com/v1"
####################################
# WEBUI
####################################
ENABLE_SIGNUP
=
(
ENABLE_SIGNUP
=
PersistentConfig
(
"ENABLE_SIGNUP"
,
"ui.enable_signup"
,
(
False
if
WEBUI_AUTH
==
False
if
not
WEBUI_AUTH
else
os
.
environ
.
get
(
"ENABLE_SIGNUP"
,
"True"
).
lower
()
==
"true"
),
)
DEFAULT_MODELS
=
PersistentConfig
(
"DEFAULT_MODELS"
,
"ui.default_models"
,
os
.
environ
.
get
(
"DEFAULT_MODELS"
,
None
)
)
DEFAULT_MODELS
=
os
.
environ
.
get
(
"DEFAULT_MODELS"
,
None
)
DEFAULT_PROMPT_SUGGESTIONS
=
(
CONFIG_DATA
[
"ui"
][
"prompt_suggestions"
]
if
"ui"
in
CONFIG_DATA
and
"prompt_suggestions"
in
CONFIG_DATA
[
"ui"
]
and
type
(
CONFIG_DATA
[
"ui"
][
"prompt_suggestions"
])
is
list
else
[
DEFAULT_PROMPT_SUGGESTIONS
=
PersistentConfig
(
"DEFAULT_PROMPT_SUGGESTIONS"
,
"ui.prompt_suggestions"
,
[
{
"title"
:
[
"Help me study"
,
"vocabulary for a college entrance exam"
],
"content"
:
"Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."
,
...
...
@@ -404,23 +504,40 @@ DEFAULT_PROMPT_SUGGESTIONS = (
"title"
:
[
"Overcome procrastination"
,
"give me tips"
],
"content"
:
"Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?"
,
},
]
]
,
)
DEFAULT_USER_ROLE
=
os
.
getenv
(
"DEFAULT_USER_ROLE"
,
"pending"
)
DEFAULT_USER_ROLE
=
PersistentConfig
(
"DEFAULT_USER_ROLE"
,
"ui.default_user_role"
,
os
.
getenv
(
"DEFAULT_USER_ROLE"
,
"pending"
),
)
USER_PERMISSIONS_CHAT_DELETION
=
(
os
.
environ
.
get
(
"USER_PERMISSIONS_CHAT_DELETION"
,
"True"
).
lower
()
==
"true"
)
USER_PERMISSIONS
=
{
"chat"
:
{
"deletion"
:
USER_PERMISSIONS_CHAT_DELETION
}}
USER_PERMISSIONS
=
PersistentConfig
(
"USER_PERMISSIONS"
,
"ui.user_permissions"
,
{
"chat"
:
{
"deletion"
:
USER_PERMISSIONS_CHAT_DELETION
}},
)
ENABLE_MODEL_FILTER
=
os
.
environ
.
get
(
"ENABLE_MODEL_FILTER"
,
"False"
).
lower
()
==
"true"
ENABLE_MODEL_FILTER
=
PersistentConfig
(
"ENABLE_MODEL_FILTER"
,
"model_filter.enable"
,
os
.
environ
.
get
(
"ENABLE_MODEL_FILTER"
,
"False"
).
lower
()
==
"true"
,
)
MODEL_FILTER_LIST
=
os
.
environ
.
get
(
"MODEL_FILTER_LIST"
,
""
)
MODEL_FILTER_LIST
=
[
model
.
strip
()
for
model
in
MODEL_FILTER_LIST
.
split
(
";"
)]
MODEL_FILTER_LIST
=
PersistentConfig
(
"MODEL_FILTER_LIST"
,
"model_filter.list"
,
[
model
.
strip
()
for
model
in
MODEL_FILTER_LIST
.
split
(
";"
)],
)
WEBHOOK_URL
=
os
.
environ
.
get
(
"WEBHOOK_URL"
,
""
)
WEBHOOK_URL
=
PersistentConfig
(
"WEBHOOK_URL"
,
"webhook_url"
,
os
.
environ
.
get
(
"WEBHOOK_URL"
,
""
)
)
ENABLE_ADMIN_EXPORT
=
os
.
environ
.
get
(
"ENABLE_ADMIN_EXPORT"
,
"True"
).
lower
()
==
"true"
...
...
@@ -458,26 +575,45 @@ else:
CHROMA_HTTP_SSL
=
os
.
environ
.
get
(
"CHROMA_HTTP_SSL"
,
"false"
).
lower
()
==
"true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_TOP_K
=
int
(
os
.
environ
.
get
(
"RAG_TOP_K"
,
"5"
))
RAG_RELEVANCE_THRESHOLD
=
float
(
os
.
environ
.
get
(
"RAG_RELEVANCE_THRESHOLD"
,
"0.0"
))
ENABLE_RAG_HYBRID_SEARCH
=
(
os
.
environ
.
get
(
"ENABLE_RAG_HYBRID_SEARCH"
,
""
).
lower
()
==
"true"
RAG_TOP_K
=
PersistentConfig
(
"RAG_TOP_K"
,
"rag.top_k"
,
int
(
os
.
environ
.
get
(
"RAG_TOP_K"
,
"5"
))
)
RAG_RELEVANCE_THRESHOLD
=
PersistentConfig
(
"RAG_RELEVANCE_THRESHOLD"
,
"rag.relevance_threshold"
,
float
(
os
.
environ
.
get
(
"RAG_RELEVANCE_THRESHOLD"
,
"0.0"
)),
)
ENABLE_RAG_HYBRID_SEARCH
=
PersistentConfig
(
"ENABLE_RAG_HYBRID_SEARCH"
,
"rag.enable_hybrid_search"
,
os
.
environ
.
get
(
"ENABLE_RAG_HYBRID_SEARCH"
,
""
).
lower
()
==
"true"
,
)
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
(
os
.
environ
.
get
(
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION"
,
"True"
).
lower
()
==
"true"
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
=
PersistentConfig
(
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION"
,
"rag.enable_web_loader_ssl_verification"
,
os
.
environ
.
get
(
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION"
,
"True"
).
lower
()
==
"true"
,
)
RAG_EMBEDDING_ENGINE
=
os
.
environ
.
get
(
"RAG_EMBEDDING_ENGINE"
,
""
)
RAG_EMBEDDING_ENGINE
=
PersistentConfig
(
"RAG_EMBEDDING_ENGINE"
,
"rag.embedding_engine"
,
os
.
environ
.
get
(
"RAG_EMBEDDING_ENGINE"
,
""
),
)
PDF_EXTRACT_IMAGES
=
os
.
environ
.
get
(
"PDF_EXTRACT_IMAGES"
,
"False"
).
lower
()
==
"true"
PDF_EXTRACT_IMAGES
=
PersistentConfig
(
"PDF_EXTRACT_IMAGES"
,
"rag.pdf_extract_images"
,
os
.
environ
.
get
(
"PDF_EXTRACT_IMAGES"
,
"False"
).
lower
()
==
"true"
,
)
RAG_EMBEDDING_MODEL
=
os
.
environ
.
get
(
"RAG_EMBEDDING_MODEL"
,
"sentence-transformers/all-MiniLM-L6-v2"
RAG_EMBEDDING_MODEL
=
PersistentConfig
(
"RAG_EMBEDDING_MODEL"
,
"rag.embedding_model"
,
os
.
environ
.
get
(
"RAG_EMBEDDING_MODEL"
,
"sentence-transformers/all-MiniLM-L6-v2"
),
)
log
.
info
(
f
"Embedding model set:
{
RAG_EMBEDDING_MODEL
}
"
),
log
.
info
(
f
"Embedding model set:
{
RAG_EMBEDDING_MODEL
.
value
}
"
),
RAG_EMBEDDING_MODEL_AUTO_UPDATE
=
(
os
.
environ
.
get
(
"RAG_EMBEDDING_MODEL_AUTO_UPDATE"
,
""
).
lower
()
==
"true"
...
...
@@ -487,9 +623,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os
.
environ
.
get
(
"RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE"
,
""
).
lower
()
==
"true"
)
RAG_RERANKING_MODEL
=
os
.
environ
.
get
(
"RAG_RERANKING_MODEL"
,
""
)
if
not
RAG_RERANKING_MODEL
==
""
:
log
.
info
(
f
"Reranking model set:
{
RAG_RERANKING_MODEL
}
"
),
RAG_RERANKING_MODEL
=
PersistentConfig
(
"RAG_RERANKING_MODEL"
,
"rag.reranking_model"
,
os
.
environ
.
get
(
"RAG_RERANKING_MODEL"
,
""
),
)
if
RAG_RERANKING_MODEL
.
value
!=
""
:
log
.
info
(
f
"Reranking model set:
{
RAG_RERANKING_MODEL
.
value
}
"
),
RAG_RERANKING_MODEL_AUTO_UPDATE
=
(
os
.
environ
.
get
(
"RAG_RERANKING_MODEL_AUTO_UPDATE"
,
""
).
lower
()
==
"true"
...
...
@@ -527,9 +667,14 @@ if USE_CUDA.lower() == "true":
else
:
DEVICE_TYPE
=
"cpu"
CHUNK_SIZE
=
int
(
os
.
environ
.
get
(
"CHUNK_SIZE"
,
"1500"
))
CHUNK_OVERLAP
=
int
(
os
.
environ
.
get
(
"CHUNK_OVERLAP"
,
"100"
))
CHUNK_SIZE
=
PersistentConfig
(
"CHUNK_SIZE"
,
"rag.chunk_size"
,
int
(
os
.
environ
.
get
(
"CHUNK_SIZE"
,
"1500"
))
)
CHUNK_OVERLAP
=
PersistentConfig
(
"CHUNK_OVERLAP"
,
"rag.chunk_overlap"
,
int
(
os
.
environ
.
get
(
"CHUNK_OVERLAP"
,
"100"
)),
)
DEFAULT_RAG_TEMPLATE
=
"""Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
...
...
@@ -545,16 +690,32 @@ And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]"""
RAG_TEMPLATE
=
os
.
environ
.
get
(
"RAG_TEMPLATE"
,
DEFAULT_RAG_TEMPLATE
)
RAG_TEMPLATE
=
PersistentConfig
(
"RAG_TEMPLATE"
,
"rag.template"
,
os
.
environ
.
get
(
"RAG_TEMPLATE"
,
DEFAULT_RAG_TEMPLATE
),
)
RAG_OPENAI_API_BASE_URL
=
os
.
getenv
(
"RAG_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
)
RAG_OPENAI_API_KEY
=
os
.
getenv
(
"RAG_OPENAI_API_KEY"
,
OPENAI_API_KEY
)
RAG_OPENAI_API_BASE_URL
=
PersistentConfig
(
"RAG_OPENAI_API_BASE_URL"
,
"rag.openai_api_base_url"
,
os
.
getenv
(
"RAG_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
),
)
RAG_OPENAI_API_KEY
=
PersistentConfig
(
"RAG_OPENAI_API_KEY"
,
"rag.openai_api_key"
,
os
.
getenv
(
"RAG_OPENAI_API_KEY"
,
OPENAI_API_KEY
),
)
ENABLE_RAG_LOCAL_WEB_FETCH
=
(
os
.
getenv
(
"ENABLE_RAG_LOCAL_WEB_FETCH"
,
"False"
).
lower
()
==
"true"
)
YOUTUBE_LOADER_LANGUAGE
=
os
.
getenv
(
"YOUTUBE_LOADER_LANGUAGE"
,
"en"
).
split
(
","
)
YOUTUBE_LOADER_LANGUAGE
=
PersistentConfig
(
"YOUTUBE_LOADER_LANGUAGE"
,
"rag.youtube_loader_language"
,
os
.
getenv
(
"YOUTUBE_LOADER_LANGUAGE"
,
"en"
).
split
(
","
),
)
####################################
# Transcribe
...
...
@@ -571,34 +732,78 @@ WHISPER_MODEL_AUTO_UPDATE = (
# Images
####################################
IMAGE_GENERATION_ENGINE
=
os
.
getenv
(
"IMAGE_GENERATION_ENGINE"
,
""
)
IMAGE_GENERATION_ENGINE
=
PersistentConfig
(
"IMAGE_GENERATION_ENGINE"
,
"image_generation.engine"
,
os
.
getenv
(
"IMAGE_GENERATION_ENGINE"
,
""
),
)
ENABLE_IMAGE_GENERATION
=
(
os
.
environ
.
get
(
"ENABLE_IMAGE_GENERATION"
,
""
).
lower
()
==
"true"
ENABLE_IMAGE_GENERATION
=
PersistentConfig
(
"ENABLE_IMAGE_GENERATION"
,
"image_generation.enable"
,
os
.
environ
.
get
(
"ENABLE_IMAGE_GENERATION"
,
""
).
lower
()
==
"true"
,
)
AUTOMATIC1111_BASE_URL
=
PersistentConfig
(
"AUTOMATIC1111_BASE_URL"
,
"image_generation.automatic1111.base_url"
,
os
.
getenv
(
"AUTOMATIC1111_BASE_URL"
,
""
),
)
AUTOMATIC1111_BASE_URL
=
os
.
getenv
(
"AUTOMATIC1111_BASE_URL"
,
""
)
COMFYUI_BASE_URL
=
os
.
getenv
(
"COMFYUI_BASE_URL"
,
""
)
COMFYUI_BASE_URL
=
PersistentConfig
(
"COMFYUI_BASE_URL"
,
"image_generation.comfyui.base_url"
,
os
.
getenv
(
"COMFYUI_BASE_URL"
,
""
),
)
IMAGES_OPENAI_API_BASE_URL
=
os
.
getenv
(
"IMAGES_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
IMAGES_OPENAI_API_BASE_URL
=
PersistentConfig
(
"IMAGES_OPENAI_API_BASE_URL"
,
"image_generation.openai.api_base_url"
,
os
.
getenv
(
"IMAGES_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
),
)
IMAGES_OPENAI_API_KEY
=
PersistentConfig
(
"IMAGES_OPENAI_API_KEY"
,
"image_generation.openai.api_key"
,
os
.
getenv
(
"IMAGES_OPENAI_API_KEY"
,
OPENAI_API_KEY
),
)
IMAGES_OPENAI_API_KEY
=
os
.
getenv
(
"IMAGES_OPENAI_API_KEY"
,
OPENAI_API_KEY
)
IMAGE_SIZE
=
os
.
getenv
(
"IMAGE_SIZE"
,
"512x512"
)
IMAGE_SIZE
=
PersistentConfig
(
"IMAGE_SIZE"
,
"image_generation.size"
,
os
.
getenv
(
"IMAGE_SIZE"
,
"512x512"
)
)
IMAGE_STEPS
=
int
(
os
.
getenv
(
"IMAGE_STEPS"
,
50
))
IMAGE_STEPS
=
PersistentConfig
(
"IMAGE_STEPS"
,
"image_generation.steps"
,
int
(
os
.
getenv
(
"IMAGE_STEPS"
,
50
))
)
IMAGE_GENERATION_MODEL
=
os
.
getenv
(
"IMAGE_GENERATION_MODEL"
,
""
)
IMAGE_GENERATION_MODEL
=
PersistentConfig
(
"IMAGE_GENERATION_MODEL"
,
"image_generation.model"
,
os
.
getenv
(
"IMAGE_GENERATION_MODEL"
,
""
),
)
####################################
# Audio
####################################
AUDIO_OPENAI_API_BASE_URL
=
os
.
getenv
(
"AUDIO_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
)
AUDIO_OPENAI_API_KEY
=
os
.
getenv
(
"AUDIO_OPENAI_API_KEY"
,
OPENAI_API_KEY
)
AUDIO_OPENAI_API_MODEL
=
os
.
getenv
(
"AUDIO_OPENAI_API_MODEL"
,
"tts-1"
)
AUDIO_OPENAI_API_VOICE
=
os
.
getenv
(
"AUDIO_OPENAI_API_VOICE"
,
"alloy"
)
AUDIO_OPENAI_API_BASE_URL
=
PersistentConfig
(
"AUDIO_OPENAI_API_BASE_URL"
,
"audio.openai.api_base_url"
,
os
.
getenv
(
"AUDIO_OPENAI_API_BASE_URL"
,
OPENAI_API_BASE_URL
),
)
AUDIO_OPENAI_API_KEY
=
PersistentConfig
(
"AUDIO_OPENAI_API_KEY"
,
"audio.openai.api_key"
,
os
.
getenv
(
"AUDIO_OPENAI_API_KEY"
,
OPENAI_API_KEY
),
)
AUDIO_OPENAI_API_MODEL
=
PersistentConfig
(
"AUDIO_OPENAI_API_MODEL"
,
"audio.openai.api_model"
,
os
.
getenv
(
"AUDIO_OPENAI_API_MODEL"
,
"tts-1"
),
)
AUDIO_OPENAI_API_VOICE
=
PersistentConfig
(
"AUDIO_OPENAI_API_VOICE"
,
"audio.openai.api_voice"
,
os
.
getenv
(
"AUDIO_OPENAI_API_VOICE"
,
"alloy"
),
)
####################################
# LiteLLM
...
...
backend/main.py
View file @
564a3a29
from
contextlib
import
asynccontextmanager
from
bs4
import
BeautifulSoup
import
json
import
markdown
...
...
@@ -58,6 +59,7 @@ from config import (
SRC_LOG_LEVELS
,
WEBHOOK_URL
,
ENABLE_ADMIN_EXPORT
,
AppConfig
,
)
from
constants
import
ERROR_MESSAGES
...
...
@@ -92,12 +94,25 @@ https://github.com/open-webui/open-webui
"""
)
app
=
FastAPI
(
docs_url
=
"/docs"
if
ENV
==
"dev"
else
None
,
redoc_url
=
None
)
app
.
state
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
@
asynccontextmanager
async
def
lifespan
(
app
:
FastAPI
):
if
ENABLE_LITELLM
:
asyncio
.
create_task
(
start_litellm_background
())
yield
if
ENABLE_LITELLM
:
await
shutdown_litellm_background
()
app
=
FastAPI
(
docs_url
=
"/docs"
if
ENV
==
"dev"
else
None
,
redoc_url
=
None
,
lifespan
=
lifespan
)
app
.
state
.
config
=
AppConfig
()
app
.
state
.
config
.
ENABLE_MODEL_FILTER
=
ENABLE_MODEL_FILTER
app
.
state
.
config
.
MODEL_FILTER_LIST
=
MODEL_FILTER_LIST
app
.
state
.
WEBHOOK_URL
=
WEBHOOK_URL
app
.
state
.
config
.
WEBHOOK_URL
=
WEBHOOK_URL
origins
=
[
"*"
]
...
...
@@ -129,12 +144,12 @@ class RAGMiddleware(BaseHTTPMiddleware):
data
[
"messages"
],
citations
=
rag_messages
(
docs
=
data
[
"docs"
],
messages
=
data
[
"messages"
],
template
=
rag_app
.
state
.
RAG_TEMPLATE
,
template
=
rag_app
.
state
.
config
.
RAG_TEMPLATE
,
embedding_function
=
rag_app
.
state
.
EMBEDDING_FUNCTION
,
k
=
rag_app
.
state
.
TOP_K
,
k
=
rag_app
.
state
.
config
.
TOP_K
,
reranking_function
=
rag_app
.
state
.
sentence_transformer_rf
,
r
=
rag_app
.
state
.
RELEVANCE_THRESHOLD
,
hybrid_search
=
rag_app
.
state
.
ENABLE_RAG_HYBRID_SEARCH
,
r
=
rag_app
.
state
.
config
.
RELEVANCE_THRESHOLD
,
hybrid_search
=
rag_app
.
state
.
config
.
ENABLE_RAG_HYBRID_SEARCH
,
)
del
data
[
"docs"
]
...
...
@@ -211,12 +226,6 @@ async def check_url(request: Request, call_next):
return
response
@
app
.
on_event
(
"startup"
)
async
def
on_startup
():
if
ENABLE_LITELLM
:
asyncio
.
create_task
(
start_litellm_background
())
app
.
mount
(
"/api/v1"
,
webui_app
)
app
.
mount
(
"/litellm/api"
,
litellm_app
)
...
...
@@ -243,9 +252,9 @@ async def get_app_config():
"version"
:
VERSION
,
"auth"
:
WEBUI_AUTH
,
"default_locale"
:
default_locale
,
"images"
:
images_app
.
state
.
ENABLED
,
"default_models"
:
webui_app
.
state
.
DEFAULT_MODELS
,
"default_prompt_suggestions"
:
webui_app
.
state
.
DEFAULT_PROMPT_SUGGESTIONS
,
"images"
:
images_app
.
state
.
config
.
ENABLED
,
"default_models"
:
webui_app
.
state
.
config
.
DEFAULT_MODELS
,
"default_prompt_suggestions"
:
webui_app
.
state
.
config
.
DEFAULT_PROMPT_SUGGESTIONS
,
"trusted_header_auth"
:
bool
(
webui_app
.
state
.
AUTH_TRUSTED_EMAIL_HEADER
),
"admin_export_enabled"
:
ENABLE_ADMIN_EXPORT
,
}
...
...
@@ -254,8 +263,8 @@ async def get_app_config():
@
app
.
get
(
"/api/config/model/filter"
)
async
def
get_model_filter_config
(
user
=
Depends
(
get_admin_user
)):
return
{
"enabled"
:
app
.
state
.
ENABLE_MODEL_FILTER
,
"models"
:
app
.
state
.
MODEL_FILTER_LIST
,
"enabled"
:
app
.
state
.
config
.
ENABLE_MODEL_FILTER
,
"models"
:
app
.
state
.
config
.
MODEL_FILTER_LIST
,
}
...
...
@@ -268,28 +277,28 @@ class ModelFilterConfigForm(BaseModel):
async
def
update_model_filter_config
(
form_data
:
ModelFilterConfigForm
,
user
=
Depends
(
get_admin_user
)
):
app
.
state
.
ENABLE_MODEL_FILTER
=
form_data
.
enabled
app
.
state
.
MODEL_FILTER_LIST
=
form_data
.
models
app
.
state
.
config
.
ENABLE_MODEL_FILTER
,
form_data
.
enabled
app
.
state
.
config
.
MODEL_FILTER_LIST
,
form_data
.
models
ollama_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
ENABLE_MODEL_FILTER
ollama_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
MODEL_FILTER_LIST
ollama_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
config
.
ENABLE_MODEL_FILTER
ollama_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
config
.
MODEL_FILTER_LIST
openai_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
ENABLE_MODEL_FILTER
openai_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
MODEL_FILTER_LIST
openai_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
config
.
ENABLE_MODEL_FILTER
openai_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
config
.
MODEL_FILTER_LIST
litellm_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
ENABLE_MODEL_FILTER
litellm_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
MODEL_FILTER_LIST
litellm_app
.
state
.
ENABLE_MODEL_FILTER
=
app
.
state
.
config
.
ENABLE_MODEL_FILTER
litellm_app
.
state
.
MODEL_FILTER_LIST
=
app
.
state
.
config
.
MODEL_FILTER_LIST
return
{
"enabled"
:
app
.
state
.
ENABLE_MODEL_FILTER
,
"models"
:
app
.
state
.
MODEL_FILTER_LIST
,
"enabled"
:
app
.
state
.
config
.
ENABLE_MODEL_FILTER
,
"models"
:
app
.
state
.
config
.
MODEL_FILTER_LIST
,
}
@
app
.
get
(
"/api/webhook"
)
async
def
get_webhook_url
(
user
=
Depends
(
get_admin_user
)):
return
{
"url"
:
app
.
state
.
WEBHOOK_URL
,
"url"
:
app
.
state
.
config
.
WEBHOOK_URL
,
}
...
...
@@ -299,12 +308,12 @@ class UrlForm(BaseModel):
@
app
.
post
(
"/api/webhook"
)
async
def
update_webhook_url
(
form_data
:
UrlForm
,
user
=
Depends
(
get_admin_user
)):
app
.
state
.
WEBHOOK_URL
=
form_data
.
url
app
.
state
.
config
.
WEBHOOK_URL
=
form_data
.
url
webui_app
.
state
.
WEBHOOK_URL
=
app
.
state
.
WEBHOOK_URL
webui_app
.
state
.
WEBHOOK_URL
=
app
.
state
.
config
.
WEBHOOK_URL
return
{
"url"
:
app
.
state
.
WEBHOOK_URL
,
"url"
:
app
.
state
.
config
.
WEBHOOK_URL
,
}
...
...
@@ -381,9 +390,3 @@ else:
log
.
warning
(
f
"Frontend build directory not found at '
{
FRONTEND_BUILD_DIR
}
'. Serving API only."
)
@
app
.
on_event
(
"shutdown"
)
async
def
shutdown_event
():
if
ENABLE_LITELLM
:
await
shutdown_litellm_background
()
docker-compose.api.yaml
View file @
564a3a29
version
:
'
3.8'
services
:
ollama
:
# Expose Ollama API outside the container stack
...
...
docker-compose.data.yaml
View file @
564a3a29
version
:
'
3.8'
services
:
ollama
:
volumes
:
...
...
docker-compose.gpu.yaml
View file @
564a3a29
version
:
'
3.8'
services
:
ollama
:
# GPU support
...
...
docker-compose.yaml
View file @
564a3a29
version
:
'
3.8'
services
:
ollama
:
volumes
:
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment