Unverified Commit 92c98eda authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #1781 from open-webui/dev

0.1.122
parents 1092ee9c 85df019c
......@@ -25,11 +25,11 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
class Document(Model):
collection_name = CharField(unique=True)
name = CharField(unique=True)
title = CharField()
filename = CharField()
title = TextField()
filename = TextField()
content = TextField(null=True)
user_id = CharField()
timestamp = DateField()
timestamp = BigIntegerField()
class Meta:
database = DB
......
......@@ -20,7 +20,7 @@ class Modelfile(Model):
tag_name = CharField(unique=True)
user_id = CharField()
modelfile = TextField()
timestamp = DateField()
timestamp = BigIntegerField()
class Meta:
database = DB
......
......@@ -19,9 +19,9 @@ import json
class Prompt(Model):
command = CharField(unique=True)
user_id = CharField()
title = CharField()
title = TextField()
content = TextField()
timestamp = DateField()
timestamp = BigIntegerField()
class Meta:
database = DB
......
......@@ -35,7 +35,7 @@ class ChatIdTag(Model):
tag_name = CharField()
chat_id = CharField()
user_id = CharField()
timestamp = DateField()
timestamp = BigIntegerField()
class Meta:
database = DB
......
......@@ -18,8 +18,12 @@ class User(Model):
name = CharField()
email = CharField()
role = CharField()
profile_image_url = CharField()
timestamp = DateField()
profile_image_url = TextField()
last_active_at = BigIntegerField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
api_key = CharField(null=True, unique=True)
class Meta:
......@@ -32,7 +36,11 @@ class UserModel(BaseModel):
email: str
role: str = "pending"
profile_image_url: str
timestamp: int # timestamp in epoch
last_active_at: int # timestamp in epoch
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
api_key: Optional[str] = None
......@@ -73,7 +81,9 @@ class UsersTable:
"email": email,
"role": role,
"profile_image_url": profile_image_url,
"timestamp": int(time.time()),
"last_active_at": int(time.time()),
"created_at": int(time.time()),
"updated_at": int(time.time()),
}
)
result = User.create(**user.model_dump())
......@@ -137,6 +147,16 @@ class UsersTable:
except:
return None
def update_user_last_active_by_id(self, id: str) -> Optional[UserModel]:
try:
query = User.update(last_active_at=int(time.time())).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)
......
import logging
from fastapi import Request
from fastapi import Depends, HTTPException, status
......
......@@ -36,15 +36,49 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"])
router = APIRouter()
############################
# GetChats
# GetChatList
############################
@router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(
@router.get("/list", response_model=List[ChatTitleIdResponse])
async def get_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
return Chats.get_chat_list_by_user_id(user.id, skip, limit)
############################
# DeleteAllChats
############################
@router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
if (
user.role == "user"
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chats_by_user_id(user.id)
return result
############################
# GetUserChatList
############################
@router.get("/list/user/{user_id}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_user_id(
user_id: str, user=Depends(get_admin_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_list_by_user_id(user_id, skip, limit)
############################
......@@ -53,22 +87,22 @@ async def get_user_chats(
@router.get("/archived", response_model=List[ChatTitleIdResponse])
async def get_archived_user_chats(
async def get_archived_session_user_chat_list(
user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_archived_chat_lists_by_user_id(user.id, skip, limit)
return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
############################
# GetAllChats
# GetChats
############################
@router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(user=Depends(get_current_user)):
async def get_user_chats(user=Depends(get_current_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id)
for chat in Chats.get_chats_by_user_id(user.id)
]
......@@ -86,7 +120,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
)
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats()
for chat in Chats.get_chats()
]
......@@ -107,45 +141,6 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
)
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chats_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
chats = Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id)
return chats
############################
# GetChatById
############################
......@@ -193,10 +188,11 @@ async def update_chat_by_id(
@router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
if (
user.role == "user"
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
):
if user.role == "admin":
result = Chats.delete_chat_by_id(id)
return result
else:
if not request.app.state.USER_PERMISSIONS["chat"]["deletion"]:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
......@@ -303,6 +299,45 @@ async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
)
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chat_list_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
chats = Chats.get_chat_list_by_chat_ids(chat_ids, skip, limit)
if len(chats) == 0:
Tags.delete_tag_by_tag_name_and_user_id(tag_name, user.id)
return chats
############################
# GetChatTagsById
############################
......@@ -383,24 +418,3 @@ async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################
# DeleteAllChats
############################
@router.delete("/", response_model=bool)
async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
if (
user.role == "user"
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chats_by_user_id(user.id)
return result
from fastapi import APIRouter, UploadFile, File, Response
from fastapi import Depends, HTTPException, status
from peewee import SqliteDatabase
from starlette.responses import StreamingResponse, FileResponse
from pydantic import BaseModel
......@@ -7,7 +8,7 @@ from pydantic import BaseModel
from fpdf import FPDF
import markdown
from apps.web.internal.db import DB
from utils.utils import get_admin_user
from utils.misc import calculate_sha256, get_gravatar_url
......@@ -96,8 +97,13 @@ async def download_db(user=Depends(get_admin_user)):
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if not isinstance(DB, SqliteDatabase):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DB_NOT_SQLITE,
)
return FileResponse(
f"{DATA_DIR}/webui.db",
DB.database,
media_type="application/octet-stream",
filename="webui.db",
)
......@@ -71,6 +71,9 @@ except ImportError:
log.warning("dotenv not installed, skipping...")
WEBUI_NAME = os.environ.get("WEBUI_NAME", "Open WebUI")
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
WEBUI_FAVICON_URL = "https://openwebui.com/favicon.png"
####################################
......@@ -195,9 +198,6 @@ if CUSTOM_NAME:
except Exception as e:
log.exception(e)
pass
else:
if WEBUI_NAME != "Open WebUI":
WEBUI_NAME += " (Open WebUI)"
####################################
......@@ -220,7 +220,7 @@ Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
# Docs DIR
####################################
DOCS_DIR = f"{DATA_DIR}/docs"
DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
......@@ -375,8 +375,7 @@ USER_PERMISSIONS_CHAT_DELETION = (
USER_PERMISSIONS = {"chat": {"deletion": USER_PERMISSIONS_CHAT_DELETION}}
MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", "False").lower() == "true"
ENABLE_MODEL_FILTER = os.environ.get("ENABLE_MODEL_FILTER", "False").lower() == "true"
MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
......@@ -418,19 +417,57 @@ if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
####################################
CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
CHROMA_TENANT = os.environ.get("CHROMA_TENANT", chromadb.DEFAULT_TENANT)
CHROMA_DATABASE = os.environ.get("CHROMA_DATABASE", chromadb.DEFAULT_DATABASE)
CHROMA_HTTP_HOST = os.environ.get("CHROMA_HTTP_HOST", "")
CHROMA_HTTP_PORT = int(os.environ.get("CHROMA_HTTP_PORT", "8000"))
# Comma-separated list of header=value pairs
CHROMA_HTTP_HEADERS = os.environ.get("CHROMA_HTTP_HEADERS", "")
if CHROMA_HTTP_HEADERS:
CHROMA_HTTP_HEADERS = dict(
[pair.split("=") for pair in CHROMA_HTTP_HEADERS.split(",")]
)
else:
CHROMA_HTTP_HEADERS = None
CHROMA_HTTP_SSL = os.environ.get("CHROMA_HTTP_SSL", "false").lower() == "true"
# this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (sentence-transformers/all-MiniLM-L6-v2)
RAG_TOP_K = int(os.environ.get("RAG_TOP_K", "5"))
RAG_RELEVANCE_THRESHOLD = float(os.environ.get("RAG_RELEVANCE_THRESHOLD", "0.0"))
ENABLE_RAG_HYBRID_SEARCH = (
os.environ.get("ENABLE_RAG_HYBRID_SEARCH", "").lower() == "true"
)
RAG_EMBEDDING_ENGINE = os.environ.get("RAG_EMBEDDING_ENGINE", "")
PDF_EXTRACT_IMAGES = os.environ.get("PDF_EXTRACT_IMAGES", "False").lower() == "true"
RAG_EMBEDDING_MODEL = os.environ.get(
"RAG_EMBEDDING_MODEL", "sentence-transformers/all-MiniLM-L6-v2"
)
log.info(f"Embedding model set: {RAG_EMBEDDING_MODEL}"),
RAG_EMBEDDING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_EMBEDDING_MODEL_AUTO_UPDATE", "").lower() == "true"
)
RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_RERANKING_MODEL = os.environ.get("RAG_RERANKING_MODEL", "")
if not RAG_RERANKING_MODEL == "":
log.info(f"Reranking model set: {RAG_RERANKING_MODEL}"),
RAG_RERANKING_MODEL_AUTO_UPDATE = (
os.environ.get("RAG_RERANKING_MODEL_AUTO_UPDATE", "").lower() == "true"
)
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_RERANKING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
# device type embedding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
USE_CUDA = os.environ.get("USE_CUDA_DOCKER", "false")
......@@ -439,16 +476,28 @@ if USE_CUDA.lower() == "true":
else:
DEVICE_TYPE = "cpu"
CHROMA_CLIENT = chromadb.PersistentClient(
if CHROMA_HTTP_HOST != "":
CHROMA_CLIENT = chromadb.HttpClient(
host=CHROMA_HTTP_HOST,
port=CHROMA_HTTP_PORT,
headers=CHROMA_HTTP_HEADERS,
ssl=CHROMA_HTTP_SSL,
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
)
else:
CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
)
CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100
tenant=CHROMA_TENANT,
database=CHROMA_DATABASE,
)
CHUNK_SIZE = int(os.environ.get("CHUNK_SIZE", "1500"))
CHUNK_OVERLAP = int(os.environ.get("CHUNK_OVERLAP", "100"))
RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
DEFAULT_RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
[context]
</context>
......@@ -462,6 +511,8 @@ And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]"""
RAG_TEMPLATE = os.environ.get("RAG_TEMPLATE", DEFAULT_RAG_TEMPLATE)
RAG_OPENAI_API_BASE_URL = os.getenv("RAG_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL)
RAG_OPENAI_API_KEY = os.getenv("RAG_OPENAI_API_KEY", OPENAI_API_KEY)
......@@ -480,18 +531,25 @@ WHISPER_MODEL_AUTO_UPDATE = (
# Images
####################################
IMAGE_GENERATION_ENGINE = os.getenv("IMAGE_GENERATION_ENGINE", "")
ENABLE_IMAGE_GENERATION = (
os.environ.get("ENABLE_IMAGE_GENERATION", "").lower() == "true"
)
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
COMFYUI_BASE_URL = os.getenv("COMFYUI_BASE_URL", "")
IMAGES_OPENAI_API_BASE_URL = os.getenv(
"IMAGES_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL
)
IMAGES_OPENAI_API_KEY = os.getenv("IMAGES_OPENAI_API_KEY", OPENAI_API_KEY)
IMAGE_SIZE = os.getenv("IMAGE_SIZE", "512x512")
IMAGE_STEPS = int(os.getenv("IMAGE_STEPS", 50))
IMAGE_GENERATION_MODEL = os.getenv("IMAGE_GENERATION_MODEL", "")
####################################
# Audio
......@@ -504,7 +562,17 @@ AUDIO_OPENAI_API_KEY = os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY)
# LiteLLM
####################################
ENABLE_LITELLM = os.environ.get("ENABLE_LITELLM", "True").lower() == "true"
LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365"))
if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535:
raise ValueError("Invalid port number for LITELLM_PROXY_PORT")
LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1")
####################################
# Database
####################################
DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db")
......@@ -69,3 +69,5 @@ class ERROR_MESSAGES(str, Enum):
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
DB_NOT_SQLITE = "This feature is only available when running with SQLite databases."
......@@ -47,7 +47,8 @@ from config import (
FRONTEND_BUILD_DIR,
CACHE_DIR,
STATIC_DIR,
MODEL_FILTER_ENABLED,
ENABLE_LITELLM,
ENABLE_MODEL_FILTER,
MODEL_FILTER_LIST,
GLOBAL_LOG_LEVEL,
SRC_LOG_LEVELS,
......@@ -89,7 +90,7 @@ https://github.com/open-webui/open-webui
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.WEBHOOK_URL = WEBHOOK_URL
......@@ -116,15 +117,14 @@ class RAGMiddleware(BaseHTTPMiddleware):
if "docs" in data:
data = {**data}
data["messages"] = rag_messages(
data["docs"],
data["messages"],
rag_app.state.RAG_TEMPLATE,
rag_app.state.TOP_K,
rag_app.state.RAG_EMBEDDING_ENGINE,
rag_app.state.RAG_EMBEDDING_MODEL,
rag_app.state.sentence_transformer_ef,
rag_app.state.OPENAI_API_KEY,
rag_app.state.OPENAI_API_BASE_URL,
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.RAG_TEMPLATE,
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.ENABLE_RAG_HYBRID_SEARCH,
)
del data["docs"]
......@@ -176,6 +176,7 @@ async def check_url(request: Request, call_next):
@app.on_event("startup")
async def on_startup():
if ENABLE_LITELLM:
asyncio.create_task(start_litellm_background())
......@@ -215,7 +216,7 @@ async def get_app_config():
@app.get("/api/config/model/filter")
async def get_model_filter_config(user=Depends(get_admin_user)):
return {
"enabled": app.state.MODEL_FILTER_ENABLED,
"enabled": app.state.ENABLE_MODEL_FILTER,
"models": app.state.MODEL_FILTER_LIST,
}
......@@ -229,20 +230,20 @@ class ModelFilterConfigForm(BaseModel):
async def update_model_filter_config(
form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
):
app.state.MODEL_FILTER_ENABLED = form_data.enabled
app.state.ENABLE_MODEL_FILTER = form_data.enabled
app.state.MODEL_FILTER_LIST = form_data.models
ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
ollama_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
openai_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
litellm_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
litellm_app.state.ENABLE_MODEL_FILTER = app.state.ENABLE_MODEL_FILTER
litellm_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
return {
"enabled": app.state.MODEL_FILTER_ENABLED,
"enabled": app.state.ENABLE_MODEL_FILTER,
"models": app.state.MODEL_FILTER_LIST,
}
......@@ -326,4 +327,5 @@ app.mount(
@app.on_event("shutdown")
async def shutdown_event():
if ENABLE_LITELLM:
await shutdown_litellm_background()
......@@ -15,6 +15,8 @@ requests
aiohttp
peewee
peewee-migrate
psycopg2-binary
pymysql
bcrypt
litellm==1.35.17
......
......@@ -6,6 +6,7 @@ cd "$SCRIPT_DIR" || exit
KEY_FILE=.webui_secret_key
PORT="${PORT:-8080}"
HOST="${HOST:-0.0.0.0}"
if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
echo "No WEBUI_SECRET_KEY provided"
......@@ -29,4 +30,4 @@ if [ "$USE_CUDA_DOCKER" = "true" ]; then
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
fi
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host 0.0.0.0 --port "$PORT" --forwarded-allow-ips '*'
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host "$HOST" --port "$PORT" --forwarded-allow-ips '*'
......@@ -7,7 +7,7 @@ SET "SCRIPT_DIR=%~dp0"
cd /d "%SCRIPT_DIR%" || exit /b
SET "KEY_FILE=.webui_secret_key"
SET "PORT=%PORT:8080%"
IF "%PORT%"=="" SET PORT=8080
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
......
......@@ -89,6 +89,8 @@ def get_current_user(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
Users.update_user_last_active_by_id(user.id)
return user
else:
raise HTTPException(
......@@ -99,11 +101,15 @@ def get_current_user(
def get_current_user_by_api_key(api_key: str):
user = Users.get_user_by_api_key(api_key)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
else:
Users.update_user_last_active_by_id(user.id)
return user
......
......@@ -2,7 +2,12 @@
echo "Warning: This will remove all containers and volumes, including persistent data. Do you want to continue? [Y/N]"
read ans
if [ "$ans" == "Y" ] || [ "$ans" == "y" ]; then
command docker-compose 2>/dev/null
if [ "$?" == "0" ]; then
docker-compose down -v
else
docker compose down -v
fi
else
echo "Operation cancelled."
fi
import { defineConfig } from 'cypress';
export default defineConfig({
e2e: {
baseUrl: 'http://localhost:8080'
},
video: true
});
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
/// <reference path="../support/index.d.ts" />
// These tests run through the chat flow.
describe('Settings', () => {
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
after(() => {
// eslint-disable-next-line cypress/no-unnecessary-waiting
cy.wait(2000);
});
beforeEach(() => {
// Login as the admin user
cy.loginAdmin();
// Visit the home page
cy.visit('/');
});
context('Ollama', () => {
it('user can select a model', () => {
// Click on the model selector
cy.get('button[aria-label="Select a model"]').click();
// Select the first model
cy.get('div[role="option"][data-value]').first().click();
});
it('user can perform text chat', () => {
// Click on the model selector
cy.get('button[aria-label="Select a model"]').click();
// Select the first model
cy.get('div[role="option"][data-value]').first().click();
// Type a message
cy.get('#chat-textarea').type('Hi, what can you do? A single sentence only please.', {
force: true
});
// Send the message
cy.get('button[type="submit"]').click();
// User's message should be visible
cy.get('.chat-user').should('exist');
// Wait for the response
cy.get('.chat-assistant', { timeout: 120_000 }) // .chat-assistant is created after the first token is received
.find('div[aria-label="Generation Info"]', { timeout: 120_000 }) // Generation Info is created after the stop token is received
.should('exist');
});
});
});
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
/// <reference path="../support/index.d.ts" />
import { adminUser } from '../support/e2e';
// These tests assume the following defaults:
// 1. No users exist in the database or that the test admin user is an admin
// 2. Language is set to English
// 3. The default role for new users is 'pending'
describe('Registration and Login', () => {
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
after(() => {
// eslint-disable-next-line cypress/no-unnecessary-waiting
cy.wait(2000);
});
beforeEach(() => {
cy.visit('/');
});
it('should register a new user as pending', () => {
const userName = `Test User - ${Date.now()}`;
const userEmail = `cypress-${Date.now()}@example.com`;
// Toggle from sign in to sign up
cy.contains('Sign up').click();
// Fill out the form
cy.get('input[autocomplete="name"]').type(userName);
cy.get('input[autocomplete="email"]').type(userEmail);
cy.get('input[type="password"]').type('password');
// Submit the form
cy.get('button[type="submit"]').click();
// Wait until the user is redirected to the home page
cy.contains(userName);
// Expect the user to be pending
cy.contains('Check Again');
});
it('can login with the admin user', () => {
// Fill out the form
cy.get('input[autocomplete="email"]').type(adminUser.email);
cy.get('input[type="password"]').type(adminUser.password);
// Submit the form
cy.get('button[type="submit"]').click();
// Wait until the user is redirected to the home page
cy.contains(adminUser.name);
// Dismiss the changelog dialog if it is visible
cy.getAllLocalStorage().then((ls) => {
if (!ls['version']) {
cy.get('button').contains("Okay, Let's Go!").click();
}
});
});
});
// eslint-disable-next-line @typescript-eslint/triple-slash-reference
/// <reference path="../support/index.d.ts" />
import { adminUser } from '../support/e2e';
// These tests run through the various settings pages, ensuring that the user can interact with them as expected
describe('Settings', () => {
// Wait for 2 seconds after all tests to fix an issue with Cypress's video recording missing the last few frames
after(() => {
// eslint-disable-next-line cypress/no-unnecessary-waiting
cy.wait(2000);
});
beforeEach(() => {
// Login as the admin user
cy.loginAdmin();
// Visit the home page
cy.visit('/');
// Open the sidebar if it is not already open
cy.get('[aria-label="Open sidebar"]').then(() => {
cy.get('button[id="sidebar-toggle-button"]').click();
});
// Click on the profile link
cy.get('button').contains(adminUser.name).click();
// Click on the settings link
cy.get('button').contains('Settings').click();
});
context('General', () => {
it('user can open the General modal and hit save', () => {
cy.get('button').contains('General').click();
cy.get('button').contains('Save').click();
});
});
context('Connections', () => {
it('user can open the Connections modal and hit save', () => {
cy.get('button').contains('Connections').click();
cy.get('button').contains('Save').click();
});
});
context('Models', () => {
it('user can open the Models modal', () => {
cy.get('button').contains('Models').click();
});
});
context('Interface', () => {
it('user can open the Interface modal and hit save', () => {
cy.get('button').contains('Interface').click();
cy.get('button').contains('Save').click();
});
});
context('Audio', () => {
it('user can open the Audio modal and hit save', () => {
cy.get('button').contains('Audio').click();
cy.get('button').contains('Save').click();
});
});
context('Images', () => {
it('user can open the Images modal and hit save', () => {
cy.get('button').contains('Images').click();
// Currently fails because the backend requires a valid URL
// cy.get('button').contains('Save').click();
});
});
context('Chats', () => {
it('user can open the Chats modal', () => {
cy.get('button').contains('Chats').click();
});
});
context('Account', () => {
it('user can open the Account modal and hit save', () => {
cy.get('button').contains('Account').click();
cy.get('button').contains('Save').click();
});
});
context('About', () => {
it('user can open the About modal', () => {
cy.get('button').contains('About').click();
});
});
});
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment