"ppstructure/vscode:/vscode.git/clone" did not exist on "6c19d15a571a4cbc83461b749e8056e99992af6e"
Commit 4aab4609 authored by Jun Siang Cheah's avatar Jun Siang Cheah
Browse files

Merge remote-tracking branch 'upstream/dev' into feat/oauth

parents 4ff17acc a2ea6b1b
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class File(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
filename = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "file"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("file")
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Function(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
type = pw.TextField()
content = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "function"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("function")
from contextvars import ContextVar
from peewee import *
from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
import logging
from playhouse.db_url import connect, parse
from playhouse.shortcuts import ReconnectMixin
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["DB"])
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
db_state = ContextVar("db_state", default=db_state_default.copy())
class PeeweeConnectionState(object):
def __init__(self, **kwargs):
super().__setattr__("_state", db_state)
super().__init__(**kwargs)
def __setattr__(self, name, value):
self._state.get()[name] = value
def __getattr__(self, name):
value = self._state.get()[name]
return value
class CustomReconnectMixin(ReconnectMixin):
reconnect_errors = (
# psycopg2
(OperationalError, "termin"),
(InterfaceError, "closed"),
# peewee
(PeeWeeInterfaceError, "closed"),
)
class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
pass
def register_connection(db_url):
db = connect(db_url)
if isinstance(db, PostgresqlDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to PostgreSQL database")
# Get the connection details
connection = parse(db_url)
# Use our custom database class that supports reconnection
db = ReconnectingPostgresqlDatabase(
connection["database"],
user=connection["user"],
password=connection["password"],
host=connection["host"],
port=connection["port"],
)
db.connect(reuse_if_open=True)
elif isinstance(db, SqliteDatabase):
# Enable autoconnect for SQLite databases, managed by Peewee
db.autoconnect = True
db.reuse_if_open = True
log.info("Connected to SQLite database")
else:
raise ValueError("Unsupported database connection")
return db
...@@ -14,7 +14,12 @@ from apps.webui.routers import ( ...@@ -14,7 +14,12 @@ from apps.webui.routers import (
configs, configs,
memories, memories,
utils, utils,
files,
functions,
) )
from apps.webui.models.functions import Functions
from apps.webui.utils import load_function_module_by_id
from config import ( from config import (
WEBUI_BUILD_HASH, WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS, SHOW_ADMIN_DETAILS,
...@@ -27,6 +32,7 @@ from config import ( ...@@ -27,6 +32,7 @@ from config import (
USER_PERMISSIONS, USER_PERMISSIONS,
WEBHOOK_URL, WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
JWT_EXPIRES_IN, JWT_EXPIRES_IN,
WEBUI_BANNERS, WEBUI_BANNERS,
ENABLE_COMMUNITY_SHARING, ENABLE_COMMUNITY_SHARING,
...@@ -42,6 +48,7 @@ app.state.config = AppConfig() ...@@ -42,6 +48,7 @@ app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
...@@ -59,7 +66,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING ...@@ -59,7 +66,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {} app.state.MODELS = {}
app.state.TOOLS = {} app.state.TOOLS = {}
app.state.FUNCTIONS = {}
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
...@@ -69,17 +76,21 @@ app.add_middleware( ...@@ -69,17 +76,21 @@ app.add_middleware(
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"]) app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"]) app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"]) app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"]) app.include_router(memories.router, prefix="/memories", tags=["memories"])
app.include_router(files.router, prefix="/files", tags=["files"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(functions.router, prefix="/functions", tags=["functions"])
app.include_router(configs.router, prefix="/configs", tags=["configs"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])
...@@ -91,3 +102,58 @@ async def get_status(): ...@@ -91,3 +102,58 @@ async def get_status():
"default_models": app.state.config.DEFAULT_MODELS, "default_models": app.state.config.DEFAULT_MODELS,
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
} }
async def get_pipe_models():
pipes = Functions.get_functions_by_type("pipe")
pipe_models = []
for pipe in pipes:
# Check if function is already loaded
if pipe.id not in app.state.FUNCTIONS:
function_module, function_type = load_function_module_by_id(pipe.id)
app.state.FUNCTIONS[pipe.id] = function_module
else:
function_module = app.state.FUNCTIONS[pipe.id]
# Check if function is a manifold
if hasattr(function_module, "type"):
if function_module.type == "manifold":
manifold_pipes = []
# Check if pipes is a function or a list
if callable(function_module.pipes):
manifold_pipes = function_module.pipes()
else:
manifold_pipes = function_module.pipes
for p in manifold_pipes:
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
manifold_pipe_name = p["name"]
if hasattr(function_module, "name"):
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
pipe_models.append(
{
"id": manifold_pipe_id,
"name": manifold_pipe_name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": pipe.type},
}
)
else:
pipe_models.append(
{
"id": pipe.id,
"name": pipe.name,
"object": "model",
"created": pipe.created_at,
"owned_by": "openai",
"pipe": {"type": "pipe"},
}
)
return pipe_models
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Files DB Schema
####################
class File(Model):
id = CharField(unique=True)
user_id = CharField()
filename = TextField()
meta = JSONField()
created_at = BigIntegerField()
class Meta:
database = DB
class FileModel(BaseModel):
id: str
user_id: str
filename: str
meta: dict
created_at: int # timestamp in epoch
####################
# Forms
####################
class FileModelResponse(BaseModel):
id: str
user_id: str
filename: str
meta: dict
created_at: int # timestamp in epoch
class FileForm(BaseModel):
id: str
filename: str
meta: dict = {}
class FilesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([File])
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
file = FileModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"created_at": int(time.time()),
}
)
try:
result = File.create(**file.model_dump())
if result:
return file
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_file_by_id(self, id: str) -> Optional[FileModel]:
try:
file = File.get(File.id == id)
return FileModel(**model_to_dict(file))
except:
return None
def get_files(self) -> List[FileModel]:
return [FileModel(**model_to_dict(file)) for file in File.select()]
def delete_file_by_id(self, id: str) -> bool:
try:
query = File.delete().where((File.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
def delete_all_files(self) -> bool:
try:
query = File.delete()
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Files = FilesTable(DB)
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Functions DB Schema
####################
class Function(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
type = TextField()
content = TextField()
meta = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class FunctionMeta(BaseModel):
description: Optional[str] = None
class FunctionModel(BaseModel):
id: str
user_id: str
name: str
type: str
content: str
meta: FunctionMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class FunctionResponse(BaseModel):
id: str
user_id: str
type: str
name: str
meta: FunctionMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class FunctionForm(BaseModel):
id: str
name: str
content: str
meta: FunctionMeta
class FunctionsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Function])
def insert_new_function(
self, user_id: str, type: str, form_data: FunctionForm
) -> Optional[FunctionModel]:
function = FunctionModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"type": type,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Function.create(**function.model_dump())
if result:
return function
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
try:
function = Function.get(Function.id == id)
return FunctionModel(**model_to_dict(function))
except:
return None
def get_functions(self) -> List[FunctionModel]:
return [
FunctionModel(**model_to_dict(function)) for function in Function.select()
]
def get_functions_by_type(self, type: str) -> List[FunctionModel]:
return [
FunctionModel(**model_to_dict(function))
for function in Function.select().where(Function.type == type)
]
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
try:
query = Function.update(
**updated,
updated_at=int(time.time()),
).where(Function.id == id)
query.execute()
function = Function.get(Function.id == id)
return FunctionModel(**model_to_dict(function))
except:
return None
def delete_function_by_id(self, id: str) -> bool:
try:
query = Function.delete().where((Function.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Functions = FunctionsTable(DB)
...@@ -26,6 +26,7 @@ class User(Model): ...@@ -26,6 +26,7 @@ class User(Model):
api_key = CharField(null=True, unique=True) api_key = CharField(null=True, unique=True)
settings = JSONField(null=True) settings = JSONField(null=True)
info = JSONField(null=True)
oauth_sub = TextField(null=True, unique=True) oauth_sub = TextField(null=True, unique=True)
...@@ -52,6 +53,7 @@ class UserModel(BaseModel): ...@@ -52,6 +53,7 @@ class UserModel(BaseModel):
api_key: Optional[str] = None api_key: Optional[str] = None
settings: Optional[UserSettings] = None settings: Optional[UserSettings] = None
info: Optional[dict] = None
oauth_sub: Optional[str] = None oauth_sub: Optional[str] = None
......
...@@ -2,6 +2,7 @@ import logging ...@@ -2,6 +2,7 @@ import logging
from fastapi import Request, UploadFile, File from fastapi import Request, UploadFile, File
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status
from fastapi.responses import Response
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
...@@ -35,6 +36,7 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES ...@@ -35,6 +36,7 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import ( from config import (
WEBUI_AUTH, WEBUI_AUTH,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER, WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
WEBUI_AUTH_TRUSTED_NAME_HEADER,
) )
router = APIRouter() router = APIRouter()
...@@ -45,7 +47,21 @@ router = APIRouter() ...@@ -45,7 +47,21 @@ router = APIRouter()
@router.get("/", response_model=UserResponse) @router.get("/", response_model=UserResponse)
async def get_session_user(user=Depends(get_current_user)): async def get_session_user(
request: Request, response: Response, user=Depends(get_current_user)
):
token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
)
return { return {
"id": user.id, "id": user.id,
"email": user.email, "email": user.email,
...@@ -106,17 +122,22 @@ async def update_password( ...@@ -106,17 +122,22 @@ async def update_password(
@router.post("/signin", response_model=SigninResponse) @router.post("/signin", response_model=SigninResponse)
async def signin(request: Request, form_data: SigninForm): async def signin(request: Request, response: Response, form_data: SigninForm):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers: if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER) raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower() trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
trusted_name = trusted_email
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
trusted_name = request.headers.get(
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
)
if not Users.get_user_by_email(trusted_email.lower()): if not Users.get_user_by_email(trusted_email.lower()):
await signup( await signup(
request, request,
SignupForm( SignupForm(
email=trusted_email, password=str(uuid.uuid4()), name=trusted_email email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
), ),
) )
user = Auths.authenticate_user_by_trusted_header(trusted_email) user = Auths.authenticate_user_by_trusted_header(trusted_email)
...@@ -145,6 +166,13 @@ async def signin(request: Request, form_data: SigninForm): ...@@ -145,6 +166,13 @@ async def signin(request: Request, form_data: SigninForm):
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN), expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
) )
# Set the cookie token
response.set_cookie(
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
)
return { return {
"token": token, "token": token,
"token_type": "Bearer", "token_type": "Bearer",
...@@ -164,7 +192,7 @@ async def signin(request: Request, form_data: SigninForm): ...@@ -164,7 +192,7 @@ async def signin(request: Request, form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse) @router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, form_data: SignupForm): async def signup(request: Request, response: Response, form_data: SignupForm):
if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH: if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
raise HTTPException( raise HTTPException(
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
...@@ -200,6 +228,13 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -200,6 +228,13 @@ async def signup(request: Request, form_data: SignupForm):
) )
# response.set_cookie(key='token', value=token, httponly=True) # response.set_cookie(key='token', value=token, httponly=True)
# Set the cookie token
response.set_cookie(
key="token",
value=token,
httponly=True, # Ensures the cookie is not accessible via JavaScript
)
if request.app.state.config.WEBHOOK_URL: if request.app.state.config.WEBHOOK_URL:
post_webhook( post_webhook(
request.app.state.config.WEBHOOK_URL, request.app.state.config.WEBHOOK_URL,
......
from fastapi import (
Depends,
FastAPI,
HTTPException,
status,
Request,
UploadFile,
File,
Form,
)
from datetime import datetime, timedelta
from typing import List, Union, Optional
from pathlib import Path
from fastapi import APIRouter
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
from pydantic import BaseModel
import json
from apps.webui.models.files import (
Files,
FileForm,
FileModel,
FileModelResponse,
)
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
from importlib import util
import os
import uuid
import os, shutil, logging, re
from config import SRC_LOG_LEVELS, UPLOAD_DIR
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# Upload File
############################
@router.post("/")
def upload_file(
file: UploadFile = File(...),
user=Depends(get_verified_user),
):
log.info(f"file.content_type: {file.content_type}")
try:
unsanitized_filename = file.filename
filename = os.path.basename(unsanitized_filename)
# replace filename with uuid
id = str(uuid.uuid4())
filename = f"{id}_{filename}"
file_path = f"{UPLOAD_DIR}/{filename}"
contents = file.file.read()
with open(file_path, "wb") as f:
f.write(contents)
f.close()
file = Files.insert_new_file(
user.id,
FileForm(
**{
"id": id,
"filename": filename,
"meta": {
"content_type": file.content_type,
"size": len(contents),
"path": file_path,
},
}
),
)
if file:
return file
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# List Files
############################
@router.get("/", response_model=List[FileModel])
async def list_files(user=Depends(get_verified_user)):
files = Files.get_files()
return files
############################
# Delete All Files
############################
@router.delete("/all")
async def delete_all_files(user=Depends(get_admin_user)):
result = Files.delete_all_files()
if result:
folder = f"{UPLOAD_DIR}"
try:
# Check if the directory exists
if os.path.exists(folder):
# Iterate over all the files and directories in the specified directory
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove the file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"The directory {folder} does not exist")
except Exception as e:
print(f"Failed to process the directory {folder}. Reason: {e}")
return {"message": "All files deleted successfully"}
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
)
############################
# Get File By Id
############################
@router.get("/{id}", response_model=Optional[FileModel])
async def get_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file:
return file
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# Get File Content By Id
############################
@router.get("/{id}/content", response_model=Optional[FileModel])
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file:
file_path = Path(file.meta["path"])
# Check if the file already exists in the cache
if file_path.is_file():
print(f"file_path: {file_path}")
return FileResponse(file_path)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# Delete File By Id
############################
@router.delete("/{id}")
async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
file = Files.get_file_by_id(id)
if file:
result = Files.delete_file_by_id(id)
if result:
return {"message": "File deleted successfully"}
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error deleting file"),
)
else:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.NOT_FOUND,
)
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.functions import (
Functions,
FunctionForm,
FunctionModel,
FunctionResponse,
)
from apps.webui.utils import load_function_module_by_id
from utils.utils import get_verified_user, get_admin_user
from constants import ERROR_MESSAGES
from importlib import util
import os
from pathlib import Path
from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
router = APIRouter()
############################
# GetFunctions
############################
@router.get("/", response_model=List[FunctionResponse])
async def get_functions(user=Depends(get_verified_user)):
return Functions.get_functions()
############################
# ExportFunctions
############################
@router.get("/export", response_model=List[FunctionModel])
async def get_functions(user=Depends(get_admin_user)):
return Functions.get_functions()
############################
# CreateNewFunction
############################
@router.post("/create", response_model=Optional[FunctionResponse])
async def create_new_function(
request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
function = Functions.get_function_by_id(form_data.id)
if function == None:
function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
try:
with open(function_path, "w") as function_file:
function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(form_data.id)
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[form_data.id] = function_module
function = Functions.insert_new_function(user.id, function_type, form_data)
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
function_cache_dir.mkdir(parents=True, exist_ok=True)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetFunctionById
############################
@router.get("/id/{id}", response_model=Optional[FunctionModel])
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
function = Functions.get_function_by_id(id)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateFunctionById
############################
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
):
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
try:
with open(function_path, "w") as function_file:
function_file.write(form_data.content)
function_module, function_type = load_function_module_by_id(id)
FUNCTIONS = request.app.state.FUNCTIONS
FUNCTIONS[id] = function_module
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
print(updated)
function = Functions.update_function_by_id(id, updated)
if function:
return function
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteFunctionById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_function_by_id(
request: Request, id: str, user=Depends(get_admin_user)
):
result = Functions.delete_function_by_id(id)
if result:
FUNCTIONS = request.app.state.FUNCTIONS
if id in FUNCTIONS:
del FUNCTIONS[id]
# delete the function file
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
os.remove(function_path)
return result
...@@ -15,8 +15,9 @@ from constants import ERROR_MESSAGES ...@@ -15,8 +15,9 @@ from constants import ERROR_MESSAGES
from importlib import util from importlib import util
import os import os
from pathlib import Path
from config import DATA_DIR from config import DATA_DIR, CACHE_DIR
TOOLS_DIR = f"{DATA_DIR}/tools" TOOLS_DIR = f"{DATA_DIR}/tools"
...@@ -79,6 +80,9 @@ async def create_new_toolkit( ...@@ -79,6 +80,9 @@ async def create_new_toolkit(
specs = get_tools_specs(TOOLS[form_data.id]) specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs) toolkit = Tools.insert_new_tool(user.id, form_data, specs)
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
tool_cache_dir.mkdir(parents=True, exist_ok=True)
if toolkit: if toolkit:
return toolkit return toolkit
else: else:
......
...@@ -115,6 +115,52 @@ async def update_user_settings_by_session_user( ...@@ -115,6 +115,52 @@ async def update_user_settings_by_session_user(
) )
############################
# GetUserInfoBySessionUser
############################
@router.get("/user/info", response_model=Optional[dict])
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
user = Users.get_user_by_id(user.id)
if user:
return user.info
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################
# UpdateUserInfoBySessionUser
############################
@router.post("/user/info/update", response_model=Optional[dict])
async def update_user_settings_by_session_user(
form_data: dict, user=Depends(get_verified_user)
):
user = Users.get_user_by_id(user.id)
if user:
if user.info is None:
user.info = {}
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
if user:
return user.info
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
############################ ############################
# GetUserById # GetUserById
############################ ############################
......
from importlib import util from importlib import util
import os import os
from config import TOOLS_DIR from config import TOOLS_DIR, FUNCTIONS_DIR
def load_toolkit_module_by_id(toolkit_id): def load_toolkit_module_by_id(toolkit_id):
...@@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id): ...@@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id):
# Move the file to the error folder # Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error") os.rename(toolkit_path, f"{toolkit_path}.error")
raise e raise e
def load_function_module_by_id(function_id):
function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py")
spec = util.spec_from_file_location(function_id, function_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Pipe"):
return module.Pipe(), "pipe"
elif hasattr(module, "Filter"):
return module.Filter(), "filter"
else:
raise Exception("No Function class found")
except Exception as e:
print(f"Error loading module: {function_id}")
# Move the file to the error folder
os.rename(function_path, f"{function_path}.error")
raise e
...@@ -294,6 +294,7 @@ WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true" ...@@ -294,6 +294,7 @@ WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get( WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None "WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
) )
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
JWT_EXPIRES_IN = PersistentConfig( JWT_EXPIRES_IN = PersistentConfig(
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1") "JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
) )
...@@ -505,6 +506,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools") ...@@ -505,6 +506,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True) Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Functions DIR
####################################
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# LITELLM_CONFIG # LITELLM_CONFIG
#################################### ####################################
...@@ -554,7 +563,14 @@ OLLAMA_API_BASE_URL = os.environ.get( ...@@ -554,7 +563,14 @@ OLLAMA_API_BASE_URL = os.environ.get(
) )
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "") OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
AIOHTTP_CLIENT_TIMEOUT = int(os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "300")) AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "300")
if AIOHTTP_CLIENT_TIMEOUT == "":
AIOHTTP_CLIENT_TIMEOUT = None
else:
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
K8S_FLAG = os.environ.get("K8S_FLAG", "") K8S_FLAG = os.environ.get("K8S_FLAG", "")
USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false") USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
...@@ -1034,6 +1050,18 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig( ...@@ -1034,6 +1050,18 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
os.getenv("RAG_WEB_SEARCH_ENGINE", ""), os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
) )
# You can provide a list of your own websites to filter after performing a web search.
# This ensures the highest level of safety and reliability of the information sources.
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
"rag.rag.web.search.domain.filter_list",
[
# "wikipedia.com",
# "wikimedia.org",
# "wikidata.org",
],
)
SEARXNG_QUERY_URL = PersistentConfig( SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL", "SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url", "rag.web.search.searxng_query_url",
...@@ -1139,6 +1167,30 @@ COMFYUI_BASE_URL = PersistentConfig( ...@@ -1139,6 +1167,30 @@ COMFYUI_BASE_URL = PersistentConfig(
os.getenv("COMFYUI_BASE_URL", ""), os.getenv("COMFYUI_BASE_URL", ""),
) )
COMFYUI_CFG_SCALE = PersistentConfig(
"COMFYUI_CFG_SCALE",
"image_generation.comfyui.cfg_scale",
os.getenv("COMFYUI_CFG_SCALE", ""),
)
COMFYUI_SAMPLER = PersistentConfig(
"COMFYUI_SAMPLER",
"image_generation.comfyui.sampler",
os.getenv("COMFYUI_SAMPLER", ""),
)
COMFYUI_SCHEDULER = PersistentConfig(
"COMFYUI_SCHEDULER",
"image_generation.comfyui.scheduler",
os.getenv("COMFYUI_SCHEDULER", ""),
)
COMFYUI_SD3 = PersistentConfig(
"COMFYUI_SD3",
"image_generation.comfyui.sd3",
os.environ.get("COMFYUI_SD3", "").lower() == "true",
)
IMAGES_OPENAI_API_BASE_URL = PersistentConfig( IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
"IMAGES_OPENAI_API_BASE_URL", "IMAGES_OPENAI_API_BASE_URL",
"image_generation.openai.api_base_url", "image_generation.openai.api_base_url",
......
...@@ -15,9 +15,11 @@ import requests ...@@ -15,9 +15,11 @@ import requests
import mimetypes import mimetypes
import shutil import shutil
import os import os
import uuid
import inspect import inspect
import asyncio import asyncio
from fastapi.concurrency import run_in_threadpool
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
...@@ -46,16 +48,19 @@ from apps.openai.main import ( ...@@ -46,16 +48,19 @@ from apps.openai.main import (
from apps.audio.main import app as audio_app from apps.audio.main import app as audio_app
from apps.images.main import app as images_app from apps.images.main import app as images_app
from apps.rag.main import app as rag_app from apps.rag.main import app as rag_app
from apps.webui.main import app as webui_app from apps.webui.main import app as webui_app, get_pipe_models
from pydantic import BaseModel from pydantic import BaseModel
from typing import List, Optional from typing import List, Optional, Iterator, Generator, Union
from apps.webui.models.auths import Auths from apps.webui.models.auths import Auths
from apps.webui.models.models import Models, ModelModel from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools from apps.webui.models.tools import Tools
from apps.webui.models.functions import Functions
from apps.webui.models.users import Users from apps.webui.models.users import Users
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
from apps.webui.utils import load_toolkit_module_by_id from apps.webui.utils import load_toolkit_module_by_id
from utils.misc import parse_duration from utils.misc import parse_duration
...@@ -72,7 +77,11 @@ from utils.task import ( ...@@ -72,7 +77,11 @@ from utils.task import (
search_query_generation_template, search_query_generation_template,
tools_function_calling_generation_template, tools_function_calling_generation_template,
) )
from utils.misc import get_last_user_message, add_or_update_system_message from utils.misc import (
get_last_user_message,
add_or_update_system_message,
stream_message_template,
)
from apps.rag.utils import get_rag_context, rag_template from apps.rag.utils import get_rag_context, rag_template
...@@ -85,6 +94,7 @@ from config import ( ...@@ -85,6 +94,7 @@ from config import (
VERSION, VERSION,
CHANGELOG, CHANGELOG,
FRONTEND_BUILD_DIR, FRONTEND_BUILD_DIR,
UPLOAD_DIR,
CACHE_DIR, CACHE_DIR,
STATIC_DIR, STATIC_DIR,
ENABLE_OPENAI_API, ENABLE_OPENAI_API,
...@@ -184,7 +194,16 @@ app.state.MODELS = {} ...@@ -184,7 +194,16 @@ app.state.MODELS = {}
origins = ["*"] origins = ["*"]
async def get_function_call_response(messages, tool_id, template, task_model_id, user): ##################################
#
# ChatCompletion Middleware
#
##################################
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user
):
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs) content = tools_function_calling_generation_template(template, tools_specs)
...@@ -222,9 +241,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, ...@@ -222,9 +241,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
response = None response = None
try: try:
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion( response = await generate_ollama_chat_completion(payload, user=user)
OpenAIChatCompletionForm(**payload), user=user
)
else: else:
response = await generate_openai_chat_completion(payload, user=user) response = await generate_openai_chat_completion(payload, user=user)
...@@ -247,6 +264,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, ...@@ -247,6 +264,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
result = json.loads(content) result = json.loads(content)
print(result) print(result)
citation = None
# Call the function # Call the function
if "name" in result: if "name" in result:
if tool_id in webui_app.state.TOOLS: if tool_id in webui_app.state.TOOLS:
...@@ -255,76 +273,170 @@ async def get_function_call_response(messages, tool_id, template, task_model_id, ...@@ -255,76 +273,170 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
toolkit_module = load_toolkit_module_by_id(tool_id) toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module webui_app.state.TOOLS[tool_id] = toolkit_module
file_handler = False
# check if toolkit_module has file_handler self variable
if hasattr(toolkit_module, "file_handler"):
file_handler = True
print("file_handler: ", file_handler)
function = getattr(toolkit_module, result["name"]) function = getattr(toolkit_module, result["name"])
function_result = None function_result = None
try: try:
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(function) sig = inspect.signature(function)
# Check if '__user__' is a parameter of the function params = result["parameters"]
if "__user__" in sig.parameters: if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included # Call the function with the '__user__' parameter included
function_result = function( params = {
**{ **params,
**result["parameters"], "__user__": {
"__user__": { "id": user.id,
"id": user.id, "email": user.email,
"email": user.email, "name": user.name,
"name": user.name, "role": user.role,
"role": user.role, },
}, }
}
) if "__messages__" in sig.parameters:
# Call the function with the '__messages__' parameter included
params = {
**params,
"__messages__": messages,
}
if "__files__" in sig.parameters:
# Call the function with the '__files__' parameter included
params = {
**params,
"__files__": files,
}
if "__model__" in sig.parameters:
# Call the function with the '__model__' parameter included
params = {
**params,
"__model__": model,
}
if "__id__" in sig.parameters:
# Call the function with the '__id__' parameter included
params = {
**params,
"__id__": tool_id,
}
if inspect.iscoroutinefunction(function):
function_result = await function(**params)
else: else:
# Call the function without modifying the parameters function_result = function(**params)
function_result = function(**result["parameters"])
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
citation = {
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
"document": [function_result],
"metadata": [{"source": result["name"]}],
}
except Exception as e: except Exception as e:
print(e) print(e)
# Add the function result to the system prompt # Add the function result to the system prompt
if function_result: if function_result is not None:
return function_result return function_result, citation, file_handler
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
return None return None, None, False
class ChatCompletionMiddleware(BaseHTTPMiddleware): class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
return_citations = False data_items = []
if request.method == "POST" and ( show_citations = False
"/ollama/api/chat" in request.url.path citations = []
or "/chat/completions" in request.url.path
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
): ):
log.debug(f"request.url.path: {request.url.path}") log.debug(f"request.url.path: {request.url.path}")
# Read the original request body # Read the original request body
body = await request.body() body = await request.body()
# Decode body to string
body_str = body.decode("utf-8") body_str = body.decode("utf-8")
# Parse string to JSON
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
user = get_current_user( user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization")) request,
get_http_authorization_cred(request.headers.get("Authorization")),
) )
# Flag to skip RAG completions if file_handler is present in tools/functions
# Remove the citations from the body skip_files = False
return_citations = data.get("citations", False) if data.get("citations"):
if "citations" in data: show_citations = True
del data["citations"] del data["citations"]
# Set the task model model_id = data["model"]
task_model_id = data["model"] if model_id not in app.state.MODELS:
if task_model_id not in app.state.MODELS:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
) )
model = app.state.MODELS[model_id]
# Check if the model has any filters
if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
try:
if hasattr(function_module, "inlet"):
inlet = function_module.inlet
if inspect.iscoroutinefunction(inlet):
data = await inlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
else:
data = inlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Check if the user has a custom task model # Set the task model
# If the user has a custom task model, use that model task_model_id = data["model"]
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama": if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if ( if (
app.state.config.TASK_MODEL app.state.config.TASK_MODEL
...@@ -347,55 +459,71 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -347,55 +459,71 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
for tool_id in data["tool_ids"]: for tool_id in data["tool_ids"]:
print(tool_id) print(tool_id)
try: try:
response = await get_function_call_response( response, citation, file_handler = (
messages=data["messages"], await get_function_call_response(
tool_id=tool_id, messages=data["messages"],
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, files=data.get("files", []),
task_model_id=task_model_id, tool_id=tool_id,
user=user, template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
) )
if response: print(file_handler)
if isinstance(response, str):
context += ("\n" if context != "" else "") + response context += ("\n" if context != "" else "") + response
if citation:
citations.append(citation)
show_citations = True
if file_handler:
skip_files = True
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
del data["tool_ids"] del data["tool_ids"]
print(f"tool_context: {context}") print(f"tool_context: {context}")
# If docs field is present, generate RAG completions # If files field is present, generate RAG completions
if "docs" in data: # If skip_files is True, skip the RAG completions
data = {**data} if "files" in data:
rag_context, citations = get_rag_context( if not skip_files:
docs=data["docs"], data = {**data}
messages=data["messages"], rag_context, rag_citations = get_rag_context(
embedding_function=rag_app.state.EMBEDDING_FUNCTION, files=data["files"],
k=rag_app.state.config.TOP_K, messages=data["messages"],
reranking_function=rag_app.state.sentence_transformer_rf, embedding_function=rag_app.state.EMBEDDING_FUNCTION,
r=rag_app.state.config.RELEVANCE_THRESHOLD, k=rag_app.state.config.TOP_K,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, reranking_function=rag_app.state.sentence_transformer_rf,
) r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if rag_context: if rag_citations:
context += ("\n" if context != "" else "") + rag_context citations.extend(rag_citations)
del data["docs"] del data["files"]
log.debug(f"rag_context: {rag_context}, citations: {citations}") if show_citations and len(citations) > 0:
data_items.append({"citations": citations})
if context != "": if context != "":
system_prompt = rag_template( system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt rag_app.state.config.RAG_TEMPLATE, context, prompt
) )
print(system_prompt) print(system_prompt)
data["messages"] = add_or_update_system_message( data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"] system_prompt, data["messages"]
) )
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
request._body = modified_body_bytes request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length # Set custom header to ensure content-length matches new body length
...@@ -408,43 +536,54 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): ...@@ -408,43 +536,54 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
], ],
] ]
response = await call_next(request) response = await call_next(request)
if return_citations:
# Inject the citations into the response
if isinstance(response, StreamingResponse): if isinstance(response, StreamingResponse):
# If it's a streaming response, inject it as SSE event or NDJSON line # If it's a streaming response, inject it as SSE event or NDJSON line
content_type = response.headers.get("Content-Type") content_type = response.headers.get("Content-Type")
if "text/event-stream" in content_type: if "text/event-stream" in content_type:
return StreamingResponse( return StreamingResponse(
self.openai_stream_wrapper(response.body_iterator, citations), self.openai_stream_wrapper(response.body_iterator, data_items),
) )
if "application/x-ndjson" in content_type: if "application/x-ndjson" in content_type:
return StreamingResponse( return StreamingResponse(
self.ollama_stream_wrapper(response.body_iterator, citations), self.ollama_stream_wrapper(response.body_iterator, data_items),
) )
else:
return response
# If it's not a chat completion request, just pass it through
response = await call_next(request)
return response return response
async def _receive(self, body: bytes): async def _receive(self, body: bytes):
return {"type": "http.request", "body": body, "more_body": False} return {"type": "http.request", "body": body, "more_body": False}
async def openai_stream_wrapper(self, original_generator, citations): async def openai_stream_wrapper(self, original_generator, data_items):
yield f"data: {json.dumps({'citations': citations})}\n\n" for item in data_items:
yield f"data: {json.dumps(item)}\n\n"
async for data in original_generator: async for data in original_generator:
yield data yield data
async def ollama_stream_wrapper(self, original_generator, citations): async def ollama_stream_wrapper(self, original_generator, data_items):
yield f"{json.dumps({'citations': citations})}\n" for item in data_items:
yield f"{json.dumps(item)}\n"
async for data in original_generator: async for data in original_generator:
yield data yield data
app.add_middleware(ChatCompletionMiddleware) app.add_middleware(ChatCompletionMiddleware)
##################################
#
# Pipeline Middleware
#
##################################
def filter_pipeline(payload, user): def filter_pipeline(payload, user):
user = {"id": user.id, "name": user.name, "role": user.role} user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"] model_id = payload["model"]
filters = [ filters = [
model model
...@@ -532,7 +671,8 @@ class PipelineMiddleware(BaseHTTPMiddleware): ...@@ -532,7 +671,8 @@ class PipelineMiddleware(BaseHTTPMiddleware):
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
user = get_current_user( user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization")) request,
get_http_authorization_cred(request.headers.get("Authorization")),
) )
try: try:
...@@ -600,7 +740,6 @@ async def update_embedding_function(request: Request, call_next): ...@@ -600,7 +740,6 @@ async def update_embedding_function(request: Request, call_next):
app.mount("/ws", socket_app) app.mount("/ws", socket_app)
app.mount("/ollama", ollama_app) app.mount("/ollama", ollama_app)
app.mount("/openai", openai_app) app.mount("/openai", openai_app)
...@@ -614,17 +753,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION ...@@ -614,17 +753,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
async def get_all_models(): async def get_all_models():
pipe_models = []
openai_models = [] openai_models = []
ollama_models = [] ollama_models = []
pipe_models = await get_pipe_models()
if app.state.config.ENABLE_OPENAI_API: if app.state.config.ENABLE_OPENAI_API:
openai_models = await get_openai_models() openai_models = await get_openai_models()
openai_models = openai_models["data"] openai_models = openai_models["data"]
if app.state.config.ENABLE_OLLAMA_API: if app.state.config.ENABLE_OLLAMA_API:
ollama_models = await get_ollama_models() ollama_models = await get_ollama_models()
ollama_models = [ ollama_models = [
{ {
"id": model["model"], "id": model["model"],
...@@ -637,9 +777,9 @@ async def get_all_models(): ...@@ -637,9 +777,9 @@ async def get_all_models():
for model in ollama_models["models"] for model in ollama_models["models"]
] ]
models = openai_models + ollama_models models = pipe_models + openai_models + ollama_models
custom_models = Models.get_all_models()
custom_models = Models.get_all_models()
for custom_model in custom_models: for custom_model in custom_models:
if custom_model.base_model_id == None: if custom_model.base_model_id == None:
for model in models: for model in models:
...@@ -702,6 +842,253 @@ async def get_models(user=Depends(get_verified_user)): ...@@ -702,6 +842,253 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models} return {"data": models}
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
pipe = model.get("pipe")
if pipe:
form_data["user"] = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
async def job():
pipe_id = form_data["model"]
if "." in pipe_id:
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
print(pipe_id)
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
if form_data["stream"]:
async def stream_content():
if inspect.iscoroutinefunction(pipe):
res = await pipe(body=form_data)
else:
res = pipe(body=form_data)
if isinstance(res, str):
message = stream_message_template(form_data["model"], res)
yield f"data: {json.dumps(message)}\n\n"
if isinstance(res, Iterator):
for line in res:
if isinstance(line, BaseModel):
line = line.model_dump_json()
line = f"data: {line}"
try:
line = line.decode("utf-8")
except:
pass
if line.startswith("data:"):
yield f"{line}\n\n"
else:
line = stream_message_template(form_data["model"], line)
yield f"data: {json.dumps(line)}\n\n"
if isinstance(res, str) or isinstance(res, Generator):
finish_message = {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"delta": {},
"logprobs": None,
"finish_reason": "stop",
}
],
}
yield f"data: {json.dumps(finish_message)}\n\n"
yield f"data: [DONE]"
return StreamingResponse(
stream_content(), media_type="text/event-stream"
)
else:
if inspect.iscoroutinefunction(pipe):
res = await pipe(body=form_data)
else:
res = pipe(body=form_data)
if isinstance(res, dict):
return res
elif isinstance(res, BaseModel):
return res.model_dump()
else:
message = ""
if isinstance(res, str):
message = res
if isinstance(res, Generator):
for stream in res:
message = f"{message}{stream}"
return {
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
"object": "chat.completion",
"created": int(time.time()),
"model": form_data["model"],
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": message,
},
"logprobs": None,
"finish_reason": "stop",
}
],
}
return await job()
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(form_data, user=user)
else:
return await generate_openai_chat_completion(form_data, user=user)
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
# Check if the model has any filters
if "info" in model and "meta" in model["info"]:
for filter_id in model["info"]["meta"].get("filterIds", []):
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type = load_function_module_by_id(
filter_id
)
webui_app.state.FUNCTIONS[filter_id] = function_module
try:
if hasattr(function_module, "outlet"):
outlet = function_module.outlet
if inspect.iscoroutinefunction(outlet):
data = await outlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
else:
data = outlet(
data,
{
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
return data
##################################
#
# Task Endpoints
#
##################################
# TODO: Refactor task API endpoints below into a separate file
@app.get("/api/task/config") @app.get("/api/task/config")
async def get_task_config(user=Depends(get_verified_user)): async def get_task_config(user=Depends(get_verified_user)):
return { return {
...@@ -780,7 +1167,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -780,7 +1167,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
content = title_generation_template( content = title_generation_template(
template, form_data["prompt"], user.model_dump() template,
form_data["prompt"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
) )
payload = { payload = {
...@@ -792,7 +1184,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -792,7 +1184,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"title": True, "title": True,
} }
print(payload) log.debug(payload)
try: try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user)
...@@ -803,9 +1195,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): ...@@ -803,9 +1195,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
) )
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion( return await generate_ollama_chat_completion(payload, user=user)
OpenAIChatCompletionForm(**payload), user=user
)
else: else:
return await generate_openai_chat_completion(payload, user=user) return await generate_openai_chat_completion(payload, user=user)
...@@ -846,7 +1236,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -846,7 +1236,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
content = search_query_generation_template( content = search_query_generation_template(
template, form_data["prompt"], user.model_dump() template, form_data["prompt"], {"name": user.name}
) )
payload = { payload = {
...@@ -868,9 +1258,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) ...@@ -868,9 +1258,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
) )
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion( return await generate_ollama_chat_completion(payload, user=user)
OpenAIChatCompletionForm(**payload), user=user
)
else: else:
return await generate_openai_chat_completion(payload, user=user) return await generate_openai_chat_completion(payload, user=user)
...@@ -909,7 +1297,12 @@ Message: """{{prompt}}""" ...@@ -909,7 +1297,12 @@ Message: """{{prompt}}"""
''' '''
content = title_generation_template( content = title_generation_template(
template, form_data["prompt"], user.model_dump() template,
form_data["prompt"],
{
"name": user.name,
"location": user.info.get("location") if user.info else None,
},
) )
payload = { payload = {
...@@ -921,7 +1314,7 @@ Message: """{{prompt}}""" ...@@ -921,7 +1314,7 @@ Message: """{{prompt}}"""
"task": True, "task": True,
} }
print(payload) log.debug(payload)
try: try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user)
...@@ -932,9 +1325,7 @@ Message: """{{prompt}}""" ...@@ -932,9 +1325,7 @@ Message: """{{prompt}}"""
) )
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion( return await generate_ollama_chat_completion(payload, user=user)
OpenAIChatCompletionForm(**payload), user=user
)
else: else:
return await generate_openai_chat_completion(payload, user=user) return await generate_openai_chat_completion(payload, user=user)
...@@ -967,8 +1358,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ ...@@ -967,8 +1358,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try: try:
context = await get_function_call_response( context, citation, file_handler = await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user form_data["messages"],
form_data.get("files", []),
form_data["tool_id"],
template,
model_id,
user,
) )
return context return context
except Exception as e: except Exception as e:
...@@ -978,94 +1374,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_ ...@@ -978,94 +1374,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
) )
@app.post("/api/chat/completions") ##################################
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)): #
model_id = form_data["model"] # Pipelines Endpoints
if model_id not in app.state.MODELS: #
raise HTTPException( ##################################
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**form_data), user=user
)
else:
return await generate_openai_chat_completion(form_data, user=user)
# TODO: Refactor pipelines API endpoints below into a separate file
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
model_id = data["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
print(model_id)
if model_id in app.state.MODELS:
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {"id": user.id, "name": user.name, "role": user.role},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
return data
@app.get("/api/pipelines/list") @app.get("/api/pipelines/list")
...@@ -1388,6 +1704,13 @@ async def update_pipeline_valves( ...@@ -1388,6 +1704,13 @@ async def update_pipeline_valves(
) )
##################################
#
# Config Endpoints
#
##################################
@app.get("/api/config") @app.get("/api/config")
async def get_app_config(): async def get_app_config():
# Checking and Handling the Absence of 'ui' in CONFIG_DATA # Checking and Handling the Absence of 'ui' in CONFIG_DATA
...@@ -1457,6 +1780,9 @@ async def update_model_filter_config( ...@@ -1457,6 +1780,9 @@ async def update_model_filter_config(
} }
# TODO: webhook endpoint should be under config endpoints
@app.get("/api/webhook") @app.get("/api/webhook")
async def get_webhook_url(user=Depends(get_admin_user)): async def get_webhook_url(user=Depends(get_admin_user)):
return { return {
......
...@@ -3,7 +3,9 @@ import hashlib ...@@ -3,7 +3,9 @@ import hashlib
import json import json
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional, List from typing import Optional, List, Tuple
import uuid
import time
def get_last_user_message(messages: List[dict]) -> str: def get_last_user_message(messages: List[dict]) -> str:
...@@ -28,6 +30,21 @@ def get_last_assistant_message(messages: List[dict]) -> str: ...@@ -28,6 +30,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
return None return None
def get_system_message(messages: List[dict]) -> dict:
for message in messages:
if message["role"] == "system":
return message
return None
def remove_system_message(messages: List[dict]) -> List[dict]:
return [message for message in messages if message["role"] != "system"]
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
return get_system_message(messages), remove_system_message(messages)
def add_or_update_system_message(content: str, messages: List[dict]): def add_or_update_system_message(content: str, messages: List[dict]):
""" """
Adds a new system message at the beginning of the messages list Adds a new system message at the beginning of the messages list
...@@ -47,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]): ...@@ -47,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
return messages return messages
def stream_message_template(model: str, message: str):
return {
"id": f"{model}-{str(uuid.uuid4())}",
"object": "chat.completion.chunk",
"created": int(time.time()),
"model": model,
"choices": [
{
"index": 0,
"delta": {"content": message},
"logprobs": None,
"finish_reason": None,
}
],
}
def get_gravatar_url(email): def get_gravatar_url(email):
# Trim leading and trailing whitespace from # Trim leading and trailing whitespace from
# an email address and force all characters # an email address and force all characters
......
...@@ -6,24 +6,34 @@ from typing import Optional ...@@ -6,24 +6,34 @@ from typing import Optional
def prompt_template( def prompt_template(
template: str, user_name: str = None, current_location: str = None template: str, user_name: str = None, user_location: str = None
) -> str: ) -> str:
# Get the current date # Get the current date
current_date = datetime.now() current_date = datetime.now()
# Format the date to YYYY-MM-DD # Format the date to YYYY-MM-DD
formatted_date = current_date.strftime("%Y-%m-%d") formatted_date = current_date.strftime("%Y-%m-%d")
formatted_time = current_date.strftime("%I:%M:%S %p")
# Replace {{CURRENT_DATE}} in the template with the formatted date
template = template.replace("{{CURRENT_DATE}}", formatted_date) template = template.replace("{{CURRENT_DATE}}", formatted_date)
template = template.replace("{{CURRENT_TIME}}", formatted_time)
template = template.replace(
"{{CURRENT_DATETIME}}", f"{formatted_date} {formatted_time}"
)
if user_name: if user_name:
# Replace {{USER_NAME}} in the template with the user's name # Replace {{USER_NAME}} in the template with the user's name
template = template.replace("{{USER_NAME}}", user_name) template = template.replace("{{USER_NAME}}", user_name)
else:
# Replace {{USER_NAME}} in the template with "Unknown"
template = template.replace("{{USER_NAME}}", "Unknown")
if current_location: if user_location:
# Replace {{CURRENT_LOCATION}} in the template with the current location # Replace {{USER_LOCATION}} in the template with the current location
template = template.replace("{{CURRENT_LOCATION}}", current_location) template = template.replace("{{USER_LOCATION}}", user_location)
else:
# Replace {{USER_LOCATION}} in the template with "Unknown"
template = template.replace("{{USER_LOCATION}}", "Unknown")
return template return template
...@@ -61,7 +71,7 @@ def title_generation_template( ...@@ -61,7 +71,7 @@ def title_generation_template(
template = prompt_template( template = prompt_template(
template, template,
**( **(
{"user_name": user.get("name"), "current_location": user.get("location")} {"user_name": user.get("name"), "user_location": user.get("location")}
if user if user
else {} else {}
), ),
...@@ -104,7 +114,7 @@ def search_query_generation_template( ...@@ -104,7 +114,7 @@ def search_query_generation_template(
template = prompt_template( template = prompt_template(
template, template,
**( **(
{"user_name": user.get("name"), "current_location": user.get("location")} {"user_name": user.get("name"), "user_location": user.get("location")}
if user if user
else {} else {}
), ),
......
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends from fastapi import HTTPException, status, Depends, Request
from apps.webui.models.users import Users from apps.webui.models.users import Users
...@@ -24,7 +24,7 @@ ALGORITHM = "HS256" ...@@ -24,7 +24,7 @@ ALGORITHM = "HS256"
# Auth Utils # Auth Utils
############## ##############
bearer_security = HTTPBearer() bearer_security = HTTPBearer(auto_error=False)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
...@@ -75,13 +75,26 @@ def get_http_authorization_cred(auth_header: str): ...@@ -75,13 +75,26 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user( def get_current_user(
request: Request,
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
): ):
token = None
if auth_token is not None:
token = auth_token.credentials
if token is None and "token" in request.cookies:
token = request.cookies.get("token")
if token is None:
raise HTTPException(status_code=403, detail="Not authenticated")
# auth by api key # auth by api key
if auth_token.credentials.startswith("sk-"): if token.startswith("sk-"):
return get_current_user_by_api_key(auth_token.credentials) return get_current_user_by_api_key(token)
# auth by jwt token # auth by jwt token
data = decode_token(auth_token.credentials) data = decode_token(token)
if data != None and "id" in data: if data != None and "id" in data:
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
if user is None: if user is None:
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.3.4", "version": "0.3.5",
"lockfileVersion": 3, "lockfileVersion": 3,
"requires": true, "requires": true,
"packages": { "packages": {
"": { "": {
"name": "open-webui", "name": "open-webui",
"version": "0.3.4", "version": "0.3.5",
"dependencies": { "dependencies": {
"@codemirror/lang-javascript": "^6.2.2", "@codemirror/lang-javascript": "^6.2.2",
"@codemirror/lang-python": "^6.1.6", "@codemirror/lang-python": "^6.1.6",
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
"async": "^3.2.5", "async": "^3.2.5",
"bits-ui": "^0.19.7", "bits-ui": "^0.19.7",
"codemirror": "^6.0.1", "codemirror": "^6.0.1",
"crc-32": "^1.2.2",
"dayjs": "^1.11.10", "dayjs": "^1.11.10",
"eventsource-parser": "^1.1.2", "eventsource-parser": "^1.1.2",
"file-saver": "^2.0.5", "file-saver": "^2.0.5",
...@@ -28,11 +29,12 @@ ...@@ -28,11 +29,12 @@
"katex": "^0.16.9", "katex": "^0.16.9",
"marked": "^9.1.0", "marked": "^9.1.0",
"mermaid": "^10.9.1", "mermaid": "^10.9.1",
"pyodide": "^0.26.0-alpha.4", "pyodide": "^0.26.1",
"socket.io-client": "^4.7.5", "socket.io-client": "^4.2.0",
"sortablejs": "^1.15.2", "sortablejs": "^1.15.2",
"svelte-sonner": "^0.3.19", "svelte-sonner": "^0.3.19",
"tippy.js": "^6.3.7", "tippy.js": "^6.3.7",
"turndown": "^7.2.0",
"uuid": "^9.0.1" "uuid": "^9.0.1"
}, },
"devDependencies": { "devDependencies": {
...@@ -999,6 +1001,11 @@ ...@@ -999,6 +1001,11 @@
"svelte": ">=3 <5" "svelte": ">=3 <5"
} }
}, },
"node_modules/@mixmark-io/domino": {
"version": "2.2.0",
"resolved": "https://registry.npmjs.org/@mixmark-io/domino/-/domino-2.2.0.tgz",
"integrity": "sha512-Y28PR25bHXUg88kCV7nivXrP2Nj2RueZ3/l/jdx6J9f8J4nsEGcgX0Qe6lt7Pa+J79+kPiJU3LguR6O/6zrLOw=="
},
"node_modules/@nodelib/fs.scandir": { "node_modules/@nodelib/fs.scandir": {
"version": "2.1.5", "version": "2.1.5",
"resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz",
...@@ -2266,11 +2273,6 @@ ...@@ -2266,11 +2273,6 @@
"dev": true, "dev": true,
"optional": true "optional": true
}, },
"node_modules/base-64": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/base-64/-/base-64-1.0.0.tgz",
"integrity": "sha512-kwDPIFCGx0NZHog36dj+tHiwP4QMzsZ3AgMViUBKI0+V5n4U0ufTCUMhnQ04diaRI8EX/QcPfql7zlhZ7j4zgg=="
},
"node_modules/base64-js": { "node_modules/base64-js": {
"version": "1.5.1", "version": "1.5.1",
"resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz", "resolved": "https://registry.npmjs.org/base64-js/-/base64-js-1.5.1.tgz",
...@@ -3063,6 +3065,17 @@ ...@@ -3063,6 +3065,17 @@
"layout-base": "^1.0.0" "layout-base": "^1.0.0"
} }
}, },
"node_modules/crc-32": {
"version": "1.2.2",
"resolved": "https://registry.npmjs.org/crc-32/-/crc-32-1.2.2.tgz",
"integrity": "sha512-ROmzCKrTnOwybPcJApAA6WBWij23HVfGVNKqqrZpuyZOHqK2CwHSvpGuyt/UNNvaIjEd8X5IFGp4Mh+Ie1IHJQ==",
"bin": {
"crc32": "bin/crc32.njs"
},
"engines": {
"node": ">=0.8"
}
},
"node_modules/crelt": { "node_modules/crelt": {
"version": "1.0.6", "version": "1.0.6",
"resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz", "resolved": "https://registry.npmjs.org/crelt/-/crelt-1.0.6.tgz",
...@@ -3984,37 +3997,17 @@ ...@@ -3984,37 +3997,17 @@
} }
}, },
"node_modules/engine.io-client": { "node_modules/engine.io-client": {
"version": "6.5.3", "version": "6.5.4",
"resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.5.3.tgz", "resolved": "https://registry.npmjs.org/engine.io-client/-/engine.io-client-6.5.4.tgz",
"integrity": "sha512-9Z0qLB0NIisTRt1DZ/8U2k12RJn8yls/nXMZLn+/N8hANT3TcYjKFKcwbw5zFQiN4NTde3TSY9zb79e1ij6j9Q==", "integrity": "sha512-GeZeeRjpD2qf49cZQ0Wvh/8NJNfeXkXXcoGh+F77oEAgo9gUHwT1fCRxSNU+YEEaysOJTnsFHmM5oAcPy4ntvQ==",
"dependencies": { "dependencies": {
"@socket.io/component-emitter": "~3.1.0", "@socket.io/component-emitter": "~3.1.0",
"debug": "~4.3.1", "debug": "~4.3.1",
"engine.io-parser": "~5.2.1", "engine.io-parser": "~5.2.1",
"ws": "~8.11.0", "ws": "~8.17.1",
"xmlhttprequest-ssl": "~2.0.0" "xmlhttprequest-ssl": "~2.0.0"
} }
}, },
"node_modules/engine.io-client/node_modules/ws": {
"version": "8.11.0",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.11.0.tgz",
"integrity": "sha512-HPG3wQd9sNQoT9xHyNCXoDUa+Xw/VevmY9FoHyQ+g+rrMn4j6FB4np7Z0OhdTgjx6MgQLK7jwSy1YecU1+4Asg==",
"engines": {
"node": ">=10.0.0"
},
"peerDependencies": {
"bufferutil": "^4.0.1",
"utf-8-validate": "^5.0.2"
},
"peerDependenciesMeta": {
"bufferutil": {
"optional": true
},
"utf-8-validate": {
"optional": true
}
}
},
"node_modules/engine.io-parser": { "node_modules/engine.io-parser": {
"version": "5.2.2", "version": "5.2.2",
"resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.2.tgz", "resolved": "https://registry.npmjs.org/engine.io-parser/-/engine.io-parser-5.2.2.tgz",
...@@ -7551,11 +7544,10 @@ ...@@ -7551,11 +7544,10 @@
} }
}, },
"node_modules/pyodide": { "node_modules/pyodide": {
"version": "0.26.0-alpha.4", "version": "0.26.1",
"resolved": "https://registry.npmjs.org/pyodide/-/pyodide-0.26.0-alpha.4.tgz", "resolved": "https://registry.npmjs.org/pyodide/-/pyodide-0.26.1.tgz",
"integrity": "sha512-Ixuczq99DwhQlE+Bt0RaS6Ln9MHSZOkbU6iN8azwaeorjHtr7ukaxh+FeTxViFrp2y+ITyKgmcobY+JnBPcULw==", "integrity": "sha512-P+Gm88nwZqY7uBgjbQH8CqqU6Ei/rDn7pS1t02sNZsbyLJMyE2OVXjgNuqVT3KqYWnyGREUN0DbBUCJqk8R0ew==",
"dependencies": { "dependencies": {
"base-64": "^1.0.0",
"ws": "^8.5.0" "ws": "^8.5.0"
}, },
"engines": { "engines": {
...@@ -9065,6 +9057,14 @@ ...@@ -9065,6 +9057,14 @@
"node": "*" "node": "*"
} }
}, },
"node_modules/turndown": {
"version": "7.2.0",
"resolved": "https://registry.npmjs.org/turndown/-/turndown-7.2.0.tgz",
"integrity": "sha512-eCZGBN4nNNqM9Owkv9HAtWRYfLA4h909E/WGAWWBpmB275ehNhZyk87/Tpvjbp0jjNl9XwCsbe6bm6CqFsgD+A==",
"dependencies": {
"@mixmark-io/domino": "^2.2.0"
}
},
"node_modules/tweetnacl": { "node_modules/tweetnacl": {
"version": "0.14.5", "version": "0.14.5",
"resolved": "https://registry.npmjs.org/tweetnacl/-/tweetnacl-0.14.5.tgz", "resolved": "https://registry.npmjs.org/tweetnacl/-/tweetnacl-0.14.5.tgz",
...@@ -10382,9 +10382,9 @@ ...@@ -10382,9 +10382,9 @@
"integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==" "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ=="
}, },
"node_modules/ws": { "node_modules/ws": {
"version": "8.17.0", "version": "8.17.1",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.17.0.tgz", "resolved": "https://registry.npmjs.org/ws/-/ws-8.17.1.tgz",
"integrity": "sha512-uJq6108EgZMAl20KagGkzCKfMEjxmKvZHG7Tlq0Z6nOky7YF7aq4mOx6xK8TJ/i1LeK4Qus7INktacctDgY8Ow==", "integrity": "sha512-6XQFvXTkbfUOZOKKILFG1PDK2NDQs4azKQl26T0YS5CxqWLgXajbPZ+h4gZekJyRqFU8pvnbAbbs/3TgRPy+GQ==",
"engines": { "engines": {
"node": ">=10.0.0" "node": ">=10.0.0"
}, },
......
{ {
"name": "open-webui", "name": "open-webui",
"version": "0.3.4", "version": "0.3.5",
"private": true, "private": true,
"scripts": { "scripts": {
"dev": "npm run pyodide:fetch && vite dev --host", "dev": "npm run pyodide:fetch && vite dev --host",
...@@ -56,6 +56,7 @@ ...@@ -56,6 +56,7 @@
"async": "^3.2.5", "async": "^3.2.5",
"bits-ui": "^0.19.7", "bits-ui": "^0.19.7",
"codemirror": "^6.0.1", "codemirror": "^6.0.1",
"crc-32": "^1.2.2",
"dayjs": "^1.11.10", "dayjs": "^1.11.10",
"eventsource-parser": "^1.1.2", "eventsource-parser": "^1.1.2",
"file-saver": "^2.0.5", "file-saver": "^2.0.5",
...@@ -68,11 +69,12 @@ ...@@ -68,11 +69,12 @@
"katex": "^0.16.9", "katex": "^0.16.9",
"marked": "^9.1.0", "marked": "^9.1.0",
"mermaid": "^10.9.1", "mermaid": "^10.9.1",
"pyodide": "^0.26.0-alpha.4", "pyodide": "^0.26.1",
"socket.io-client": "^4.7.5", "socket.io-client": "^4.2.0",
"sortablejs": "^1.15.2", "sortablejs": "^1.15.2",
"svelte-sonner": "^0.3.19", "svelte-sonner": "^0.3.19",
"tippy.js": "^6.3.7", "tippy.js": "^6.3.7",
"turndown": "^7.2.0",
"uuid": "^9.0.1" "uuid": "^9.0.1"
} }
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment