Unverified Commit 9bcd4ce5 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #3559 from open-webui/dev

0.3.8
parents 824966ad b38abf23
......@@ -10,7 +10,8 @@ node_modules
vite.config.js.timestamp-*
vite.config.ts.timestamp-*
__pycache__
.env
.idea
venv
_old
uploads
.ipynb_checkpoints
......
......@@ -35,6 +35,10 @@ jobs:
done
echo "Service is up!"
- name: Delete Docker build cache
run: |
docker builder prune --all --force
- name: Preload Ollama model
run: |
docker exec ollama ollama pull qwen:0.5b-chat-v1.5-q2_K
......@@ -43,7 +47,7 @@ jobs:
uses: cypress-io/github-action@v6
with:
browser: chrome
wait-on: "http://localhost:3000"
wait-on: 'http://localhost:3000'
config: baseUrl=http://localhost:3000
- uses: actions/upload-artifact@v4
......@@ -67,6 +71,28 @@ jobs:
path: compose-logs.txt
if-no-files-found: ignore
# pytest:
# name: Run Backend Tests
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v4
# - name: Set up Python
# uses: actions/setup-python@v4
# with:
# python-version: ${{ matrix.python-version }}
# - name: Install dependencies
# run: |
# python -m pip install --upgrade pip
# pip install -r backend/requirements.txt
# - name: pytest run
# run: |
# ls -al
# cd backend
# PYTHONPATH=. pytest . -o log_cli=true -o log_cli_level=INFO
migration_test:
name: Run Migration Tests
runs-on: ubuntu-latest
......@@ -126,11 +152,11 @@ jobs:
cd backend
uvicorn main:app --port "8080" --forwarded-allow-ips '*' &
UVICORN_PID=$!
# Wait up to 20 seconds for the server to start
for i in {1..20}; do
# Wait up to 40 seconds for the server to start
for i in {1..40}; do
curl -s http://localhost:8080/api/config > /dev/null && break
sleep 1
if [ $i -eq 20 ]; then
if [ $i -eq 40 ]; then
echo "Server failed to start"
kill -9 $UVICORN_PID
exit 1
......@@ -171,7 +197,7 @@ jobs:
fi
# Check that service will reconnect to postgres when connection will be closed
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has failed before postgres reconnect check"
exit 1
......@@ -183,7 +209,7 @@ jobs:
cur = conn.cursor(); \
cur.execute('SELECT pg_terminate_backend(psa.pid) FROM pg_stat_activity psa WHERE datname = current_database() AND pid <> pg_backend_pid();')"
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health)
status_code=$(curl --write-out %{http_code} -s --output /dev/null http://localhost:8081/health/db)
if [[ "$status_code" -ne 200 ]] ; then
echo "Server has not reconnected to postgres after connection was closed: returned status $status_code"
exit 1
......
......@@ -306,3 +306,4 @@ dist
# cypress artifacts
cypress/videos
cypress/screenshots
.vscode/settings.json
......@@ -5,6 +5,33 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [0.3.8] - 2024-07-09
### Added
- **💬 Chat Controls**: Easily adjust parameters for each chat session, offering more precise control over your interactions.
- **📌 Pinned Chats**: Support for pinned chats, allowing you to keep important conversations easily accessible.
- **📄 Apache Tika Integration**: Added support for using Apache Tika as a document loader, enhancing document processing capabilities.
- **🛠️ Custom Environment for OpenID Claims**: Allows setting custom claims for OpenID, providing more flexibility in user authentication.
- **🔧 Enhanced Tools & Functions API**: Introduced 'event_emitter' and 'event_call', now you can also add citations for better documentation and tracking. Detailed documentation will be provided on our documentation website.
- **↔️ Sideways Scrolling in Settings**: Settings tabs container now supports horizontal scrolling for easier navigation.
- **🌑 Darker OLED Theme**: Includes a new, darker OLED theme and improved styling for the light theme, enhancing visual appeal.
- **🌐 Language Updates**: Updated translations for Indonesian, German, French, and Catalan languages, expanding accessibility.
### Fixed
- **⏰ OpenAI Streaming Timeout**: Resolved issues with OpenAI streaming response using the 'AIOHTTP_CLIENT_TIMEOUT' setting, ensuring reliable performance.
- **💡 User Valves**: Fixed malfunctioning user valves, ensuring proper functionality.
- **🔄 Collapsible Components**: Addressed issues with collapsible components not working, restoring expected behavior.
### Changed
- **🗃️ Database Backend**: Switched from Peewee to SQLAlchemy for improved concurrency support, enhancing database performance.
- **🔤 Primary Font Styling**: Updated primary font to Archivo for better visual consistency.
- **🔄 Font Change for Windows**: Replaced Arimo with Inter font for Windows users, improving readability.
- **🚀 Lazy Loading**: Implemented lazy loading for 'faster_whisper' and 'sentence_transformers' to reduce startup memory usage.
- **📋 Task Generation Payload**: Task generations now include only the "task" field in the body instead of "title".
## [0.3.7] - 2024-06-29
### Added
......
# A generic, single database configuration.
[alembic]
# path to migration scripts
script_location = migrations
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the python>=3.9 or backports.zoneinfo library.
# Any required deps can installed by adding `alembic[tz]` to the pip requirements
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the
# "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to migrations/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "version_path_separator" below.
# version_locations = %(here)s/bar:%(here)s/bat:migrations/versions
# version path separator; As mentioned above, this is the character used to split
# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep.
# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas.
# Valid values for version_path_separator are:
#
# version_path_separator = :
# version_path_separator = ;
# version_path_separator = space
version_path_separator = os # Use os.pathsep. Default configuration used for new projects.
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# sqlalchemy.url = REPLACE_WITH_DATABASE_URL
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the exec runner, execute a binary
# hooks = ruff
# ruff.type = exec
# ruff.executable = %(here)s/.venv/bin/ruff
# ruff.options = --fix REVISION_SCRIPT_FILENAME
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S
......@@ -14,7 +14,6 @@ from fastapi import (
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from pydantic import BaseModel
import uuid
......@@ -277,6 +276,8 @@ def transcribe(
f.close()
if app.state.config.STT_ENGINE == "":
from faster_whisper import WhisperModel
whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type,
......
......@@ -12,7 +12,6 @@ from fastapi import (
Form,
)
from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from constants import ERROR_MESSAGES
from utils.utils import (
......
......@@ -25,6 +25,7 @@ from utils.task import prompt_template
from config import (
SRC_LOG_LEVELS,
ENABLE_OPENAI_API,
AIOHTTP_CLIENT_TIMEOUT,
OPENAI_API_BASE_URLS,
OPENAI_API_KEYS,
CACHE_DIR,
......@@ -463,7 +464,9 @@ async def generate_chat_completion(
streaming = False
try:
session = aiohttp.ClientSession(trust_env=True)
session = aiohttp.ClientSession(
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
)
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
......
......@@ -48,8 +48,6 @@ import mimetypes
import uuid
import json
import sentence_transformers
from apps.webui.models.documents import (
Documents,
DocumentForm,
......@@ -93,6 +91,8 @@ from config import (
SRC_LOG_LEVELS,
UPLOAD_DIR,
DOCS_DIR,
CONTENT_EXTRACTION_ENGINE,
TIKA_SERVER_URL,
RAG_TOP_K,
RAG_RELEVANCE_THRESHOLD,
RAG_EMBEDDING_ENGINE,
......@@ -148,6 +148,9 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
app.state.config.CHUNK_SIZE = CHUNK_SIZE
app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
......@@ -190,6 +193,8 @@ def update_embedding_model(
update_model: bool = False,
):
if embedding_model and app.state.config.RAG_EMBEDDING_ENGINE == "":
import sentence_transformers
app.state.sentence_transformer_ef = sentence_transformers.SentenceTransformer(
get_model_path(embedding_model, update_model),
device=DEVICE_TYPE,
......@@ -204,6 +209,8 @@ def update_reranking_model(
update_model: bool = False,
):
if reranking_model:
import sentence_transformers
app.state.sentence_transformer_rf = sentence_transformers.CrossEncoder(
get_model_path(reranking_model, update_model),
device=DEVICE_TYPE,
......@@ -388,6 +395,10 @@ async def get_rag_config(user=Depends(get_admin_user)):
return {
"status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
......@@ -417,6 +428,11 @@ async def get_rag_config(user=Depends(get_admin_user)):
}
class ContentExtractionConfig(BaseModel):
engine: str = ""
tika_server_url: Optional[str] = None
class ChunkParamUpdateForm(BaseModel):
chunk_size: int
chunk_overlap: int
......@@ -450,6 +466,7 @@ class WebConfig(BaseModel):
class ConfigUpdateForm(BaseModel):
pdf_extract_images: Optional[bool] = None
content_extraction: Optional[ContentExtractionConfig] = None
chunk: Optional[ChunkParamUpdateForm] = None
youtube: Optional[YoutubeLoaderConfig] = None
web: Optional[WebConfig] = None
......@@ -463,6 +480,11 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
else app.state.config.PDF_EXTRACT_IMAGES
)
if form_data.content_extraction is not None:
log.info(f"Updating text settings: {form_data.content_extraction}")
app.state.config.CONTENT_EXTRACTION_ENGINE = form_data.content_extraction.engine
app.state.config.TIKA_SERVER_URL = form_data.content_extraction.tika_server_url
if form_data.chunk is not None:
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
......@@ -499,6 +521,10 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
return {
"status": True,
"pdf_extract_images": app.state.config.PDF_EXTRACT_IMAGES,
"content_extraction": {
"engine": app.state.config.CONTENT_EXTRACTION_ENGINE,
"tika_server_url": app.state.config.TIKA_SERVER_URL,
},
"chunk": {
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
......@@ -978,13 +1004,49 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
return True
except Exception as e:
log.exception(e)
if e.__class__.__name__ == "UniqueConstraintError":
return True
log.exception(e)
return False
class TikaLoader:
def __init__(self, file_path, mime_type=None):
self.file_path = file_path
self.mime_type = mime_type
def load(self) -> List[Document]:
with open(self.file_path, "rb") as f:
data = f.read()
if self.mime_type is not None:
headers = {"Content-Type": self.mime_type}
else:
headers = {}
endpoint = app.state.config.TIKA_SERVER_URL
if not endpoint.endswith("/"):
endpoint += "/"
endpoint += "tika/text"
r = requests.put(endpoint, data=data, headers=headers)
if r.ok:
raw_metadata = r.json()
text = raw_metadata.get("X-TIKA:content", "<No text content found>")
if "Content-Type" in raw_metadata:
headers["Content-Type"] = raw_metadata["Content-Type"]
log.info("Tika extracted text: %s", text)
return [Document(page_content=text, metadata=headers)]
else:
raise Exception(f"Error calling Tika: {r.reason}")
def get_loader(filename: str, file_content_type: str, file_path: str):
file_ext = filename.split(".")[-1].lower()
known_type = True
......@@ -1035,6 +1097,17 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"msg",
]
if (
app.state.config.CONTENT_EXTRACTION_ENGINE == "tika"
and app.state.config.TIKA_SERVER_URL
):
if file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
loader = TextLoader(file_path, autodetect_encoding=True)
else:
loader = TikaLoader(file_path, file_content_type)
else:
if file_ext == "pdf":
loader = PyPDFLoader(
file_path, extract_images=app.state.config.PDF_EXTRACT_IMAGES
......
......@@ -294,15 +294,17 @@ def get_rag_context(
extracted_collections.extend(collection_names)
context_string = ""
contexts = []
citations = []
for context in relevant_contexts:
try:
if "documents" in context:
context_string += "\n\n".join(
contexts.append(
"\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
)
if "metadatas" in context:
citations.append(
......@@ -315,9 +317,7 @@ def get_rag_context(
except Exception as e:
log.exception(e)
context_string = context_string.strip()
return context_string, citations
return contexts, citations
def get_model_path(model: str, update_model: bool = False):
......@@ -442,8 +442,6 @@ from langchain_core.documents import BaseDocumentCompressor, Document
from langchain_core.callbacks import Callbacks
from langchain_core.pydantic_v1 import Extra
from sentence_transformers import util
class RerankCompressor(BaseDocumentCompressor):
embedding_function: Any
......@@ -468,6 +466,8 @@ class RerankCompressor(BaseDocumentCompressor):
[(query, doc.page_content) for doc in documents]
)
else:
from sentence_transformers import util
query_embedding = self.embedding_function(query)
document_embedding = self.embedding_function(
[doc.page_content for doc in documents]
......
import os
import logging
import json
from contextlib import contextmanager
from peewee import *
from peewee_migrate import Router
from apps.webui.internal.wrappers import register_connection
from typing import Optional, Any
from typing_extensions import Self
from sqlalchemy import create_engine, types, Dialect
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session
from sqlalchemy.sql.type_api import _T
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
class JSONField(TextField):
class JSONField(types.TypeDecorator):
impl = types.Text
cache_ok = True
def process_bind_param(self, value: Optional[_T], dialect: Dialect) -> Any:
return json.dumps(value)
def process_result_value(self, value: Optional[_T], dialect: Dialect) -> Any:
if value is not None:
return json.loads(value)
def copy(self, **kw: Any) -> Self:
return JSONField(self.impl.length)
def db_value(self, value):
return json.dumps(value)
......@@ -30,25 +51,60 @@ else:
pass
# The `register_connection` function encapsulates the logic for setting up
# the database connection based on the connection string, while `connect`
# is a Peewee-specific method to manage the connection state and avoid errors
# when a connection is already open.
try:
DB = register_connection(DATABASE_URL)
log.info(f"Connected to a {DB.__class__.__name__} database.")
except Exception as e:
# Workaround to handle the peewee migration
# This is required to ensure the peewee migration is handled before the alembic migration
def handle_peewee_migration(DATABASE_URL):
try:
# Replace the postgresql:// with postgres:// and %40 with @ in the DATABASE_URL
db = register_connection(
DATABASE_URL.replace("postgresql://", "postgres://").replace("%40", "@")
)
migrate_dir = BACKEND_DIR / "apps" / "webui" / "internal" / "migrations"
router = Router(db, logger=log, migrate_dir=migrate_dir)
router.run()
db.close()
# check if db connection has been closed
except Exception as e:
log.error(f"Failed to initialize the database connection: {e}")
raise
router = Router(
DB,
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
logger=log,
finally:
# Properly closing the database connection
if db and not db.is_closed():
db.close()
# Assert if db connection has been closed
assert db.is_closed(), "Database connection is still open."
handle_peewee_migration(DATABASE_URL)
SQLALCHEMY_DATABASE_URL = DATABASE_URL
if "sqlite" in SQLALCHEMY_DATABASE_URL:
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
else:
engine = create_engine(SQLALCHEMY_DATABASE_URL, pool_pre_ping=True)
SessionLocal = sessionmaker(
autocommit=False, autoflush=False, bind=engine, expire_on_commit=False
)
router.run()
try:
DB.connect(reuse_if_open=True)
except OperationalError as e:
log.info(f"Failed to connect to database again due to: {e}")
pass
Base = declarative_base()
Session = scoped_session(SessionLocal)
# Dependency
def get_session():
db = SessionLocal()
try:
yield db
finally:
db.close()
get_db = contextmanager(get_session)
"""Peewee migrations -- 017_add_user_oauth_sub.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
......@@ -21,7 +18,6 @@ Some examples (model - class or model name)::
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
......
# Database Migrations
This directory contains all the database migrations for the web app.
Migrations are done using the [`peewee-migrate`](https://github.com/klen/peewee_migrate) library.
Migrations are automatically ran at app startup.
## Creating a migration
Have you made a change to the schema of an existing model?
You will need to create a migration file to ensure that existing databases are updated for backwards compatibility.
1. Have a database file (`webui.db`) that has the old schema prior to any of your changes.
2. Make your changes to the models.
3. From the `backend` directory, run the following command:
```bash
pw_migrate create --auto --auto-source apps.webui.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME}
```
- `$SQLITE_DB` should be the path to the database file.
- `$MIGRATION_NAME` should be a descriptive name for the migration.
4. The migration file will be created in the `apps/web/internal/migrations` directory.
......@@ -3,7 +3,7 @@ from fastapi.routing import APIRoute
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from sqlalchemy.orm import Session
from apps.webui.routers import (
auths,
users,
......@@ -19,8 +19,13 @@ from apps.webui.routers import (
functions,
)
from apps.webui.models.functions import Functions
from apps.webui.models.models import Models
from apps.webui.utils import load_function_module_by_id
from utils.misc import stream_message_template
from utils.task import prompt_template
from config import (
WEBUI_BUILD_HASH,
......@@ -39,6 +44,8 @@ from config import (
WEBUI_BANNERS,
ENABLE_COMMUNITY_SHARING,
AppConfig,
OAUTH_USERNAME_CLAIM,
OAUTH_PICTURE_CLAIM,
)
import inspect
......@@ -74,6 +81,9 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.MODELS = {}
app.state.TOOLS = {}
app.state.FUNCTIONS = {}
......@@ -129,7 +139,6 @@ async def get_pipe_models():
function_module = app.state.FUNCTIONS[pipe.id]
if hasattr(function_module, "valves") and hasattr(function_module, "Valves"):
print(f"Getting valves for {pipe.id}")
valves = Functions.get_function_valves_by_id(pipe.id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
......@@ -181,6 +190,77 @@ async def get_pipe_models():
async def generate_function_chat_completion(form_data, user):
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
form_data["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
if model_info.params.get("temperature", None) is not None:
form_data["temperature"] = float(model_info.params.get("temperature"))
if model_info.params.get("top_p", None):
form_data["top_p"] = int(model_info.params.get("top_p", None))
if model_info.params.get("max_tokens", None):
form_data["max_tokens"] = int(model_info.params.get("max_tokens", None))
if model_info.params.get("frequency_penalty", None):
form_data["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
form_data["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
form_data["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if form_data.get("messages"):
for message in form_data["messages"]:
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
form_data["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
else:
pass
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
......@@ -259,6 +339,9 @@ async def generate_function_chat_completion(form_data, user):
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
if isinstance(line, dict):
line = f"data: {json.dumps(line)}"
try:
line = line.decode("utf-8")
except:
......
from pydantic import BaseModel
from typing import List, Union, Optional
import time
from typing import Optional
import uuid
import logging
from peewee import *
from sqlalchemy import String, Column, Boolean, Text
from apps.webui.models.users import UserModel, Users
from utils.utils import verify_password
from apps.webui.internal.db import DB
from apps.webui.internal.db import Base, get_db
from config import SRC_LOG_LEVELS
......@@ -20,14 +19,13 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Auth(Model):
id = CharField(unique=True)
email = CharField()
password = TextField()
active = BooleanField()
class Auth(Base):
__tablename__ = "auth"
class Meta:
database = DB
id = Column(String, primary_key=True)
email = Column(String)
password = Column(Text)
active = Column(Boolean)
class AuthModel(BaseModel):
......@@ -94,9 +92,6 @@ class AddUserForm(SignupForm):
class AuthsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Auth])
def insert_new_auth(
self,
......@@ -107,6 +102,8 @@ class AuthsTable:
role: str = "pending",
oauth_sub: Optional[str] = None,
) -> Optional[UserModel]:
with get_db() as db:
log.info("insert_new_auth")
id = str(uuid.uuid4())
......@@ -114,12 +111,16 @@ class AuthsTable:
auth = AuthModel(
**{"id": id, "email": email, "password": password, "active": True}
)
result = Auth.create(**auth.model_dump())
result = Auth(**auth.model_dump())
db.add(result)
user = Users.insert_new_user(
id, name, email, profile_image_url, role, oauth_sub
)
db.commit()
db.refresh(result)
if result and user:
return user
else:
......@@ -128,7 +129,9 @@ class AuthsTable:
def authenticate_user(self, email: str, password: str) -> Optional[UserModel]:
log.info(f"authenticate_user: {email}")
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
with get_db() as db:
auth = db.query(Auth).filter_by(email=email, active=True).first()
if auth:
if verify_password(password, auth.password):
user = Users.get_user_by_id(auth.id)
......@@ -155,7 +158,8 @@ class AuthsTable:
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
with get_db() as db:
auth = db.query(Auth).filter(email=email, active=True).first()
if auth:
user = Users.get_user_by_id(auth.id)
return user
......@@ -164,31 +168,34 @@ class AuthsTable:
def update_user_password_by_id(self, id: str, new_password: str) -> bool:
try:
query = Auth.update(password=new_password).where(Auth.id == id)
result = query.execute()
with get_db() as db:
result = (
db.query(Auth).filter_by(id=id).update({"password": new_password})
)
db.commit()
return True if result == 1 else False
except:
return False
def update_email_by_id(self, id: str, email: str) -> bool:
try:
query = Auth.update(email=email).where(Auth.id == id)
result = query.execute()
with get_db() as db:
result = db.query(Auth).filter_by(id=id).update({"email": email})
db.commit()
return True if result == 1 else False
except:
return False
def delete_auth_by_id(self, id: str) -> bool:
try:
with get_db() as db:
# Delete User
result = Users.delete_user_by_id(id)
if result:
# Delete Auth
query = Auth.delete().where(Auth.id == id)
query.execute() # Remove the rows, return number of rows removed.
db.query(Auth).filter_by(id=id).delete()
db.commit()
return True
else:
......@@ -197,4 +204,4 @@ class AuthsTable:
return False
Auths = AuthsTable(DB)
Auths = AuthsTable()
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
import json
import uuid
import time
from apps.webui.internal.db import DB
from sqlalchemy import Column, String, BigInteger, Boolean, Text
from apps.webui.internal.db import Base, get_db
####################
# Chat DB Schema
####################
class Chat(Model):
id = CharField(unique=True)
user_id = CharField()
title = TextField()
chat = TextField() # Save Chat JSON as Text
class Chat(Base):
__tablename__ = "chat"
created_at = BigIntegerField()
updated_at = BigIntegerField()
id = Column(String, primary_key=True)
user_id = Column(String)
title = Column(Text)
chat = Column(Text) # Save Chat JSON as Text
share_id = CharField(null=True, unique=True)
archived = BooleanField(default=False)
created_at = Column(BigInteger)
updated_at = Column(BigInteger)
class Meta:
database = DB
share_id = Column(Text, unique=True, nullable=True)
archived = Column(Boolean, default=False)
class ChatModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: str
user_id: str
title: str
......@@ -75,18 +77,19 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable:
def __init__(self, db):
self.db = db
db.create_tables([Chat])
def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
with get_db() as db:
id = str(uuid.uuid4())
chat = ChatModel(
**{
"id": id,
"user_id": user_id,
"title": (
form_data.chat["title"] if "title" in form_data.chat else "New Chat"
form_data.chat["title"]
if "title" in form_data.chat
else "New Chat"
),
"chat": json.dumps(form_data.chat),
"created_at": int(time.time()),
......@@ -94,26 +97,32 @@ class ChatTable:
}
)
result = Chat.create(**chat.model_dump())
return chat if result else None
result = Chat(**chat.model_dump())
db.add(result)
db.commit()
db.refresh(result)
return ChatModel.model_validate(result) if result else None
def update_chat_by_id(self, id: str, chat: dict) -> Optional[ChatModel]:
try:
query = Chat.update(
chat=json.dumps(chat),
title=chat["title"] if "title" in chat else "New Chat",
updated_at=int(time.time()),
).where(Chat.id == id)
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
with get_db() as db:
chat_obj = db.get(Chat, id)
chat_obj.chat = json.dumps(chat)
chat_obj.title = chat["title"] if "title" in chat else "New Chat"
chat_obj.updated_at = int(time.time())
db.commit()
db.refresh(chat_obj)
return ChatModel.model_validate(chat_obj)
except Exception as e:
return None
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
with get_db() as db:
# Get the existing chat to share
chat = Chat.get(Chat.id == chat_id)
chat = db.get(Chat, chat_id)
# Check if the chat is already shared
if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
......@@ -128,36 +137,42 @@ class ChatTable:
"updated_at": int(time.time()),
}
)
shared_result = Chat.create(**shared_chat.model_dump())
shared_result = Chat(**shared_chat.model_dump())
db.add(shared_result)
db.commit()
db.refresh(shared_result)
# Update the original chat with the share_id
result = (
Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
db.query(Chat)
.filter_by(id=chat_id)
.update({"share_id": shared_chat.id})
)
db.commit()
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try:
with get_db() as db:
print("update_shared_chat_by_id")
chat = Chat.get(Chat.id == chat_id)
chat = db.get(Chat, chat_id)
print(chat)
chat.title = chat.title
chat.chat = chat.chat
db.commit()
db.refresh(chat)
query = Chat.update(
title=chat.title,
chat=chat.chat,
).where(Chat.id == chat.share_id)
query.execute()
chat = Chat.get(Chat.id == chat.share_id)
return ChatModel(**model_to_dict(chat))
return self.get_chat_by_id(chat.share_id)
except:
return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try:
query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}")
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Chat).filter_by(user_id=f"shared-{chat_id}").delete()
db.commit()
return True
except:
......@@ -167,40 +182,33 @@ class ChatTable:
self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]:
try:
query = Chat.update(
share_id=share_id,
).where(Chat.id == id)
query.execute()
with get_db() as db:
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
chat = db.get(Chat, id)
chat.share_id = share_id
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except:
return None
def toggle_chat_archive_by_id(self, id: str) -> Optional[ChatModel]:
try:
chat = self.get_chat_by_id(id)
query = Chat.update(
archived=(not chat.archived),
).where(Chat.id == id)
with get_db() as db:
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
chat = db.get(Chat, id)
chat.archived = not chat.archived
db.commit()
db.refresh(chat)
return ChatModel.model_validate(chat)
except:
return None
def archive_all_chats_by_user_id(self, user_id: str) -> bool:
try:
chats = self.get_chats_by_user_id(user_id)
for chat in chats:
query = Chat.update(
archived=True,
).where(Chat.id == chat.id)
query.execute()
with get_db() as db:
db.query(Chat).filter_by(user_id=user_id).update({"archived": True})
db.commit()
return True
except:
return False
......@@ -208,15 +216,16 @@ class ChatTable:
def get_archived_chat_list_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id)
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
# .limit(limit).offset(skip)
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_user_id(
self,
......@@ -225,92 +234,97 @@ class ChatTable:
skip: int = 0,
limit: int = 50,
) -> List[ChatModel]:
if include_archived:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
else:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit)
# .offset(skip)
]
with get_db() as db:
query = db.query(Chat).filter_by(user_id=user_id)
if not include_archived:
query = query.filter_by(archived=False)
all_chats = (
query.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_list_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == False)
.where(Chat.id.in_(chat_ids))
with get_db() as db:
all_chats = (
db.query(Chat)
.filter(Chat.id.in_(chat_ids))
.filter_by(archived=False)
.order_by(Chat.updated_at.desc())
]
.all()
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
with get_db() as db:
chat = db.get(Chat, id)
return ChatModel.model_validate(chat)
except:
return None
def get_chat_by_share_id(self, id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.share_id == id)
with get_db() as db:
chat = db.query(Chat).filter_by(share_id=id).first()
if chat:
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
return self.get_chat_by_id(id)
else:
return None
except:
except Exception as e:
return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
return ChatModel(**model_to_dict(chat))
with get_db() as db:
chat = db.query(Chat).filter_by(id=id, user_id=user_id).first()
return ChatModel.model_validate(chat)
except:
return None
def get_chats(self, skip: int = 0, limit: int = 50) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select().order_by(Chat.updated_at.desc())
with get_db() as db:
all_chats = (
db.query(Chat)
# .limit(limit).offset(skip)
]
.order_by(Chat.updated_at.desc())
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.user_id == user_id)
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id)
.order_by(Chat.updated_at.desc())
# .limit(limit).offset(skip)
]
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id)
with get_db() as db:
all_chats = (
db.query(Chat)
.filter_by(user_id=user_id, archived=True)
.order_by(Chat.updated_at.desc())
]
)
return [ChatModel.model_validate(chat) for chat in all_chats]
def delete_chat_by_id(self, id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id))
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Chat).filter_by(id=id).delete()
db.commit()
return True and self.delete_shared_chat_by_chat_id(id)
except:
......@@ -318,8 +332,10 @@ class ChatTable:
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Chat).filter_by(id=id, user_id=user_id).delete()
db.commit()
return True and self.delete_shared_chat_by_chat_id(id)
except:
......@@ -328,10 +344,12 @@ class ChatTable:
def delete_chats_by_user_id(self, user_id: str) -> bool:
try:
with get_db() as db:
self.delete_shared_chats_by_user_id(user_id)
query = Chat.delete().where(Chat.user_id == user_id)
query.execute() # Remove the rows, return number of rows removed.
db.query(Chat).filter_by(user_id=user_id).delete()
db.commit()
return True
except:
......@@ -339,17 +357,18 @@ class ChatTable:
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try:
shared_chat_ids = [
f"shared-{chat.id}"
for chat in Chat.select().where(Chat.user_id == user_id)
]
query = Chat.delete().where(Chat.user_id << shared_chat_ids)
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
chats_by_user = db.query(Chat).filter_by(user_id=user_id).all()
shared_chat_ids = [f"shared-{chat.id}" for chat in chats_by_user]
db.query(Chat).filter(Chat.user_id.in_(shared_chat_ids)).delete()
db.commit()
return True
except:
return False
Chats = ChatTable(DB)
Chats = ChatTable()
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
from pydantic import BaseModel, ConfigDict
from typing import List, Optional
import time
import logging
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from sqlalchemy import String, Column, BigInteger, Text
from apps.webui.internal.db import DB
from apps.webui.internal.db import Base, get_db
import json
......@@ -22,20 +19,21 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Document(Model):
collection_name = CharField(unique=True)
name = CharField(unique=True)
title = TextField()
filename = TextField()
content = TextField(null=True)
user_id = CharField()
timestamp = BigIntegerField()
class Document(Base):
__tablename__ = "document"
class Meta:
database = DB
collection_name = Column(String, primary_key=True)
name = Column(String, unique=True)
title = Column(Text)
filename = Column(Text)
content = Column(Text, nullable=True)
user_id = Column(String)
timestamp = Column(BigInteger)
class DocumentModel(BaseModel):
model_config = ConfigDict(from_attributes=True)
collection_name: str
name: str
title: str
......@@ -72,13 +70,12 @@ class DocumentForm(DocumentUpdateForm):
class DocumentsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Document])
def insert_new_doc(
self, user_id: str, form_data: DocumentForm
) -> Optional[DocumentModel]:
with get_db() as db:
document = DocumentModel(
**{
**form_data.model_dump(),
......@@ -88,9 +85,12 @@ class DocumentsTable:
)
try:
result = Document.create(**document.model_dump())
result = Document(**document.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return document
return DocumentModel.model_validate(result)
else:
return None
except:
......@@ -98,31 +98,35 @@ class DocumentsTable:
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
try:
document = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(document))
with get_db() as db:
document = db.query(Document).filter_by(name=name).first()
return DocumentModel.model_validate(document) if document else None
except:
return None
def get_docs(self) -> List[DocumentModel]:
with get_db() as db:
return [
DocumentModel(**model_to_dict(doc))
for doc in Document.select()
# .limit(limit).offset(skip)
DocumentModel.model_validate(doc) for doc in db.query(Document).all()
]
def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm
) -> Optional[DocumentModel]:
try:
query = Document.update(
title=form_data.title,
name=form_data.name,
timestamp=int(time.time()),
).where(Document.name == name)
query.execute()
doc = Document.get(Document.name == form_data.name)
return DocumentModel(**model_to_dict(doc))
with get_db() as db:
db.query(Document).filter_by(name=name).update(
{
"title": form_data.title,
"name": form_data.name,
"timestamp": int(time.time()),
}
)
db.commit()
return self.get_doc_by_name(form_data.name)
except Exception as e:
log.exception(e)
return None
......@@ -135,26 +139,29 @@ class DocumentsTable:
doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = {**doc_content, **updated}
query = Document.update(
content=json.dumps(doc_content),
timestamp=int(time.time()),
).where(Document.name == name)
query.execute()
with get_db() as db:
doc = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(doc))
db.query(Document).filter_by(name=name).update(
{
"content": json.dumps(doc_content),
"timestamp": int(time.time()),
}
)
db.commit()
return self.get_doc_by_name(name)
except Exception as e:
log.exception(e)
return None
def delete_doc_by_name(self, name: str) -> bool:
try:
query = Document.delete().where((Document.name == name))
query.execute() # Remove the rows, return number of rows removed.
with get_db() as db:
db.query(Document).filter_by(name=name).delete()
db.commit()
return True
except:
return False
Documents = DocumentsTable(DB)
Documents = DocumentsTable()
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
from sqlalchemy import Column, String, BigInteger, Text
from apps.webui.internal.db import JSONField, Base, get_db
import json
......@@ -18,15 +19,14 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class File(Model):
id = CharField(unique=True)
user_id = CharField()
filename = TextField()
meta = JSONField()
created_at = BigIntegerField()
class File(Base):
__tablename__ = "file"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
filename = Column(Text)
meta = Column(JSONField)
created_at = Column(BigInteger)
class FileModel(BaseModel):
......@@ -36,6 +36,8 @@ class FileModel(BaseModel):
meta: dict
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
......@@ -57,11 +59,10 @@ class FileForm(BaseModel):
class FilesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([File])
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
with get_db() as db:
file = FileModel(
**{
**form_data.model_dump(),
......@@ -71,9 +72,12 @@ class FilesTable:
)
try:
result = File.create(**file.model_dump())
result = File(**file.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return file
return FileModel.model_validate(result)
else:
return None
except Exception as e:
......@@ -81,32 +85,42 @@ class FilesTable:
return None
def get_file_by_id(self, id: str) -> Optional[FileModel]:
with get_db() as db:
try:
file = File.get(File.id == id)
return FileModel(**model_to_dict(file))
file = db.get(File, id)
return FileModel.model_validate(file)
except:
return None
def get_files(self) -> List[FileModel]:
return [FileModel(**model_to_dict(file)) for file in File.select()]
with get_db() as db:
return [FileModel.model_validate(file) for file in db.query(File).all()]
def delete_file_by_id(self, id: str) -> bool:
with get_db() as db:
try:
query = File.delete().where((File.id == id))
query.execute() # Remove the rows, return number of rows removed.
db.query(File).filter_by(id=id).delete()
db.commit()
return True
except:
return False
def delete_all_files(self) -> bool:
with get_db() as db:
try:
query = File.delete()
query.execute() # Remove the rows, return number of rows removed.
db.query(File).delete()
db.commit()
return True
except:
return False
Files = FilesTable(DB)
Files = FilesTable()
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
from sqlalchemy import Column, String, Text, BigInteger, Boolean
from apps.webui.internal.db import JSONField, Base, get_db
from apps.webui.models.users import Users
import json
......@@ -21,21 +22,20 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
class Function(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
type = TextField()
content = TextField()
meta = JSONField()
valves = JSONField()
is_active = BooleanField(default=False)
is_global = BooleanField(default=False)
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Function(Base):
__tablename__ = "function"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
name = Column(Text)
type = Column(Text)
content = Column(Text)
meta = Column(JSONField)
valves = Column(JSONField)
is_active = Column(Boolean)
is_global = Column(Boolean)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class FunctionMeta(BaseModel):
......@@ -55,6 +55,8 @@ class FunctionModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
......@@ -85,13 +87,11 @@ class FunctionValves(BaseModel):
class FunctionsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Function])
def insert_new_function(
self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]:
function = FunctionModel(
**{
**form_data.model_dump(),
......@@ -103,9 +103,13 @@ class FunctionsTable:
)
try:
result = Function.create(**function.model_dump())
with get_db() as db:
result = Function(**function.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return function
return FunctionModel.model_validate(result)
else:
return None
except Exception as e:
......@@ -114,52 +118,60 @@ class FunctionsTable:
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try:
function = Function.get(Function.id == id)
return FunctionModel(**model_to_dict(function))
with get_db() as db:
function = db.get(Function, id)
return FunctionModel.model_validate(function)
except:
return None
def get_functions(self, active_only=False) -> List[FunctionModel]:
with get_db() as db:
if active_only:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(Function.is_active == True)
FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(is_active=True).all()
]
else:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select()
FunctionModel.model_validate(function)
for function in db.query(Function).all()
]
def get_functions_by_type(
self, type: str, active_only=False
) -> List[FunctionModel]:
with get_db() as db:
if active_only:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(
Function.type == type, Function.is_active == True
)
FunctionModel.model_validate(function)
for function in db.query(Function)
.filter_by(type=type, is_active=True)
.all()
]
else:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(Function.type == type)
FunctionModel.model_validate(function)
for function in db.query(Function).filter_by(type=type).all()
]
def get_global_filter_functions(self) -> List[FunctionModel]:
with get_db() as db:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(
Function.type == "filter",
Function.is_active == True,
Function.is_global == True,
)
FunctionModel.model_validate(function)
for function in db.query(Function)
.filter_by(type="filter", is_active=True, is_global=True)
.all()
]
def get_function_valves_by_id(self, id: str) -> Optional[dict]:
with get_db() as db:
try:
function = Function.get(Function.id == id)
function = db.get(Function, id)
return function.valves if function.valves else {}
except Exception as e:
print(f"An error occurred: {e}")
......@@ -168,24 +180,25 @@ class FunctionsTable:
def update_function_valves_by_id(
self, id: str, valves: dict
) -> Optional[FunctionValves]:
with get_db() as db:
try:
query = Function.update(
**{"valves": valves},
updated_at=int(time.time()),
).where(Function.id == id)
query.execute()
function = Function.get(Function.id == id)
return FunctionValves(**model_to_dict(function))
function = db.get(Function, id)
function.valves = valves
function.updated_at = int(time.time())
db.commit()
db.refresh(function)
return self.get_function_by_id(id)
except:
return None
def get_user_valves_by_id_and_user_id(
self, id: str, user_id: str
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
......@@ -201,9 +214,10 @@ class FunctionsTable:
def update_user_valves_by_id_and_user_id(
self, id: str, user_id: str, valves: dict
) -> Optional[dict]:
try:
user = Users.get_user_by_id(user_id)
user_settings = user.settings.model_dump()
user_settings = user.settings.model_dump() if user.settings else {}
# Check if user has "functions" and "valves" settings
if "functions" not in user_settings:
......@@ -214,8 +228,7 @@ class FunctionsTable:
user_settings["functions"]["valves"][id] = valves
# Update the user settings in the database
query = Users.update_user_by_id(user_id, {"settings": user_settings})
query.execute()
Users.update_user_by_id(user_id, {"settings": user_settings})
return user_settings["functions"]["valves"][id]
except Exception as e:
......@@ -223,39 +236,44 @@ class FunctionsTable:
return None
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
with get_db() as db:
try:
query = Function.update(
db.query(Function).filter_by(id=id).update(
{
**updated,
updated_at=int(time.time()),
).where(Function.id == id)
query.execute()
function = Function.get(Function.id == id)
return FunctionModel(**model_to_dict(function))
"updated_at": int(time.time()),
}
)
db.commit()
return self.get_function_by_id(id)
except:
return None
def deactivate_all_functions(self) -> Optional[bool]:
with get_db() as db:
try:
query = Function.update(
**{"is_active": False},
updated_at=int(time.time()),
db.query(Function).update(
{
"is_active": False,
"updated_at": int(time.time()),
}
)
query.execute()
db.commit()
return True
except:
return None
def delete_function_by_id(self, id: str) -> bool:
with get_db() as db:
try:
query = Function.delete().where((Function.id == id))
query.execute() # Remove the rows, return number of rows removed.
db.query(Function).filter_by(id=id).delete()
db.commit()
return True
except:
return False
Functions = FunctionsTable(DB)
Functions = FunctionsTable()
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from pydantic import BaseModel, ConfigDict
from typing import List, Union, Optional
from apps.webui.internal.db import DB
from apps.webui.models.chats import Chats
from sqlalchemy import Column, String, BigInteger, Text
from apps.webui.internal.db import Base, get_db
import time
import uuid
......@@ -14,15 +13,14 @@ import uuid
####################
class Memory(Model):
id = CharField(unique=True)
user_id = CharField()
content = TextField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Memory(Base):
__tablename__ = "memory"
class Meta:
database = DB
id = Column(String, primary_key=True)
user_id = Column(String)
content = Column(Text)
updated_at = Column(BigInteger)
created_at = Column(BigInteger)
class MemoryModel(BaseModel):
......@@ -32,6 +30,8 @@ class MemoryModel(BaseModel):
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
model_config = ConfigDict(from_attributes=True)
####################
# Forms
......@@ -39,15 +39,14 @@ class MemoryModel(BaseModel):
class MemoriesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Memory])
def insert_new_memory(
self,
user_id: str,
content: str,
) -> Optional[MemoryModel]:
with get_db() as db:
id = str(uuid.uuid4())
memory = MemoryModel(
......@@ -59,9 +58,12 @@ class MemoriesTable:
"updated_at": int(time.time()),
}
)
result = Memory.create(**memory.model_dump())
result = Memory(**memory.model_dump())
db.add(result)
db.commit()
db.refresh(result)
if result:
return memory
return MemoryModel.model_validate(result)
else:
return None
......@@ -70,40 +72,50 @@ class MemoriesTable:
id: str,
content: str,
) -> Optional[MemoryModel]:
with get_db() as db:
try:
memory = Memory.get(Memory.id == id)
memory.content = content
memory.updated_at = int(time.time())
memory.save()
return MemoryModel(**model_to_dict(memory))
db.query(Memory).filter_by(id=id).update(
{"content": content, "updated_at": int(time.time())}
)
db.commit()
return self.get_memory_by_id(id)
except:
return None
def get_memories(self) -> List[MemoryModel]:
with get_db() as db:
try:
memories = Memory.select()
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
memories = db.query(Memory).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except:
return None
def get_memories_by_user_id(self, user_id: str) -> List[MemoryModel]:
with get_db() as db:
try:
memories = Memory.select().where(Memory.user_id == user_id)
return [MemoryModel(**model_to_dict(memory)) for memory in memories]
memories = db.query(Memory).filter_by(user_id=user_id).all()
return [MemoryModel.model_validate(memory) for memory in memories]
except:
return None
def get_memory_by_id(self, id) -> Optional[MemoryModel]:
def get_memory_by_id(self, id: str) -> Optional[MemoryModel]:
with get_db() as db:
try:
memory = Memory.get(Memory.id == id)
return MemoryModel(**model_to_dict(memory))
memory = db.get(Memory, id)
return MemoryModel.model_validate(memory)
except:
return None
def delete_memory_by_id(self, id: str) -> bool:
with get_db() as db:
try:
query = Memory.delete().where(Memory.id == id)
query.execute() # Remove the rows, return number of rows removed.
db.query(Memory).filter_by(id=id).delete()
db.commit()
return True
......@@ -111,22 +123,26 @@ class MemoriesTable:
return False
def delete_memories_by_user_id(self, user_id: str) -> bool:
with get_db() as db:
try:
query = Memory.delete().where(Memory.user_id == user_id)
query.execute()
db.query(Memory).filter_by(user_id=user_id).delete()
db.commit()
return True
except:
return False
def delete_memory_by_id_and_user_id(self, id: str, user_id: str) -> bool:
with get_db() as db:
try:
query = Memory.delete().where(Memory.id == id, Memory.user_id == user_id)
query.execute()
db.query(Memory).filter_by(id=id, user_id=user_id).delete()
db.commit()
return True
except:
return False
Memories = MemoriesTable(DB)
Memories = MemoriesTable()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment