Unverified Commit 5166e92f authored by arkohut's avatar arkohut Committed by GitHub
Browse files

Merge branch 'dev' into support-py-for-run-code

parents b443d61c b6b71c08
...@@ -11,7 +11,3 @@ OPENAI_API_KEY='' ...@@ -11,7 +11,3 @@ OPENAI_API_KEY=''
SCARF_NO_ANALYTICS=true SCARF_NO_ANALYTICS=true
DO_NOT_TRACK=true DO_NOT_TRACK=true
ANONYMIZED_TELEMETRY=false ANONYMIZED_TELEMETRY=false
# Use locally bundled version of the LiteLLM cost map json
# to avoid repetitive startup connections
LITELLM_LOCAL_MODEL_COST_MAP="True"
\ No newline at end of file
...@@ -11,7 +11,7 @@ jobs: ...@@ -11,7 +11,7 @@ jobs:
steps: steps:
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v2 uses: actions/checkout@v4
- name: Check for changes in package.json - name: Check for changes in package.json
run: | run: |
...@@ -36,7 +36,7 @@ jobs: ...@@ -36,7 +36,7 @@ jobs:
echo "::set-output name=content::$CHANGELOG_ESCAPED" echo "::set-output name=content::$CHANGELOG_ESCAPED"
- name: Create GitHub release - name: Create GitHub release
uses: actions/github-script@v5 uses: actions/github-script@v7
with: with:
github-token: ${{ secrets.GITHUB_TOKEN }} github-token: ${{ secrets.GITHUB_TOKEN }}
script: | script: |
...@@ -51,7 +51,7 @@ jobs: ...@@ -51,7 +51,7 @@ jobs:
console.log(`Created release ${release.data.html_url}`) console.log(`Created release ${release.data.html_url}`)
- name: Upload package to GitHub release - name: Upload package to GitHub release
uses: actions/upload-artifact@v3 uses: actions/upload-artifact@v4
with: with:
name: package name: package
path: . path: .
......
name: Deploy to HuggingFace Spaces
on:
push:
branches:
- dev
- main
workflow_dispatch:
jobs:
check-secret:
runs-on: ubuntu-latest
outputs:
token-set: ${{ steps.check-key.outputs.defined }}
steps:
- id: check-key
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
if: "${{ env.HF_TOKEN != '' }}"
run: echo "defined=true" >> $GITHUB_OUTPUT
deploy:
runs-on: ubuntu-latest
needs: [check-secret]
if: needs.check-secret.outputs.token-set == 'true'
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Remove git history
run: rm -rf .git
- name: Prepend YAML front matter to README.md
run: |
echo "---" > temp_readme.md
echo "title: Open WebUI" >> temp_readme.md
echo "emoji: 🐳" >> temp_readme.md
echo "colorFrom: purple" >> temp_readme.md
echo "colorTo: gray" >> temp_readme.md
echo "sdk: docker" >> temp_readme.md
echo "app_port: 8080" >> temp_readme.md
echo "---" >> temp_readme.md
cat README.md >> temp_readme.md
mv temp_readme.md README.md
- name: Configure git
run: |
git config --global user.email "41898282+github-actions[bot]@users.noreply.github.com"
git config --global user.name "github-actions[bot]"
- name: Set up Git and push to Space
run: |
git init --initial-branch=main
git lfs track "*.ttf"
rm demo.gif
git add .
git commit -m "GitHub deploy: ${{ github.sha }}"
git push --force https://open-webui:${HF_TOKEN}@huggingface.co/spaces/open-webui/open-webui main
...@@ -84,6 +84,8 @@ jobs: ...@@ -84,6 +84,8 @@ jobs:
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }} cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
build-args: |
BUILD_HASH=${{ github.sha }}
- name: Export digest - name: Export digest
run: | run: |
...@@ -170,7 +172,9 @@ jobs: ...@@ -170,7 +172,9 @@ jobs:
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }} cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
build-args: USE_CUDA=true build-args: |
BUILD_HASH=${{ github.sha }}
USE_CUDA=true
- name: Export digest - name: Export digest
run: | run: |
...@@ -257,7 +261,9 @@ jobs: ...@@ -257,7 +261,9 @@ jobs:
outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true outputs: type=image,name=${{ env.FULL_IMAGE_NAME }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }} cache-from: type=registry,ref=${{ steps.cache-meta.outputs.tags }}
cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max cache-to: type=registry,ref=${{ steps.cache-meta.outputs.tags }},mode=max
build-args: USE_OLLAMA=true build-args: |
BUILD_HASH=${{ github.sha }}
USE_OLLAMA=true
- name: Export digest - name: Export digest
run: | run: |
......
...@@ -23,7 +23,7 @@ jobs: ...@@ -23,7 +23,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v2 uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
......
...@@ -19,7 +19,7 @@ jobs: ...@@ -19,7 +19,7 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Setup Node.js - name: Setup Node.js
uses: actions/setup-node@v3 uses: actions/setup-node@v4
with: with:
node-version: '20' # Or specify any other version you want to use node-version: '20' # Or specify any other version you want to use
......
...@@ -20,7 +20,11 @@ jobs: ...@@ -20,7 +20,11 @@ jobs:
- name: Build and run Compose Stack - name: Build and run Compose Stack
run: | run: |
docker compose --file docker-compose.yaml --file docker-compose.api.yaml up --detach --build docker compose \
--file docker-compose.yaml \
--file docker-compose.api.yaml \
--file docker-compose.a1111-test.yaml \
up --detach --build
- name: Wait for Ollama to be up - name: Wait for Ollama to be up
timeout-minutes: 5 timeout-minutes: 5
...@@ -95,7 +99,7 @@ jobs: ...@@ -95,7 +99,7 @@ jobs:
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v2 uses: actions/setup-python@v5
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
......
name: Release to PyPI
on:
push:
branches:
- main # or whatever branch you want to use
- dev
jobs:
release:
runs-on: ubuntu-latest
environment:
name: pypi
url: https://pypi.org/p/open-webui
permissions:
id-token: write
steps:
- name: Checkout repository
uses: actions/checkout@v4
- uses: actions/setup-node@v4
with:
node-version: 18
- uses: actions/setup-python@v5
with:
python-version: 3.11
- name: Build
run: |
python -m pip install --upgrade pip
pip install build
python -m build .
- name: Publish package distributions to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
...@@ -11,12 +11,14 @@ ARG USE_CUDA_VER=cu121 ...@@ -11,12 +11,14 @@ ARG USE_CUDA_VER=cu121
# IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them. # IMPORTANT: If you change the embedding model (sentence-transformers/all-MiniLM-L6-v2) and vice versa, you aren't able to use RAG Chat with your previous documents loaded in the WebUI! You need to re-embed them.
ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2 ARG USE_EMBEDDING_MODEL=sentence-transformers/all-MiniLM-L6-v2
ARG USE_RERANKING_MODEL="" ARG USE_RERANKING_MODEL=""
ARG BUILD_HASH=dev-build
# Override at your own risk - non-root configurations are untested # Override at your own risk - non-root configurations are untested
ARG UID=0 ARG UID=0
ARG GID=0 ARG GID=0
######## WebUI frontend ######## ######## WebUI frontend ########
FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build FROM --platform=$BUILDPLATFORM node:21-alpine3.19 as build
ARG BUILD_HASH
WORKDIR /app WORKDIR /app
...@@ -24,6 +26,7 @@ COPY package.json package-lock.json ./ ...@@ -24,6 +26,7 @@ COPY package.json package-lock.json ./
RUN npm ci RUN npm ci
COPY . . COPY . .
ENV APP_BUILD_HASH=${BUILD_HASH}
RUN npm run build RUN npm run build
######## WebUI backend ######## ######## WebUI backend ########
...@@ -35,6 +38,7 @@ ARG USE_OLLAMA ...@@ -35,6 +38,7 @@ ARG USE_OLLAMA
ARG USE_CUDA_VER ARG USE_CUDA_VER
ARG USE_EMBEDDING_MODEL ARG USE_EMBEDDING_MODEL
ARG USE_RERANKING_MODEL ARG USE_RERANKING_MODEL
ARG BUILD_HASH
ARG UID ARG UID
ARG GID ARG GID
...@@ -59,11 +63,6 @@ ENV OPENAI_API_KEY="" \ ...@@ -59,11 +63,6 @@ ENV OPENAI_API_KEY="" \
DO_NOT_TRACK=true \ DO_NOT_TRACK=true \
ANONYMIZED_TELEMETRY=false ANONYMIZED_TELEMETRY=false
# Use locally bundled version of the LiteLLM cost map json
# to avoid repetitive startup connections
ENV LITELLM_LOCAL_MODEL_COST_MAP="True"
#### Other models ######################################################### #### Other models #########################################################
## whisper TTS model settings ## ## whisper TTS model settings ##
ENV WHISPER_MODEL="base" \ ENV WHISPER_MODEL="base" \
...@@ -132,7 +131,8 @@ RUN pip3 install uv && \ ...@@ -132,7 +131,8 @@ RUN pip3 install uv && \
uv pip install --system -r requirements.txt --no-cache-dir && \ uv pip install --system -r requirements.txt --no-cache-dir && \
python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \ python -c "import os; from sentence_transformers import SentenceTransformer; SentenceTransformer(os.environ['RAG_EMBEDDING_MODEL'], device='cpu')" && \
python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \ python -c "import os; from faster_whisper import WhisperModel; WhisperModel(os.environ['WHISPER_MODEL'], device='cpu', compute_type='int8', download_root=os.environ['WHISPER_MODEL_DIR'])"; \
fi fi; \
chown -R $UID:$GID /app/backend/data/
...@@ -154,4 +154,6 @@ HEALTHCHECK CMD curl --silent --fail http://localhost:8080/health | jq -e '.stat ...@@ -154,4 +154,6 @@ HEALTHCHECK CMD curl --silent --fail http://localhost:8080/health | jq -e '.stat
USER $UID:$GID USER $UID:$GID
ENV WEBUI_BUILD_VERSION=${BUILD_HASH}
CMD [ "bash", "start.sh"] CMD [ "bash", "start.sh"]
import sys
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException
from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware
import logging
from fastapi import FastAPI, Request, Depends, status, Response
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import StreamingResponse
import json
import time
import requests
from pydantic import BaseModel, ConfigDict
from typing import Optional, List
from utils.utils import get_verified_user, get_current_user, get_admin_user
from config import SRC_LOG_LEVELS, ENV
from constants import MESSAGES
import os
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["LITELLM"])
from config import (
ENABLE_LITELLM,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
DATA_DIR,
LITELLM_PROXY_PORT,
LITELLM_PROXY_HOST,
)
import warnings
warnings.simplefilter("ignore")
from litellm.utils import get_llm_provider
import asyncio
import subprocess
import yaml
@asynccontextmanager
async def lifespan(app: FastAPI):
log.info("startup_event")
# TODO: Check config.yaml file and create one
asyncio.create_task(start_litellm_background())
yield
app = FastAPI(lifespan=lifespan)
origins = ["*"]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
LITELLM_CONFIG_DIR = f"{DATA_DIR}/litellm/config.yaml"
with open(LITELLM_CONFIG_DIR, "r") as file:
litellm_config = yaml.safe_load(file)
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
app.state.ENABLE = ENABLE_LITELLM
app.state.CONFIG = litellm_config
# Global variable to store the subprocess reference
background_process = None
CONFLICT_ENV_VARS = [
# Uvicorn uses PORT, so LiteLLM might use it as well
"PORT",
# LiteLLM uses DATABASE_URL for Prisma connections
"DATABASE_URL",
]
async def run_background_process(command):
global background_process
log.info("run_background_process")
try:
# Log the command to be executed
log.info(f"Executing command: {command}")
# Filter environment variables known to conflict with litellm
env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS}
# Execute the command and create a subprocess
process = await asyncio.create_subprocess_exec(
*command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env
)
background_process = process
log.info("Subprocess started successfully.")
# Capture STDERR for debugging purposes
stderr_output = await process.stderr.read()
stderr_text = stderr_output.decode().strip()
if stderr_text:
log.info(f"Subprocess STDERR: {stderr_text}")
# log.info output line by line
async for line in process.stdout:
log.info(line.decode().strip())
# Wait for the process to finish
returncode = await process.wait()
log.info(f"Subprocess exited with return code {returncode}")
except Exception as e:
log.error(f"Failed to start subprocess: {e}")
raise # Optionally re-raise the exception if you want it to propagate
async def start_litellm_background():
log.info("start_litellm_background")
# Command to run in the background
command = [
"litellm",
"--port",
str(LITELLM_PROXY_PORT),
"--host",
LITELLM_PROXY_HOST,
"--telemetry",
"False",
"--config",
LITELLM_CONFIG_DIR,
]
await run_background_process(command)
async def shutdown_litellm_background():
log.info("shutdown_litellm_background")
global background_process
if background_process:
background_process.terminate()
await background_process.wait() # Ensure the process has terminated
log.info("Subprocess terminated")
background_process = None
@app.get("/")
async def get_status():
return {"status": True}
async def restart_litellm():
"""
Endpoint to restart the litellm background service.
"""
log.info("Requested restart of litellm service.")
try:
# Shut down the existing process if it is running
await shutdown_litellm_background()
log.info("litellm service shutdown complete.")
# Restart the background service
asyncio.create_task(start_litellm_background())
log.info("litellm service restart complete.")
return {
"status": "success",
"message": "litellm service restarted successfully.",
}
except Exception as e:
log.info(f"Error restarting litellm service: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
@app.get("/restart")
async def restart_litellm_handler(user=Depends(get_admin_user)):
return await restart_litellm()
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return app.state.CONFIG
class LiteLLMConfigForm(BaseModel):
general_settings: Optional[dict] = None
litellm_settings: Optional[dict] = None
model_list: Optional[List[dict]] = None
router_settings: Optional[dict] = None
model_config = ConfigDict(protected_namespaces=())
@app.post("/config/update")
async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
app.state.CONFIG = form_data.model_dump(exclude_none=True)
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return app.state.CONFIG
@app.get("/models")
@app.get("/v1/models")
async def get_models(user=Depends(get_current_user)):
if app.state.ENABLE:
while not background_process:
await asyncio.sleep(0.1)
url = f"http://localhost:{LITELLM_PROXY_PORT}/v1"
r = None
try:
r = requests.request(method="GET", url=f"{url}/models")
r.raise_for_status()
data = r.json()
if app.state.ENABLE_MODEL_FILTER:
if user and user.role == "user":
data["data"] = list(
filter(
lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
data["data"],
)
)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']}"
except:
error_detail = f"External: {e}"
return {
"data": [
{
"id": model["model_name"],
"object": "model",
"created": int(time.time()),
"owned_by": "openai",
}
for model in app.state.CONFIG["model_list"]
],
"object": "list",
}
else:
return {
"data": [],
"object": "list",
}
@app.get("/model/info")
async def get_model_list(user=Depends(get_admin_user)):
return {"data": app.state.CONFIG["model_list"]}
class AddLiteLLMModelForm(BaseModel):
model_name: str
litellm_params: dict
model_config = ConfigDict(protected_namespaces=())
@app.post("/model/new")
async def add_model_to_config(
form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
):
try:
get_llm_provider(model=form_data.model_name)
app.state.CONFIG["model_list"].append(form_data.model_dump())
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
class DeleteLiteLLMModelForm(BaseModel):
id: str
@app.post("/model/delete")
async def delete_model_from_config(
form_data: DeleteLiteLLMModelForm, user=Depends(get_admin_user)
):
app.state.CONFIG["model_list"] = [
model
for model in app.state.CONFIG["model_list"]
if model["model_name"] != form_data.id
]
with open(LITELLM_CONFIG_DIR, "w") as file:
yaml.dump(app.state.CONFIG, file)
await restart_litellm()
return {"message": MESSAGES.MODEL_DELETED(form_data.id)}
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
body = await request.body()
url = f"http://localhost:{LITELLM_PROXY_PORT}"
target_url = f"{url}/{path}"
headers = {}
# headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
try:
r = requests.request(
method=request.method,
url=target_url,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
headers=dict(r.headers),
)
else:
response_data = r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
...@@ -29,8 +29,8 @@ import time ...@@ -29,8 +29,8 @@ import time
from urllib.parse import urlparse from urllib.parse import urlparse
from typing import Optional, List, Union from typing import Optional, List, Union
from apps.webui.models.models import Models
from apps.web.models.users import Users from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
decode_token, decode_token,
...@@ -39,10 +39,13 @@ from utils.utils import ( ...@@ -39,10 +39,13 @@ from utils.utils import (
get_admin_user, get_admin_user,
) )
from utils.models import get_model_id_from_custom_model_id
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
OLLAMA_BASE_URLS, OLLAMA_BASE_URLS,
ENABLE_OLLAMA_API,
ENABLE_MODEL_FILTER, ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST, MODEL_FILTER_LIST,
UPLOAD_DIR, UPLOAD_DIR,
...@@ -67,6 +70,7 @@ app.state.config = AppConfig() ...@@ -67,6 +70,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {} app.state.MODELS = {}
...@@ -96,6 +100,21 @@ async def get_status(): ...@@ -96,6 +100,21 @@ async def get_status():
return {"status": True} return {"status": True}
@app.get("/config")
async def get_config(user=Depends(get_admin_user)):
return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
class OllamaConfigForm(BaseModel):
enable_ollama_api: Optional[bool] = None
@app.post("/config/update")
async def update_config(form_data: OllamaConfigForm, user=Depends(get_admin_user)):
app.state.config.ENABLE_OLLAMA_API = form_data.enable_ollama_api
return {"ENABLE_OLLAMA_API": app.state.config.ENABLE_OLLAMA_API}
@app.get("/urls") @app.get("/urls")
async def get_ollama_api_urls(user=Depends(get_admin_user)): async def get_ollama_api_urls(user=Depends(get_admin_user)):
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
...@@ -156,15 +175,24 @@ def merge_models_lists(model_lists): ...@@ -156,15 +175,24 @@ def merge_models_lists(model_lists):
async def get_all_models(): async def get_all_models():
log.info("get_all_models()") log.info("get_all_models()")
tasks = [fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS]
if app.state.config.ENABLE_OLLAMA_API:
tasks = [
fetch_url(f"{url}/api/tags") for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
models = { models = {
"models": merge_models_lists( "models": merge_models_lists(
map(lambda response: response["models"] if response else None, responses) map(
lambda response: response["models"] if response else None, responses
)
) )
} }
else:
models = {"models": []}
app.state.MODELS = {model["model"]: model for model in models["models"]} app.state.MODELS = {model["model"]: model for model in models["models"]}
return models return models
...@@ -278,6 +306,9 @@ async def pull_model( ...@@ -278,6 +306,9 @@ async def pull_model(
r = None r = None
# Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
def get_request(): def get_request():
nonlocal url nonlocal url
nonlocal r nonlocal r
...@@ -305,7 +336,7 @@ async def pull_model( ...@@ -305,7 +336,7 @@ async def pull_model(
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/pull", url=f"{url}/api/pull",
data=form_data.model_dump_json(exclude_none=True).encode(), data=json.dumps(payload),
stream=True, stream=True,
) )
...@@ -848,14 +879,93 @@ async def generate_chat_completion( ...@@ -848,14 +879,93 @@ async def generate_chat_completion(
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx == None: log.debug(
model = form_data.model "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
if ":" not in model: payload = {
model = f"{model}:latest" **form_data.model_dump(exclude_none=True),
}
if model in app.state.MODELS: model_id = form_data.model
url_idx = random.choice(app.state.MODELS[model]["urls"]) model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
payload["options"] = {}
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
payload["options"]["seed"] = model_info.params.get("seed", None)
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
payload["options"]["top_k"] = model_info.params.get("top_k", None)
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if url_idx == None:
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
...@@ -865,16 +975,12 @@ async def generate_chat_completion( ...@@ -865,16 +975,12 @@ async def generate_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None print(payload)
log.debug( r = None
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
def get_request(): def get_request():
nonlocal form_data nonlocal payload
nonlocal r nonlocal r
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
...@@ -883,7 +989,7 @@ async def generate_chat_completion( ...@@ -883,7 +989,7 @@ async def generate_chat_completion(
def stream_content(): def stream_content():
try: try:
if form_data.stream: if payload.get("stream", None):
yield json.dumps({"id": request_id, "done": False}) + "\n" yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192): for chunk in r.iter_content(chunk_size=8192):
...@@ -901,7 +1007,7 @@ async def generate_chat_completion( ...@@ -901,7 +1007,7 @@ async def generate_chat_completion(
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/api/chat", url=f"{url}/api/chat",
data=form_data.model_dump_json(exclude_none=True).encode(), data=json.dumps(payload),
stream=True, stream=True,
) )
...@@ -957,14 +1063,62 @@ async def generate_openai_chat_completion( ...@@ -957,14 +1063,62 @@ async def generate_openai_chat_completion(
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx == None: payload = {
model = form_data.model **form_data.model_dump(exclude_none=True),
}
if ":" not in model: model_id = form_data.model
model = f"{model}:latest" model_info = Models.get_model_by_id(model_id)
if model in app.state.MODELS: if model_info:
url_idx = random.choice(app.state.MODELS[model]["urls"]) print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
payload["temperature"] = model_info.params.get("temperature", None)
payload["top_p"] = model_info.params.get("top_p", None)
payload["max_tokens"] = model_info.params.get("max_tokens", None)
payload["frequency_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["seed"] = model_info.params.get("seed", None)
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
if url_idx == None:
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
...@@ -977,7 +1131,7 @@ async def generate_openai_chat_completion( ...@@ -977,7 +1131,7 @@ async def generate_openai_chat_completion(
r = None r = None
def get_request(): def get_request():
nonlocal form_data nonlocal payload
nonlocal r nonlocal r
request_id = str(uuid.uuid4()) request_id = str(uuid.uuid4())
...@@ -986,7 +1140,7 @@ async def generate_openai_chat_completion( ...@@ -986,7 +1140,7 @@ async def generate_openai_chat_completion(
def stream_content(): def stream_content():
try: try:
if form_data.stream: if payload.get("stream"):
yield json.dumps( yield json.dumps(
{"request_id": request_id, "done": False} {"request_id": request_id, "done": False}
) + "\n" ) + "\n"
...@@ -1006,7 +1160,7 @@ async def generate_openai_chat_completion( ...@@ -1006,7 +1160,7 @@ async def generate_openai_chat_completion(
r = requests.request( r = requests.request(
method="POST", method="POST",
url=f"{url}/v1/chat/completions", url=f"{url}/v1/chat/completions",
data=form_data.model_dump_json(exclude_none=True).encode(), data=json.dumps(payload),
stream=True, stream=True,
) )
......
...@@ -10,8 +10,8 @@ import logging ...@@ -10,8 +10,8 @@ import logging
from pydantic import BaseModel from pydantic import BaseModel
from apps.webui.models.models import Models
from apps.web.models.users import Users from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
decode_token, decode_token,
...@@ -53,7 +53,6 @@ app.state.config = AppConfig() ...@@ -53,7 +53,6 @@ app.state.config = AppConfig()
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
...@@ -199,14 +198,20 @@ async def fetch_url(url, key): ...@@ -199,14 +198,20 @@ async def fetch_url(url, key):
def merge_models_lists(model_lists): def merge_models_lists(model_lists):
log.info(f"merge_models_lists {model_lists}") log.debug(f"merge_models_lists {model_lists}")
merged_list = [] merged_list = []
for idx, models in enumerate(model_lists): for idx, models in enumerate(model_lists):
if models is not None and "error" not in models: if models is not None and "error" not in models:
merged_list.extend( merged_list.extend(
[ [
{**model, "urlIdx": idx} {
**model,
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"urlIdx": idx,
}
for model in models for model in models
if "api.openai.com" if "api.openai.com"
not in app.state.config.OPENAI_API_BASE_URLS[idx] not in app.state.config.OPENAI_API_BASE_URLS[idx]
...@@ -232,7 +237,7 @@ async def get_all_models(): ...@@ -232,7 +237,7 @@ async def get_all_models():
] ]
responses = await asyncio.gather(*tasks) responses = await asyncio.gather(*tasks)
log.info(f"get_all_models:responses() {responses}") log.debug(f"get_all_models:responses() {responses}")
models = { models = {
"data": merge_models_lists( "data": merge_models_lists(
...@@ -249,7 +254,7 @@ async def get_all_models(): ...@@ -249,7 +254,7 @@ async def get_all_models():
) )
} }
log.info(f"models: {models}") log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]} app.state.MODELS = {model["id"]: model for model in models["data"]}
return models return models
...@@ -310,31 +315,93 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -310,31 +315,93 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
body = await request.body() body = await request.body()
# TODO: Remove below after gpt-4-vision fix from Open AI # TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision) # Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
payload = None
try: try:
if "chat/completions" in path:
body = body.decode("utf-8") body = body.decode("utf-8")
body = json.loads(body) body = json.loads(body)
idx = app.state.MODELS[body.get("model")]["urlIdx"] payload = {**body}
model_id = body.get("model")
model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
payload["temperature"] = model_info.params.get("temperature", None)
payload["top_p"] = model_info.params.get("top_p", None)
payload["max_tokens"] = model_info.params.get("max_tokens", None)
payload["frequency_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["seed"] = model_info.params.get("seed", None)
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None)
+ message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
else:
pass
print(app.state.MODELS)
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id}
payload["title"] = (
True
if payload["stream"] == False and payload["max_tokens"] == 50
else False
)
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000 # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model # This is a workaround until OpenAI fixes the issue with this model
if body.get("model") == "gpt-4-vision-preview": if payload.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in body: if "max_tokens" not in payload:
body["max_tokens"] = 4000 payload["max_tokens"] = 4000
log.debug("Modified body_dict:", body) log.debug("Modified payload:", payload)
# Fix for ChatGPT calls failing because the num_ctx key is in body
if "num_ctx" in body:
# If 'num_ctx' is in the dictionary, delete it
# Leaving it there generates an error with the
# OpenAI API (Feb 2024)
del body["num_ctx"]
# Convert the modified body back to JSON # Convert the modified body back to JSON
body = json.dumps(body) payload = json.dumps(payload)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e) log.error("Error loading request body into a dictionary:", e)
print(payload)
url = app.state.config.OPENAI_API_BASE_URLS[idx] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx] key = app.state.config.OPENAI_API_KEYS[idx]
...@@ -353,7 +420,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)): ...@@ -353,7 +420,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
r = requests.request( r = requests.request(
method=request.method, method=request.method,
url=target_url, url=target_url,
data=body, data=payload if payload else body,
headers=headers, headers=headers,
stream=True, stream=True,
) )
......
...@@ -46,7 +46,7 @@ import json ...@@ -46,7 +46,7 @@ import json
import sentence_transformers import sentence_transformers
from apps.web.models.documents import ( from apps.webui.models.documents import (
Documents, Documents,
DocumentForm, DocumentForm,
DocumentResponse, DocumentResponse,
......
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
import json
####################
# Modelfile DB Schema
####################
class Modelfile(Model):
tag_name = CharField(unique=True)
user_id = CharField()
modelfile = TextField()
timestamp = BigIntegerField()
class Meta:
database = DB
class ModelfileModel(BaseModel):
tag_name: str
user_id: str
modelfile: str
timestamp: int # timestamp in epoch
####################
# Forms
####################
class ModelfileForm(BaseModel):
modelfile: dict
class ModelfileTagNameForm(BaseModel):
tag_name: str
class ModelfileUpdateForm(ModelfileForm, ModelfileTagNameForm):
pass
class ModelfileResponse(BaseModel):
tag_name: str
user_id: str
modelfile: dict
timestamp: int # timestamp in epoch
class ModelfilesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Modelfile])
def insert_new_modelfile(
self, user_id: str, form_data: ModelfileForm
) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile:
modelfile = ModelfileModel(
**{
"user_id": user_id,
"tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()),
}
)
try:
result = Modelfile.create(**modelfile.model_dump())
if result:
return modelfile
else:
return None
except:
return None
else:
return None
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile))
except:
return None
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
return [
ModelfileResponse(
**{
**model_to_dict(modelfile),
"modelfile": json.loads(modelfile.modelfile),
}
)
for modelfile in Modelfile.select()
# .limit(limit).offset(skip)
]
def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict
) -> Optional[ModelfileModel]:
try:
query = Modelfile.update(
modelfile=json.dumps(modelfile),
timestamp=int(time.time()),
).where(Modelfile.tag_name == tag_name)
query.execute()
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile))
except:
return None
def delete_modelfile_by_tag_name(self, tag_name: str) -> bool:
try:
query = Modelfile.delete().where((Modelfile.tag_name == tag_name))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Modelfiles = ModelfilesTable(DB)
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.modelfiles import (
Modelfiles,
ModelfileForm,
ModelfileTagNameForm,
ModelfileUpdateForm,
ModelfileResponse,
)
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
############################
# GetModelfiles
############################
@router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(
skip: int = 0, limit: int = 50, user=Depends(get_current_user)
):
return Modelfiles.get_modelfiles(skip, limit)
############################
# CreateNewModelfile
############################
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)):
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
############################
# GetModelfileByTagName
############################
@router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateModelfileByTagName
############################
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, user=Depends(get_admin_user)
):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
**json.loads(modelfile.modelfile),
**form_data.modelfile,
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile
)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# DeleteModelfileByTagName
############################
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, user=Depends(get_admin_user)
):
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result
import json
from peewee import * from peewee import *
from peewee_migrate import Router from peewee_migrate import Router
from playhouse.db_url import connect from playhouse.db_url import connect
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
import os import os
import logging import logging
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"]) log.setLevel(SRC_LOG_LEVELS["DB"])
class JSONField(TextField):
def db_value(self, value):
return json.dumps(value)
def python_value(self, value):
if value is not None:
return json.loads(value)
# Check if the file exists # Check if the file exists
if os.path.exists(f"{DATA_DIR}/ollama.db"): if os.path.exists(f"{DATA_DIR}/ollama.db"):
# Rename the file # Rename the file
...@@ -18,6 +30,10 @@ else: ...@@ -18,6 +30,10 @@ else:
DB = connect(DATABASE_URL) DB = connect(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.") log.info(f"Connected to a {DB.__class__.__name__} database.")
router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log) router = Router(
DB,
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
logger=log,
)
router.run() router.run()
DB.connect(reuse_if_open=True) DB.connect(reuse_if_open=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment