Unverified Commit edbe5673 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge branch 'dev' into main

parents 1114f962 8b0144cd
## Pull Request Checklist
# Pull Request Checklist
- [ ] **Target branch:** Pull requests should target the `dev` branch.
- [ ] **Description:** Briefly describe the changes in this pull request.
......@@ -7,32 +7,46 @@
- [ ] **Dependencies:** Are there any new dependencies? Have you updated the dependency versions in the documentation?
- [ ] **Testing:** Have you written and run sufficient tests for the changes?
- [ ] **Code Review:** Have you self-reviewed your code and addressed any coding standard issues?
---
## Description
[Insert a brief description of the changes made in this pull request, including any relevant motivation and impact.]
---
### Changelog Entry
- [ ] **Label title:** Ensure the pull request title is labeled properly using one of the following:
- **BREAKING CHANGE**: Significant changes that may affect compatibility
- **build**: Changes that affect the build system or external dependencies
- **ci**: Changes to our continuous integration processes or workflows
- **chore**: Refactor, cleanup, or other non-functional code changes
- **docs**: Documentation update or addition
- **feat**: Introduces a new feature or enhancement to the codebase
- **fix**: Bug fix or error correction
- **i18n**: Internationalization or localization changes
- **perf**: Performance improvement
- **refactor**: Code restructuring for better maintainability, readability, or scalability
- **style**: Changes that do not affect the meaning of the code (white-space, formatting, missing semi-colons, etc.)
- **test**: Adding missing tests or correcting existing tests
- **WIP**: Work in progress, a temporary label for incomplete or ongoing work
# Changelog Entry
### Description
- [Briefly describe the changes made in this pull request, including any relevant motivation and impact.]
### Added
- [List any new features, functionalities, or additions]
### Fixed
- [List any fixes, corrections, or bug fixes]
### Changed
- [List any changes, updates, refactorings, or optimizations]
### Deprecated
- [List any deprecated functionality or features that have been removed]
### Removed
- [List any removed features, files, or deprecated functionalities]
- [List any removed features, files, or functionalities]
### Fixed
- [List any fixes, corrections, or bug fixes]
### Security
......@@ -40,12 +54,11 @@
### Breaking Changes
- [List any breaking changes affecting compatibility or functionality]
- **BREAKING CHANGE**: [List any breaking changes affecting compatibility or functionality]
---
### Additional Information
- [Insert any additional context, notes, or explanations for the changes]
- [Reference any related issues, commits, or other relevant information]
- [Reference any related issues, commits, or other relevant information]
# Ignore files for PNPM, NPM and YARN
pnpm-lock.yaml
package-lock.json
yarn.lock
kubernetes/
# Copy of .gitignore
.DS_Store
node_modules
/build
......@@ -6,11 +14,299 @@ node_modules
.env
.env.*
!.env.example
vite.config.js.timestamp-*
vite.config.ts.timestamp-*
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# Ignore files for PNPM, NPM and YARN
pnpm-lock.yaml
package-lock.json
yarn.lock
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# Logs
logs
*.log
npm-debug.log*
yarn-debug.log*
yarn-error.log*
lerna-debug.log*
.pnpm-debug.log*
# Diagnostic reports (https://nodejs.org/api/report.html)
report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
# Runtime data
pids
*.pid
*.seed
*.pid.lock
# Directory for instrumented libs generated by jscoverage/JSCover
lib-cov
# Coverage directory used by tools like istanbul
coverage
*.lcov
# nyc test coverage
.nyc_output
# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
.grunt
# Bower dependency directory (https://bower.io/)
bower_components
# node-waf configuration
.lock-wscript
# Compiled binary addons (https://nodejs.org/api/addons.html)
build/Release
# Dependency directories
node_modules/
jspm_packages/
# Snowpack dependency directory (https://snowpack.dev/)
web_modules/
# TypeScript cache
*.tsbuildinfo
# Optional npm cache directory
.npm
# Optional eslint cache
.eslintcache
# Optional stylelint cache
.stylelintcache
# Microbundle cache
.rpt2_cache/
.rts2_cache_cjs/
.rts2_cache_es/
.rts2_cache_umd/
# Optional REPL history
.node_repl_history
# Output of 'npm pack'
*.tgz
# Yarn Integrity file
.yarn-integrity
# dotenv environment variable files
.env
.env.development.local
.env.test.local
.env.production.local
.env.local
# parcel-bundler cache (https://parceljs.org/)
.cache
.parcel-cache
# Next.js build output
.next
out
# Nuxt.js build / generate output
.nuxt
dist
# Gatsby files
.cache/
# Comment in the public line in if your project uses Gatsby and not Next.js
# https://nextjs.org/blog/next-9-1#public-directory-support
# public
# vuepress build output
.vuepress/dist
# vuepress v2.x temp and cache directory
.temp
.cache
# Docusaurus cache and generated files
.docusaurus
# Serverless directories
.serverless/
# FuseBox cache
.fusebox/
# DynamoDB Local files
.dynamodb/
# TernJS port file
.tern-port
# Stores VSCode versions used for testing VSCode extensions
.vscode-test
# yarn v2
.yarn/cache
.yarn/unplugged
.yarn/build-state.yml
.yarn/install-state.gz
.pnp.*
# Ignore kubernetes files
kubernetes
\ No newline at end of file
# cypress artifacts
cypress/videos
cypress/screenshots
......@@ -82,7 +82,7 @@ RUN echo -n 00000000-0000-0000-0000-000000000000 > $HOME/.cache/chroma/telemetry
RUN if [ "$USE_OLLAMA" = "true" ]; then \
apt-get update && \
# Install pandoc and netcat
apt-get install -y --no-install-recommends pandoc netcat-openbsd && \
apt-get install -y --no-install-recommends pandoc netcat-openbsd curl && \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# install helper tools
......@@ -94,7 +94,7 @@ RUN if [ "$USE_OLLAMA" = "true" ]; then \
else \
apt-get update && \
# Install pandoc and netcat
apt-get install -y --no-install-recommends pandoc netcat-openbsd && \
apt-get install -y --no-install-recommends pandoc netcat-openbsd curl && \
# for RAG OCR
apt-get install -y --no-install-recommends ffmpeg libsm6 libxext6 && \
# cleanup
......@@ -134,4 +134,6 @@ COPY ./backend .
EXPOSE 8080
HEALTHCHECK CMD curl --fail http://localhost:8080 || exit 1
CMD [ "bash", "start.sh"]
......@@ -45,6 +45,7 @@ from config import (
AUDIO_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL,
AUDIO_OPENAI_API_VOICE,
AppConfig,
)
log = logging.getLogger(__name__)
......@@ -59,11 +60,11 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
app.state.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
app.state.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
app.state.config = AppConfig()
app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
......@@ -83,10 +84,10 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)):
return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE,
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
}
......@@ -97,17 +98,17 @@ async def update_openai_config(
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key
app.state.OPENAI_API_MODEL = form_data.model
app.state.OPENAI_API_VOICE = form_data.speaker
app.state.config.OPENAI_API_BASE_URL = form_data.url
app.state.config.OPENAI_API_KEY = form_data.key
app.state.config.OPENAI_API_MODEL = form_data.model
app.state.config.OPENAI_API_VOICE = form_data.speaker
return {
"status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.OPENAI_API_VOICE,
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
}
......@@ -124,13 +125,13 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URL}/audio/speech",
url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech",
data=body,
headers=headers,
stream=True,
......
......@@ -42,6 +42,7 @@ from config import (
IMAGE_GENERATION_MODEL,
IMAGE_SIZE,
IMAGE_STEPS,
AppConfig,
)
......@@ -60,26 +61,31 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.ENGINE = IMAGE_GENERATION_ENGINE
app.state.ENABLED = ENABLE_IMAGE_GENERATION
app.state.config = AppConfig()
app.state.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.ENGINE = IMAGE_GENERATION_ENGINE
app.state.config.ENABLED = ENABLE_IMAGE_GENERATION
app.state.MODEL = IMAGE_GENERATION_MODEL
app.state.config.OPENAI_API_BASE_URL = IMAGES_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = IMAGES_OPENAI_API_KEY
app.state.config.MODEL = IMAGE_GENERATION_MODEL
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.IMAGE_SIZE = IMAGE_SIZE
app.state.IMAGE_STEPS = IMAGE_STEPS
app.state.config.IMAGE_SIZE = IMAGE_SIZE
app.state.config.IMAGE_STEPS = IMAGE_STEPS
@app.get("/config")
async def get_config(request: Request, user=Depends(get_admin_user)):
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
return {
"engine": app.state.config.ENGINE,
"enabled": app.state.config.ENABLED,
}
class ConfigUpdateForm(BaseModel):
......@@ -89,9 +95,12 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update")
async def update_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.ENGINE = form_data.engine
app.state.ENABLED = form_data.enabled
return {"engine": app.state.ENGINE, "enabled": app.state.ENABLED}
app.state.config.ENGINE = form_data.engine
app.state.config.ENABLED = form_data.enabled
return {
"engine": app.state.config.ENGINE,
"enabled": app.state.config.ENABLED,
}
class EngineUrlUpdateForm(BaseModel):
......@@ -102,8 +111,8 @@ class EngineUrlUpdateForm(BaseModel):
@app.get("/url")
async def get_engine_url(user=Depends(get_admin_user)):
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
}
......@@ -113,29 +122,29 @@ async def update_engine_url(
):
if form_data.AUTOMATIC1111_BASE_URL == None:
app.state.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
app.state.config.AUTOMATIC1111_BASE_URL = AUTOMATIC1111_BASE_URL
else:
url = form_data.AUTOMATIC1111_BASE_URL.strip("/")
try:
r = requests.head(url)
app.state.AUTOMATIC1111_BASE_URL = url
app.state.config.AUTOMATIC1111_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
if form_data.COMFYUI_BASE_URL == None:
app.state.COMFYUI_BASE_URL = COMFYUI_BASE_URL
app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
else:
url = form_data.COMFYUI_BASE_URL.strip("/")
try:
r = requests.head(url)
app.state.COMFYUI_BASE_URL = url
app.state.config.COMFYUI_BASE_URL = url
except Exception as e:
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
return {
"AUTOMATIC1111_BASE_URL": app.state.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.COMFYUI_BASE_URL,
"AUTOMATIC1111_BASE_URL": app.state.config.AUTOMATIC1111_BASE_URL,
"COMFYUI_BASE_URL": app.state.config.COMFYUI_BASE_URL,
"status": True,
}
......@@ -148,8 +157,8 @@ class OpenAIConfigUpdateForm(BaseModel):
@app.get("/openai/config")
async def get_openai_config(user=Depends(get_admin_user)):
return {
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
}
......@@ -160,13 +169,13 @@ async def update_openai_config(
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.OPENAI_API_BASE_URL = form_data.url
app.state.OPENAI_API_KEY = form_data.key
app.state.config.OPENAI_API_BASE_URL = form_data.url
app.state.config.OPENAI_API_KEY = form_data.key
return {
"status": True,
"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.OPENAI_API_KEY,
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
}
......@@ -176,7 +185,7 @@ class ImageSizeUpdateForm(BaseModel):
@app.get("/size")
async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_SIZE": app.state.IMAGE_SIZE}
return {"IMAGE_SIZE": app.state.config.IMAGE_SIZE}
@app.post("/size/update")
......@@ -185,9 +194,9 @@ async def update_image_size(
):
pattern = r"^\d+x\d+$" # Regular expression pattern
if re.match(pattern, form_data.size):
app.state.IMAGE_SIZE = form_data.size
app.state.config.IMAGE_SIZE = form_data.size
return {
"IMAGE_SIZE": app.state.IMAGE_SIZE,
"IMAGE_SIZE": app.state.config.IMAGE_SIZE,
"status": True,
}
else:
......@@ -203,7 +212,7 @@ class ImageStepsUpdateForm(BaseModel):
@app.get("/steps")
async def get_image_size(user=Depends(get_admin_user)):
return {"IMAGE_STEPS": app.state.IMAGE_STEPS}
return {"IMAGE_STEPS": app.state.config.IMAGE_STEPS}
@app.post("/steps/update")
......@@ -211,9 +220,9 @@ async def update_image_size(
form_data: ImageStepsUpdateForm, user=Depends(get_admin_user)
):
if form_data.steps >= 0:
app.state.IMAGE_STEPS = form_data.steps
app.state.config.IMAGE_STEPS = form_data.steps
return {
"IMAGE_STEPS": app.state.IMAGE_STEPS,
"IMAGE_STEPS": app.state.config.IMAGE_STEPS,
"status": True,
}
else:
......@@ -226,14 +235,14 @@ async def update_image_size(
@app.get("/models")
def get_models(user=Depends(get_current_user)):
try:
if app.state.ENGINE == "openai":
if app.state.config.ENGINE == "openai":
return [
{"id": "dall-e-2", "name": "DALL·E 2"},
{"id": "dall-e-3", "name": "DALL·E 3"},
]
elif app.state.ENGINE == "comfyui":
elif app.state.config.ENGINE == "comfyui":
r = requests.get(url=f"{app.state.COMFYUI_BASE_URL}/object_info")
r = requests.get(url=f"{app.state.config.COMFYUI_BASE_URL}/object_info")
info = r.json()
return list(
......@@ -245,7 +254,7 @@ def get_models(user=Depends(get_current_user)):
else:
r = requests.get(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/sd-models"
)
models = r.json()
return list(
......@@ -255,23 +264,29 @@ def get_models(user=Depends(get_current_user)):
)
)
except Exception as e:
app.state.ENABLED = False
app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
@app.get("/models/default")
async def get_default_model(user=Depends(get_admin_user)):
try:
if app.state.ENGINE == "openai":
return {"model": app.state.MODEL if app.state.MODEL else "dall-e-2"}
elif app.state.ENGINE == "comfyui":
return {"model": app.state.MODEL if app.state.MODEL else ""}
if app.state.config.ENGINE == "openai":
return {
"model": (
app.state.config.MODEL if app.state.config.MODEL else "dall-e-2"
)
}
elif app.state.config.ENGINE == "comfyui":
return {"model": (app.state.config.MODEL if app.state.config.MODEL else "")}
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
)
options = r.json()
return {"model": options["sd_model_checkpoint"]}
except Exception as e:
app.state.ENABLED = False
app.state.config.ENABLED = False
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.DEFAULT(e))
......@@ -280,20 +295,20 @@ class UpdateModelForm(BaseModel):
def set_model_handler(model: str):
if app.state.ENGINE == "openai":
app.state.MODEL = model
return app.state.MODEL
if app.state.ENGINE == "comfyui":
app.state.MODEL = model
return app.state.MODEL
if app.state.config.ENGINE in ["openai", "comfyui"]:
app.state.config.MODEL = model
return app.state.config.MODEL
else:
r = requests.get(url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options")
r = requests.get(
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options"
)
options = r.json()
if model != options["sd_model_checkpoint"]:
options["sd_model_checkpoint"] = model
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/options", json=options
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/options",
json=options,
)
return options
......@@ -382,26 +397,32 @@ def generate_image(
user=Depends(get_current_user),
):
width, height = tuple(map(int, app.state.IMAGE_SIZE.split("x")))
width, height = tuple(map(int, app.state.config.IMAGE_SIZE).split("x"))
r = None
try:
if app.state.ENGINE == "openai":
if app.state.config.ENGINE == "openai":
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
data = {
"model": app.state.MODEL if app.state.MODEL != "" else "dall-e-2",
"model": (
app.state.config.MODEL
if app.state.config.MODEL != ""
else "dall-e-2"
),
"prompt": form_data.prompt,
"n": form_data.n,
"size": form_data.size if form_data.size else app.state.IMAGE_SIZE,
"size": (
form_data.size if form_data.size else app.state.config.IMAGE_SIZE
),
"response_format": "b64_json",
}
r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URL}/images/generations",
url=f"{app.state.config.OPENAI_API_BASE_URL}/images/generations",
json=data,
headers=headers,
)
......@@ -421,7 +442,7 @@ def generate_image(
return images
elif app.state.ENGINE == "comfyui":
elif app.state.config.ENGINE == "comfyui":
data = {
"prompt": form_data.prompt,
......@@ -430,19 +451,19 @@ def generate_image(
"n": form_data.n,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.config.IMAGE_STEPS
if form_data.negative_prompt != None:
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
data = ImageGenerationPayload(**data)
res = comfyui_generate_image(
app.state.MODEL,
app.state.config.MODEL,
data,
user.id,
app.state.COMFYUI_BASE_URL,
app.state.config.COMFYUI_BASE_URL,
)
log.debug(f"res: {res}")
......@@ -469,14 +490,14 @@ def generate_image(
"height": height,
}
if app.state.IMAGE_STEPS != None:
data["steps"] = app.state.IMAGE_STEPS
if app.state.config.IMAGE_STEPS is not None:
data["steps"] = app.state.config.IMAGE_STEPS
if form_data.negative_prompt != None:
if form_data.negative_prompt is not None:
data["negative_prompt"] = form_data.negative_prompt
r = requests.post(
url=f"{app.state.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
url=f"{app.state.config.AUTOMATIC1111_BASE_URL}/sdapi/v1/txt2img",
json=data,
)
......
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException
from fastapi.routing import APIRoute
......@@ -46,7 +47,16 @@ import asyncio
import subprocess
import yaml
app = FastAPI()
@asynccontextmanager
async def lifespan(app: FastAPI):
log.info("startup_event")
# TODO: Check config.yaml file and create one
asyncio.create_task(start_litellm_background())
yield
app = FastAPI(lifespan=lifespan)
origins = ["*"]
......@@ -141,13 +151,6 @@ async def shutdown_litellm_background():
background_process = None
@app.on_event("startup")
async def startup_event():
log.info("startup_event")
# TODO: Check config.yaml file and create one
asyncio.create_task(start_litellm_background())
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
......
......@@ -46,6 +46,7 @@ from config import (
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
UPLOAD_DIR,
AppConfig,
)
from utils.misc import calculate_sha256
......@@ -61,11 +62,12 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
......@@ -96,7 +98,7 @@ async def get_status():
@app.get("/urls")
async def get_ollama_api_urls(user=Depends(get_admin_user)):
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
class UrlUpdateForm(BaseModel):
......@@ -105,10 +107,10 @@ class UrlUpdateForm(BaseModel):
@app.post("/urls/update")
async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
app.state.OLLAMA_BASE_URLS = form_data.urls
app.state.config.OLLAMA_BASE_URLS = form_data.urls
log.info(f"app.state.OLLAMA_BASE_URLS: {app.state.OLLAMA_BASE_URLS}")
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
log.info(f"app.state.config.OLLAMA_BASE_URLS: {app.state.config.OLLAMA_BASE_URLS}")
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
@app.get("/cancel/{request_id}")
......@@ -153,7 +155,7 @@ def merge_models_lists(model_lists):
async def get_all_models():
log.info("get_all_models()")
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS]
responses = await asyncio.gather(*tasks)
models = {
......@@ -186,7 +188,7 @@ async def get_ollama_tags(
return models
return models
else:
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try:
r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status()
......@@ -216,7 +218,9 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
if url_idx == None:
# returns lowest version
tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS]
tasks = [
fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
......@@ -235,7 +239,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
)
else:
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try:
r = requests.request(method="GET", url=f"{url}/api/version")
r.raise_for_status()
......@@ -267,7 +271,7 @@ class ModelNameForm(BaseModel):
async def pull_model(
form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
):
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
......@@ -355,7 +359,7 @@ async def push_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}")
r = None
......@@ -417,7 +421,7 @@ async def create_model(
form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
):
log.debug(f"form_data: {form_data}")
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
......@@ -490,7 +494,7 @@ async def copy_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.source),
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
......@@ -537,7 +541,7 @@ async def delete_model(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
......@@ -577,7 +581,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
)
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
......@@ -634,7 +638,7 @@ async def generate_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
......@@ -684,7 +688,7 @@ def generate_ollama_embeddings(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
try:
......@@ -753,7 +757,7 @@ async def generate_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
......@@ -856,7 +860,7 @@ async def generate_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
......@@ -965,7 +969,7 @@ async def generate_openai_chat_completion(
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
......@@ -1064,7 +1068,7 @@ async def get_openai_models(
}
else:
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
try:
r = requests.request(method="GET", url=f"{url}/api/tags")
r.raise_for_status()
......@@ -1198,7 +1202,7 @@ async def download_model(
if url_idx == None:
url_idx = 0
url = app.state.OLLAMA_BASE_URLS[url_idx]
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
file_name = parse_huggingface_url(form_data.url)
......@@ -1217,7 +1221,7 @@ async def download_model(
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
if url_idx == None:
url_idx = 0
ollama_url = app.state.OLLAMA_BASE_URLS[url_idx]
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
file_path = f"{UPLOAD_DIR}/{file.filename}"
......@@ -1282,7 +1286,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
# if url_idx == None:
# url_idx = 0
# url = app.state.OLLAMA_BASE_URLS[url_idx]
# url = app.state.config.OLLAMA_BASE_URLS[url_idx]
# file_location = os.path.join(UPLOAD_DIR, file.filename)
# total_size = file.size
......@@ -1319,7 +1323,7 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
async def deprecated_proxy(
path: str, request: Request, user=Depends(get_verified_user)
):
url = app.state.OLLAMA_BASE_URLS[0]
url = app.state.config.OLLAMA_BASE_URLS[0]
target_url = f"{url}/{path}"
body = await request.body()
......
......@@ -26,6 +26,7 @@ from config import (
CACHE_DIR,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
AppConfig,
)
from typing import List, Optional
......@@ -45,11 +46,13 @@ app.add_middleware(
allow_headers=["*"],
)
app.state.config = AppConfig()
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.MODELS = {}
......@@ -75,32 +78,32 @@ class KeysUpdateForm(BaseModel):
@app.get("/urls")
async def get_openai_urls(user=Depends(get_admin_user)):
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.post("/urls/update")
async def update_openai_urls(form_data: UrlsUpdateForm, user=Depends(get_admin_user)):
await get_all_models()
app.state.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": app.state.OPENAI_API_BASE_URLS}
app.state.config.OPENAI_API_BASE_URLS = form_data.urls
return {"OPENAI_API_BASE_URLS": app.state.config.OPENAI_API_BASE_URLS}
@app.get("/keys")
async def get_openai_keys(user=Depends(get_admin_user)):
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/keys/update")
async def update_openai_key(form_data: KeysUpdateForm, user=Depends(get_admin_user)):
app.state.OPENAI_API_KEYS = form_data.keys
return {"OPENAI_API_KEYS": app.state.OPENAI_API_KEYS}
app.state.config.OPENAI_API_KEYS = form_data.keys
return {"OPENAI_API_KEYS": app.state.config.OPENAI_API_KEYS}
@app.post("/audio/speech")
async def speech(request: Request, user=Depends(get_verified_user)):
idx = None
try:
idx = app.state.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
idx = app.state.config.OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
body = await request.body()
name = hashlib.sha256(body).hexdigest()
......@@ -114,13 +117,15 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEYS[idx]}"
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEYS[idx]}"
headers["Content-Type"] = "application/json"
if "openrouter.ai" in app.state.OPENAI_API_BASE_URLS[idx]:
headers['HTTP-Referer'] = "https://openwebui.com/"
headers['X-Title'] = "Open WebUI"
r = None
try:
r = requests.post(
url=f"{app.state.OPENAI_API_BASE_URLS[idx]}/audio/speech",
url=f"{app.state.config.OPENAI_API_BASE_URLS[idx]}/audio/speech",
data=body,
headers=headers,
stream=True,
......@@ -180,7 +185,8 @@ def merge_models_lists(model_lists):
[
{**model, "urlIdx": idx}
for model in models
if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
if "api.openai.com"
not in app.state.config.OPENAI_API_BASE_URLS[idx]
or "gpt" in model["id"]
]
)
......@@ -191,12 +197,15 @@ def merge_models_lists(model_lists):
async def get_all_models():
log.info("get_all_models()")
if len(app.state.OPENAI_API_KEYS) == 1 and app.state.OPENAI_API_KEYS[0] == "":
if (
len(app.state.config.OPENAI_API_KEYS) == 1
and app.state.config.OPENAI_API_KEYS[0] == ""
):
models = {"data": []}
else:
tasks = [
fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
]
responses = await asyncio.gather(*tasks)
......@@ -239,7 +248,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
return models
return models
else:
url = app.state.OPENAI_API_BASE_URLS[url_idx]
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
r = None
......@@ -303,8 +312,8 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e)
url = app.state.OPENAI_API_BASE_URLS[idx]
key = app.state.OPENAI_API_KEYS[idx]
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
target_url = f"{url}/{path}"
......
......@@ -93,6 +93,7 @@ from config import (
RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
AppConfig,
)
from constants import ERROR_MESSAGES
......@@ -102,30 +103,32 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
app = FastAPI()
app.state.TOP_K = RAG_TOP_K
app.state.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config = AppConfig()
app.state.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.TOP_K = RAG_TOP_K
app.state.config.RELEVANCE_THRESHOLD = RAG_RELEVANCE_THRESHOLD
app.state.config.ENABLE_RAG_HYBRID_SEARCH = ENABLE_RAG_HYBRID_SEARCH
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.CHUNK_SIZE = CHUNK_SIZE
app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.RAG_TEMPLATE = RAG_TEMPLATE
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
app.state.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
app.state.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None
......@@ -133,7 +136,7 @@ def update_embedding_model(
embedding_model: str,
update_model: bool = False,
):
if embedding_model and app.state.RAG_EMBEDDING_ENGINE == "":
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
......@@ -158,22 +161,22 @@ def update_reranking_model(
update_embedding_model(
app.state.RAG_EMBEDDING_MODEL,
app.state.config.RAG_EMBEDDING_MODEL,
RAG_EMBEDDING_MODEL_AUTO_UPDATE,
)
update_reranking_model(
app.state.RAG_RERANKING_MODEL,
app.state.config.RAG_RERANKING_MODEL,
RAG_RERANKING_MODEL_AUTO_UPDATE,
)
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
)
origins = ["*"]
......@@ -200,12 +203,12 @@ class UrlForm(CollectionNameForm):
async def get_status():
return {
"status": True,
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"template": app.state.RAG_TEMPLATE,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.RAG_RERANKING_MODEL,
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
"template": app.state.config.RAG_TEMPLATE,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
......@@ -213,18 +216,21 @@ async def get_status():
async def get_embedding_config(user=Depends(get_admin_user)):
return {
"status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY,
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
},
}
@app.get("/reranking")
async def get_reraanking_config(user=Depends(get_admin_user)):
return {"status": True, "reranking_model": app.state.RAG_RERANKING_MODEL}
return {
"status": True,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
class OpenAIConfigForm(BaseModel):
......@@ -243,34 +249,34 @@ async def update_embedding_config(
form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
f"Updating embedding model: {app.state.config.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
)
try:
app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
app.state.config.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config != None:
app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.OPENAI_API_KEY = form_data.openai_config.key
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
update_embedding_model(app.state.RAG_EMBEDDING_MODEL, True)
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL), True
app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
)
return {
"status": True,
"embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.RAG_EMBEDDING_MODEL,
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"openai_config": {
"url": app.state.OPENAI_API_BASE_URL,
"key": app.state.OPENAI_API_KEY,
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
},
}
except Exception as e:
......@@ -290,16 +296,16 @@ async def update_reranking_config(
form_data: RerankingModelUpdateForm, user=Depends(get_admin_user)
):
log.info(
f"Updating reranking model: {app.state.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
f"Updating reranking model: {app.state.config.RAG_RERANKING_MODEL} to {form_data.reranking_model}"
)
try:
app.state.RAG_RERANKING_MODEL = form_data.reranking_model
app.state.config.RAG_RERANKING_MODEL = form_data.reranking_model
update_reranking_model(app.state.RAG_RERANKING_MODEL, True)
update_reranking_model(app.state.config.RAG_RERANKING_MODEL), True
return {
"status": True,
"reranking_model": app.state.RAG_RERANKING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
}
except Exception as e:
log.exception(f"Problem updating reranking model: {e}")
......@@ -313,14 +319,14 @@ async def update_reranking_config(
async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": {
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
......@@ -345,50 +351,52 @@ class ConfigUpdateForm(BaseModel):
@app.post("/config/update")
async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_user)):
app.state.PDF_EXTRACT_IMAGES = (
app.state.config.PDF_EXTRACT_IMAGES = (
form_data.pdf_extract_images
if form_data.pdf_extract_images != None
else app.state.PDF_EXTRACT_IMAGES
if form_data.pdf_extract_images is not None
else app.state.config.PDF_EXTRACT_IMAGES
)
app.state.CHUNK_SIZE = (
form_data.chunk.chunk_size if form_data.chunk != None else app.state.CHUNK_SIZE
app.state.config.CHUNK_SIZE = (
form_data.chunk.chunk_size
if form_data.chunk is not None
else app.state.config.CHUNK_SIZE
)
app.state.CHUNK_OVERLAP = (
app.state.config.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk != None
else app.state.CHUNK_OVERLAP
if form_data.chunk is not None
else app.state.config.CHUNK_OVERLAP
)
app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.YOUTUBE_LOADER_LANGUAGE = (
app.state.config.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language
if form_data.youtube != None
else app.state.YOUTUBE_LOADER_LANGUAGE
if form_data.youtube is not None
else app.state.config.YOUTUBE_LOADER_LANGUAGE
)
app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation
if form_data.youtube != None
if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_TRANSLATION
)
return {
"status": True,
"pdf_extract_images": app.state.PDF_EXTRACT_IMAGES,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"chunk": {
"chunk_size": app.state.CHUNK_SIZE,
"chunk_overlap": app.state.CHUNK_OVERLAP,
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.YOUTUBE_LOADER_LANGUAGE,
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
}
......@@ -398,7 +406,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
async def get_rag_template(user=Depends(get_current_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"template": app.state.config.RAG_TEMPLATE,
}
......@@ -406,10 +414,10 @@ async def get_rag_template(user=Depends(get_current_user)):
async def get_query_settings(user=Depends(get_admin_user)):
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
"template": app.state.config.RAG_TEMPLATE,
"k": app.state.config.TOP_K,
"r": app.state.config.RELEVANCE_THRESHOLD,
"hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
}
......@@ -424,16 +432,20 @@ class QuerySettingsForm(BaseModel):
async def update_query_settings(
form_data: QuerySettingsForm, user=Depends(get_admin_user)
):
app.state.RAG_TEMPLATE = form_data.template if form_data.template else RAG_TEMPLATE
app.state.TOP_K = form_data.k if form_data.k else 4
app.state.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.ENABLE_RAG_HYBRID_SEARCH = form_data.hybrid if form_data.hybrid else False
app.state.config.RAG_TEMPLATE = (
form_data.template if form_data.template else RAG_TEMPLATE,
)
app.state.config.TOP_K = form_data.k if form_data.k else 4
app.state.config.RELEVANCE_THRESHOLD = form_data.r if form_data.r else 0.0
app.state.config.ENABLE_RAG_HYBRID_SEARCH = (
form_data.hybrid if form_data.hybrid else False,
)
return {
"status": True,
"template": app.state.RAG_TEMPLATE,
"k": app.state.TOP_K,
"r": app.state.RELEVANCE_THRESHOLD,
"hybrid": app.state.ENABLE_RAG_HYBRID_SEARCH,
"template": app.state.config.RAG_TEMPLATE,
"k": app.state.config.TOP_K,
"r": app.state.config.RELEVANCE_THRESHOLD,
"hybrid": app.state.config.ENABLE_RAG_HYBRID_SEARCH,
}
......@@ -451,21 +463,23 @@ def query_doc_handler(
user=Depends(get_current_user),
):
try:
if app.state.ENABLE_RAG_HYBRID_SEARCH:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_doc_with_hybrid_search(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
r=(
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
),
)
else:
return query_doc(
collection_name=form_data.collection_name,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else app.state.config.TOP_K,
)
except Exception as e:
log.exception(e)
......@@ -489,21 +503,23 @@ def query_collection_handler(
user=Depends(get_current_user),
):
try:
if app.state.ENABLE_RAG_HYBRID_SEARCH:
if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
return query_collection_with_hybrid_search(
collection_names=form_data.collection_names,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else app.state.config.TOP_K,
reranking_function=app.state.sentence_transformer_rf,
r=form_data.r if form_data.r else app.state.RELEVANCE_THRESHOLD,
r=(
form_data.r if form_data.r else app.state.config.RELEVANCE_THRESHOLD
),
)
else:
return query_collection(
collection_names=form_data.collection_names,
query=form_data.query,
embedding_function=app.state.EMBEDDING_FUNCTION,
k=form_data.k if form_data.k else app.state.TOP_K,
k=form_data.k if form_data.k else app.state.config.TOP_K,
)
except Exception as e:
......@@ -520,7 +536,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
loader = YoutubeLoader.from_youtube_url(
form_data.url,
add_video_info=True,
language=app.state.YOUTUBE_LOADER_LANGUAGE,
language=app.state.config.YOUTUBE_LOADER_LANGUAGE,
translation=app.state.YOUTUBE_LOADER_TRANSLATION,
)
data = loader.load()
......@@ -548,7 +564,8 @@ def store_web(form_data: UrlForm, user=Depends(get_current_user)):
# "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
try:
loader = get_web_loader(
form_data.url, verify_ssl=app.state.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
form_data.url,
verify_ssl=app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
)
data = loader.load()
......@@ -604,8 +621,8 @@ def resolve_hostname(hostname):
def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
......@@ -622,8 +639,8 @@ def store_text_in_vector_db(
text, metadata, collection_name, overwrite: bool = False
) -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE,
chunk_overlap=app.state.CHUNK_OVERLAP,
chunk_size=app.state.config.CHUNK_SIZE,
chunk_overlap=app.state.config.CHUNK_OVERLAP,
add_start_index=True,
)
docs = text_splitter.create_documents([text], metadatas=[metadata])
......@@ -646,11 +663,11 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
collection = CHROMA_CLIENT.create_collection(name=collection_name)
embedding_func = get_embedding_function(
app.state.RAG_EMBEDDING_ENGINE,
app.state.RAG_EMBEDDING_MODEL,
app.state.config.RAG_EMBEDDING_ENGINE,
app.state.config.RAG_EMBEDDING_MODEL,
app.state.sentence_transformer_ef,
app.state.OPENAI_API_KEY,
app.state.OPENAI_API_BASE_URL,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
......@@ -724,7 +741,9 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
]
if file_ext == "pdf":
loader = PyPDFLoader(file_path, extract_images=app.state.PDF_EXTRACT_IMAGES)
loader = PyPDFLoader(
file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
)
elif file_ext == "csv":
loader = CSVLoader(file_path)
elif file_ext == "rst":
......
......@@ -21,20 +21,24 @@ from config import (
USER_PERMISSIONS,
WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
AppConfig,
)
app = FastAPI()
origins = ["*"]
app.state.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.JWT_EXPIRES_IN = "-1"
app.state.config = AppConfig()
app.state.DEFAULT_MODELS = DEFAULT_MODELS
app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS
app.state.WEBHOOK_URL = WEBHOOK_URL
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.add_middleware(
......@@ -61,6 +65,6 @@ async def get_status():
return {
"status": True,
"auth": WEBUI_AUTH,
"default_models": app.state.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.DEFAULT_PROMPT_SUGGESTIONS,
"default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
}
......@@ -140,7 +140,7 @@ async def signin(request: Request, form_data: SigninForm):
if user:
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
)
return {
......@@ -163,7 +163,7 @@ async def signin(request: Request, form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, form_data: SignupForm):
if not request.app.state.ENABLE_SIGNUP and WEBUI_AUTH:
if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
......@@ -180,7 +180,7 @@ async def signup(request: Request, form_data: SignupForm):
role = (
"admin"
if Users.get_num_users() == 0
else request.app.state.DEFAULT_USER_ROLE
else request.app.state.config.DEFAULT_USER_ROLE
)
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
......@@ -194,13 +194,13 @@ async def signup(request: Request, form_data: SignupForm):
if user:
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
)
# response.set_cookie(key='token', value=token, httponly=True)
if request.app.state.WEBHOOK_URL:
if request.app.state.config.WEBHOOK_URL:
post_webhook(
request.app.state.WEBHOOK_URL,
request.app.state.config.WEBHOOK_URL,
WEBHOOK_MESSAGES.USER_SIGNUP(user.name),
{
"action": "signup",
......@@ -276,13 +276,13 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
@router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
return request.app.state.ENABLE_SIGNUP
return request.app.state.config.ENABLE_SIGNUP
@router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
return request.app.state.ENABLE_SIGNUP
request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
return request.app.state.config.ENABLE_SIGNUP
############################
......@@ -292,7 +292,7 @@ async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
@router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
return request.app.state.DEFAULT_USER_ROLE
return request.app.state.config.DEFAULT_USER_ROLE
class UpdateRoleForm(BaseModel):
......@@ -304,8 +304,8 @@ async def update_default_user_role(
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
):
if form_data.role in ["pending", "user", "admin"]:
request.app.state.DEFAULT_USER_ROLE = form_data.role
return request.app.state.DEFAULT_USER_ROLE
request.app.state.config.DEFAULT_USER_ROLE = form_data.role
return request.app.state.config.DEFAULT_USER_ROLE
############################
......@@ -315,7 +315,7 @@ async def update_default_user_role(
@router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return request.app.state.JWT_EXPIRES_IN
return request.app.state.config.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel):
......@@ -332,10 +332,10 @@ async def update_token_expires_duration(
# Check if the input string matches the pattern
if re.match(pattern, form_data.duration):
request.app.state.JWT_EXPIRES_IN = form_data.duration
return request.app.state.JWT_EXPIRES_IN
request.app.state.config.JWT_EXPIRES_IN = form_data.duration
return request.app.state.config.JWT_EXPIRES_IN
else:
return request.app.state.JWT_EXPIRES_IN
return request.app.state.config.JWT_EXPIRES_IN
############################
......
......@@ -44,8 +44,8 @@ class SetDefaultSuggestionsForm(BaseModel):
async def set_global_default_models(
request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
):
request.app.state.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS
request.app.state.config.DEFAULT_MODELS = form_data.models
return request.app.state.config.DEFAULT_MODELS
@router.post("/default/suggestions", response_model=List[PromptSuggestion])
......@@ -55,5 +55,5 @@ async def set_global_default_suggestions(
user=Depends(get_admin_user),
):
data = form_data.model_dump()
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.config.DEFAULT_PROMPT_SUGGESTIONS
......@@ -39,15 +39,15 @@ async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)
@router.get("/permissions/user")
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
return request.app.state.USER_PERMISSIONS
return request.app.state.config.USER_PERMISSIONS
@router.post("/permissions/user")
async def update_user_permissions(
request: Request, form_data: dict, user=Depends(get_admin_user)
):
request.app.state.USER_PERMISSIONS = form_data
return request.app.state.USER_PERMISSIONS
request.app.state.config.USER_PERMISSIONS = form_data
return request.app.state.config.USER_PERMISSIONS
############################
......
......@@ -5,6 +5,7 @@ import chromadb
from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup
from typing import TypeVar, Generic, Union
from pathlib import Path
import json
......@@ -17,7 +18,6 @@ import shutil
from secrets import token_bytes
from constants import ERROR_MESSAGES
####################################
# Load .env file
####################################
......@@ -71,7 +71,6 @@ for source in log_sources:
log.setLevel(SRC_LOG_LEVELS["CONFIG"])
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
......@@ -161,16 +160,6 @@ CHANGELOG = changelog_json
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
####################################
# DATA/FRONTEND BUILD DIR
####################################
......@@ -184,6 +173,108 @@ try:
except:
CONFIG_DATA = {}
####################################
# Config helpers
####################################
def save_config():
try:
with open(f"{DATA_DIR}/config.json", "w") as f:
json.dump(CONFIG_DATA, f, indent="\t")
except Exception as e:
log.exception(e)
def get_config_value(config_path: str):
path_parts = config_path.split(".")
cur_config = CONFIG_DATA
for key in path_parts:
if key in cur_config:
cur_config = cur_config[key]
else:
return None
return cur_config
T = TypeVar("T")
class PersistentConfig(Generic[T]):
def __init__(self, env_name: str, config_path: str, env_value: T):
self.env_name = env_name
self.config_path = config_path
self.env_value = env_value
self.config_value = get_config_value(config_path)
if self.config_value is not None:
log.info(f"'{env_name}' loaded from config.json")
self.value = self.config_value
else:
self.value = env_value
def __str__(self):
return str(self.value)
@property
def __dict__(self):
raise TypeError(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
def __getattribute__(self, item):
if item == "__dict__":
raise TypeError(
"PersistentConfig object cannot be converted to dict, use config_get or .value instead."
)
return super().__getattribute__(item)
def save(self):
# Don't save if the value is the same as the env value and the config value
if self.env_value == self.value:
if self.config_value == self.value:
return
log.info(f"Saving '{self.env_name}' to config.json")
path_parts = self.config_path.split(".")
config = CONFIG_DATA
for key in path_parts[:-1]:
if key not in config:
config[key] = {}
config = config[key]
config[path_parts[-1]] = self.value
save_config()
self.config_value = self.value
class AppConfig:
_state: dict[str, PersistentConfig]
def __init__(self):
super().__setattr__("_state", {})
def __setattr__(self, key, value):
if isinstance(value, PersistentConfig):
self._state[key] = value
else:
self._state[key].value = value
self._state[key].save()
def __getattr__(self, key):
return self._state[key].value
####################################
# WEBUI_AUTH (Required for security)
####################################
WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
)
####################################
# Static DIR
####################################
......@@ -318,7 +409,9 @@ OLLAMA_BASE_URLS = os.environ.get("OLLAMA_BASE_URLS", "")
OLLAMA_BASE_URLS = OLLAMA_BASE_URLS if OLLAMA_BASE_URLS != "" else OLLAMA_BASE_URL
OLLAMA_BASE_URLS = [url.strip() for url in OLLAMA_BASE_URLS.split(";")]
OLLAMA_BASE_URLS = PersistentConfig(
"OLLAMA_BASE_URLS", "ollama.base_urls", OLLAMA_BASE_URLS
)
####################################
# OPENAI_API
......@@ -335,7 +428,9 @@ OPENAI_API_KEYS = os.environ.get("OPENAI_API_KEYS", "")
OPENAI_API_KEYS = OPENAI_API_KEYS if OPENAI_API_KEYS != "" else OPENAI_API_KEY
OPENAI_API_KEYS = [url.strip() for url in OPENAI_API_KEYS.split(";")]
OPENAI_API_KEYS = PersistentConfig(
"OPENAI_API_KEYS", "openai.api_keys", OPENAI_API_KEYS
)
OPENAI_API_BASE_URLS = os.environ.get("OPENAI_API_BASE_URLS", "")
OPENAI_API_BASE_URLS = (
......@@ -346,37 +441,42 @@ OPENAI_API_BASE_URLS = [
url.strip() if url != "" else "https://api.openai.com/v1"
for url in OPENAI_API_BASE_URLS.split(";")
]
OPENAI_API_BASE_URLS = PersistentConfig(
"OPENAI_API_BASE_URLS", "openai.api_base_urls", OPENAI_API_BASE_URLS
)
OPENAI_API_KEY = ""
try:
OPENAI_API_KEY = OPENAI_API_KEYS[
OPENAI_API_BASE_URLS.index("https://api.openai.com/v1")
OPENAI_API_KEY = OPENAI_API_KEYS.value[
OPENAI_API_BASE_URLS.value.index("https://api.openai.com/v1")
]
except:
pass
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
####################################
# WEBUI
####################################
ENABLE_SIGNUP = (
False
if WEBUI_AUTH == False
else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
ENABLE_SIGNUP = PersistentConfig(
"ENABLE_SIGNUP",
"ui.enable_signup",
(
False
if not WEBUI_AUTH
else os.environ.get("ENABLE_SIGNUP", "True").lower() == "true"
),
)
DEFAULT_MODELS = PersistentConfig(
"DEFAULT_MODELS", "ui.default_models", os.environ.get("DEFAULT_MODELS", None)
)
DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None)
DEFAULT_PROMPT_SUGGESTIONS = (
CONFIG_DATA["ui"]["prompt_suggestions"]
if "ui" in CONFIG_DATA
and "prompt_suggestions" in CONFIG_DATA["ui"]
and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list
else [
DEFAULT_PROMPT_SUGGESTIONS = PersistentConfig(
"DEFAULT_PROMPT_SUGGESTIONS",
"ui.prompt_suggestions",
[
{
"title": ["Help me study", "vocabulary for a college entrance exam"],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
......@@ -404,23 +504,40 @@ DEFAULT_PROMPT_SUGGESTIONS = (
"title": ["Overcome procrastination", "give me tips"],
"content": "Could you start by asking me about instances when I procrastinate the most and then give me some suggestions to overcome it?",
},
]
],
)
DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
DEFAULT_USER_ROLE = PersistentConfig(
"DEFAULT_USER_ROLE",
"ui.default_user_role",
os.getenv("DEFAULT_USER_ROLE", "pending"),
)
USER_PERMISSIONS_CHAT_DELETION = (
os.environ.get("USER_PERMISSIONS_CHAT_DELETION", "True").lower() == "true"
)
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
USER_PERMISSIONS = PersistentConfig(
"USER_PERMISSIONS",
"ui.user_permissions",
{"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}},
)
ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true"
ENABLE_MODEL_FILTER = PersistentConfig(
"ENABLE_MODEL_FILTER",
"model_filter.enable",
os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true",
)
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
MODEL_FILTER_LIST = PersistentConfig(
"MODEL_FILTER_LIST",
"model_filter.list",
[model.strip() for model in MODEL_FILTER_LIST.split(";")],
)
WEBHOOK_URL = os.environ.get("WEBHOOK_URL", "")
WEBHOOK_URL = PersistentConfig(
"WEBHOOK_URL", "webhook_url", os.environ.get("WEBHOOK_URL", "")
)
ENABLE_ADMIN_EXPORT = os.environ.get("ENABLE_ADMIN_EXPORT", "True").lower() == "true"
......@@ -458,26 +575,45 @@ else:
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
ENABLE_RAG_HYBRID_SEARCH = (
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
RAG_TOP_K = PersistentConfig(
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "5"))
)
RAG_RELEVANCE_THRESHOLD = PersistentConfig(
"RAG_RELEVANCE_THRESHOLD",
"rag.relevance_threshold",
float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0")),
)
ENABLE_RAG_HYBRID_SEARCH = PersistentConfig(
"ENABLE_RAG_HYBRID_SEARCH",
"rag.enable_hybrid_search",
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true",
)
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true"
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = PersistentConfig(
"ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION",
"rag.enable_web_loader_ssl_verification",
os.environ.get("ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION", "True").lower() == "true",
)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
RAG_EMBEDDING_ENGINE = PersistentConfig(
"RAG_EMBEDDING_ENGINE",
"rag.embedding_engine",
os.environ.get("RAG_EMBEDDING_ENGINE", ""),
)
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
PDF_EXTRACT_IMAGES = PersistentConfig(
"PDF_EXTRACT_IMAGES",
"rag.pdf_extract_images",
os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true",
)
RAG_EMBEDDING_MODEL = os.environ.get(
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
RAG_EMBEDDING_MODEL = PersistentConfig(
"RAG_EMBEDDING_MODEL",
"rag.embedding_model",
os.environ.get("RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"),
)
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL.value}"),
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
......@@ -487,9 +623,13 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
if not RAG_RERANKING_MODEL == "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
RAG_RERANKING_MODEL = PersistentConfig(
"RAG_RERANKING_MODEL",
"rag.reranking_model",
os.environ.get("RAG_RERANKING_MODEL", ""),
)
if RAG_RERANKING_MODEL.value != "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL.value}"),
RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
......@@ -527,9 +667,14 @@ if USE_CUDA.lower() == "true":
else:
DEVICE_TYPE = "cpu"
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
CHUNK_SIZE = PersistentConfig(
"CHUNK_SIZE", "rag.chunk_size", int(os.environ.get("CHUNK_SIZE", "1500"))
)
CHUNK_OVERLAP = PersistentConfig(
"CHUNK_OVERLAP",
"rag.chunk_overlap",
int(os.environ.get("CHUNK_OVERLAP", "100")),
)
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
......@@ -545,16 +690,32 @@ And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]"""
RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_TEMPLATE = PersistentConfig(
"RAG_TEMPLATE",
"rag.template",
os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE),
)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
RAG_OPENAI_API_BASE_URL = PersistentConfig(
"RAG_OPENAI_API_BASE_URL",
"rag.openai_api_base_url",
os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
RAG_OPENAI_API_KEY = PersistentConfig(
"RAG_OPENAI_API_KEY",
"rag.openai_api_key",
os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY),
)
ENABLE_RAG_LOCAL_WEB_FETCH = (
os.getenv("ENABLE_RAG_LOCAL_WEB_FETCH", "False").lower() == "true"
)
YOUTUBE_LOADER_LANGUAGE = os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(",")
YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
"YOUTUBE_LOADER_LANGUAGE",
"rag.youtube_loader_language",
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
)
####################################
# Transcribe
......@@ -571,34 +732,78 @@ WHISPER_MODEL_AUTO_UPDATE = (
# Images
####################################
IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "")
IMAGE_GENERATION_ENGINE = PersistentConfig(
"IMAGE_GENERATION_ENGINE",
"image_generation.engine",
os.getenv("IMAGE_GENERATION_ENGINE", ""),
)
ENABLE_IMAGE_GENERATION = (
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true"
ENABLE_IMAGE_GENERATION = PersistentConfig(
"ENABLE_IMAGE_GENERATION",
"image_generation.enable",
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true",
)
AUTOMATIC1111_BASE_URL = PersistentConfig(
"AUTOMATIC1111_BASE_URL",
"image_generation.automatic1111.base_url",
os.getenv("AUTOMATIC1111_BASE_URL", ""),
)
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
COMFYUI_BASE_URL = PersistentConfig(
"COMFYUI_BASE_URL",
"image_generation.comfyui.base_url",
os.getenv("COMFYUI_BASE_URL", ""),
)
IMAGES_OPENAI_API_BASE_URL = os.getenv(
"IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL
IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
"IMAGES_OPENAI_API_BASE_URL",
"image_generation.openai.api_base_url",
os.getenv("IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
IMAGES_OPENAI_API_KEY = PersistentConfig(
"IMAGES_OPENAI_API_KEY",
"image_generation.openai.api_key",
os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY),
)
IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512")
IMAGE_SIZE = PersistentConfig(
"IMAGE_SIZE", "image_generation.size", os.getenv("IMAGE_SIZE", "512x512")
)
IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50))
IMAGE_STEPS = PersistentConfig(
"IMAGE_STEPS", "image_generation.steps", int(os.getenv("IMAGE_STEPS", 50))
)
IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "")
IMAGE_GENERATION_MODEL = PersistentConfig(
"IMAGE_GENERATION_MODEL",
"image_generation.model",
os.getenv("IMAGE_GENERATION_MODEL", ""),
)
####################################
# Audio
####################################
AUDIO_OPENAI_API_BASE_URL = os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
AUDIO_OPENAI_API_MODEL = os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1")
AUDIO_OPENAI_API_VOICE = os.getenv("AUDIO_OPENAI_API_VOICE", "alloy")
AUDIO_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_OPENAI_API_BASE_URL",
"audio.openai.api_base_url",
os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_OPENAI_API_KEY = PersistentConfig(
"AUDIO_OPENAI_API_KEY",
"audio.openai.api_key",
os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_OPENAI_API_MODEL = PersistentConfig(
"AUDIO_OPENAI_API_MODEL",
"audio.openai.api_model",
os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"),
)
AUDIO_OPENAI_API_VOICE = PersistentConfig(
"AUDIO_OPENAI_API_VOICE",
"audio.openai.api_voice",
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
)
####################################
# LiteLLM
......
from contextlib import asynccontextmanager
from bs4 import BeautifulSoup
import json
import markdown
......@@ -58,6 +59,7 @@ from config import (
SRC_LOG_LEVELS,
WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
AppConfig,
)
from constants import ERROR_MESSAGES
......@@ -92,12 +94,25 @@ https://github.com/open-webui/open-webui
"""
)
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
@asynccontextmanager
async def lifespan(app: FastAPI):
if ENABLE_LITELLM:
asyncio.create_task(start_litellm_background())
yield
if ENABLE_LITELLM:
await shutdown_litellm_background()
app = FastAPI(
docs_url="/docs" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
)
app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.WEBHOOK_URL = WEBHOOK_URL
app.state.config.WEBHOOK_URL = WEBHOOK_URL
origins = ["*"]
......@@ -211,12 +226,6 @@ async def check_url(request: Request, call_next):
return response
@app.on_event("startup")
async def on_startup():
if ENABLE_LITELLM:
asyncio.create_task(start_litellm_background())
app.mount("/api/v1", webui_app)
app.mount("/litellm/api", litellm_app)
......@@ -243,9 +252,9 @@ async def get_app_config():
"version": VERSION,
"auth": WEBUI_AUTH,
"default_locale": default_locale,
"images": images_app.state.ENABLED,
"default_models": webui_app.state.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
"images": images_app.state.config.ENABLED,
"default_models": webui_app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"admin_export_enabled": ENABLE_ADMIN_EXPORT,
}
......@@ -254,8 +263,8 @@ async def get_app_config():
@app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)):
return {
"enabled": app.state.ENABLE_MODEL_FILTER,
"models": app.state.MODEL_FILTER_LIST,
"enabled": app.state.config.ENABLE_MODEL_FILTER,
"models": app.state.config.MODEL_FILTER_LIST,
}
......@@ -268,28 +277,28 @@ class ModelFilterConfigForm(BaseModel):
async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
app.state.ENABLE_MODEL_FILTER = form_data.enabled
app.state.MODEL_FILTER_LIST = form_data.models
app.state.config.ENABLE_MODEL_FILTER, form_data.enabled
app.state.config.MODEL_FILTER_LIST, form_data.models
ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
ollama_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
ollama_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
openai_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
openai_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
litellm_app.state.ENABLE_MODEL_FILTER = app.state.config.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = app.state.config.MODEL_FILTER_LIST
return {
"enabled": app.state.ENABLE_MODEL_FILTER,
"models": app.state.MODEL_FILTER_LIST,
"enabled": app.state.config.ENABLE_MODEL_FILTER,
"models": app.state.config.MODEL_FILTER_LIST,
}
@app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)):
return {
"url": app.state.WEBHOOK_URL,
"url": app.state.config.WEBHOOK_URL,
}
......@@ -299,12 +308,12 @@ class UrlForm(BaseModel):
@app.post("/api/webhook")
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.WEBHOOK_URL = form_data.url
app.state.config.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.WEBHOOK_URL
webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
return {
"url": app.state.WEBHOOK_URL,
"url": app.state.config.WEBHOOK_URL,
}
......@@ -381,9 +390,3 @@ else:
log.warning(
f"Frontend build directory not found at '{FRONTEND_BUILD_DIR}'. Serving API only."
)
@app.on_event("shutdown")
async def shutdown_event():
if ENABLE_LITELLM:
await shutdown_litellm_background()
version: '3.8'
services:
ollama:
# Expose Ollama API outside the container stack
......
version: '3.8'
services:
ollama:
volumes:
......
version: '3.8'
services:
ollama:
# GPU support
......
version: '3.8'
services:
ollama:
volumes:
......
......@@ -17,7 +17,7 @@ If your issue or contribution pertains directly to the core Ollama technology, p
### 🚨 Reporting Issues
Noticed something off? Have an idea? Check our [Issues tab](https://github.com/open-webui/oopen-webui/issues) to see if it's already been reported or suggested. If not, feel free to open a new issue. When reporting an issue, please follow our issue templates. These templates are designed to ensure that all necessary details are provided from the start, enabling us to address your concerns more efficiently.
Noticed something off? Have an idea? Check our [Issues tab](https://github.com/open-webui/open-webui/issues) to see if it's already been reported or suggested. If not, feel free to open a new issue. When reporting an issue, please follow our issue templates. These templates are designed to ensure that all necessary details are provided from the start, enabling us to address your concerns more efficiently.
> [!IMPORTANT]
>
......@@ -54,7 +54,7 @@ Help us make Open WebUI more accessible by improving documentation, writing tuto
Help us make Open WebUI available to a wider audience. In this section, we'll guide you through the process of adding new translations to the project.
We use JSON files to store translations. You can find the existing translation files in the `src/lib/i18n/locales` directory. Each directory corresponds to a specific language, for example, `en-US` for English (US), `fr-FR` for French (France) and so on. You can refer to [ISO 639 Language Codes][http://www.lingoes.net/en/translator/langcode.htm] to find the appropriate code for a specific language.
We use JSON files to store translations. You can find the existing translation files in the `src/lib/i18n/locales` directory. Each directory corresponds to a specific language, for example, `en-US` for English (US), `fr-FR` for French (France) and so on. You can refer to [ISO 639 Language Codes](http://www.lingoes.net/en/translator/langcode.htm) to find the appropriate code for a specific language.
To add a new language:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment