Commit ca8fd8af authored by Jannik Streidl's avatar Jannik Streidl
Browse files

Merge branch 'dockerfile-optimisation' of...

Merge branch 'dockerfile-optimisation' of https://github.com/jannikstdl/open-webui into dockerfile-optimisation
parents d0d01c95 f669c0e7
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# use build args in the docker build commmand with --build-arg="BUILDARG=true" # use build args in the docker build commmand with --build-arg="BUILDARG=true"
ARG USE_CUDA=false ARG USE_CUDA=false
ARG USE_OLLAMA=false ARG USE_OLLAMA=false
# Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default) # Tested with cu117 for CUDA 11 and cu121 for CUDA 12 (default)
ARG USE_CUDA_VER=cu121 ARG USE_CUDA_VER=cu121
# any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers # any sentence transformer model; models to use can be found at https://huggingface.co/models?library=sentence-transformers
# Leaderboard: https://huggingface.co/spaces/mteb/leaderboard # Leaderboard: https://huggingface.co/spaces/mteb/leaderboard
......
...@@ -81,6 +81,12 @@ async def check_url(request: Request, call_next): ...@@ -81,6 +81,12 @@ async def check_url(request: Request, call_next):
return response return response
@app.head("/")
@app.get("/")
async def get_status():
return {"status": True}
@app.get("/urls") @app.get("/urls")
async def get_ollama_api_urls(user=Depends(get_admin_user)): async def get_ollama_api_urls(user=Depends(get_admin_user)):
return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS} return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
......
from peewee import * from peewee import *
from peewee_migrate import Router
from config import SRC_LOG_LEVELS, DATA_DIR from config import SRC_LOG_LEVELS, DATA_DIR
import os import os
import logging import logging
...@@ -16,4 +17,6 @@ else: ...@@ -16,4 +17,6 @@ else:
DB = SqliteDatabase(f"{DATA_DIR}/webui.db") DB = SqliteDatabase(f"{DATA_DIR}/webui.db")
DB.connect() router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log)
router.run()
DB.connect(reuse_if_open=True)
"""Peewee migrations -- 001_initial_schema.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Auth(pw.Model):
id = pw.CharField(max_length=255, unique=True)
email = pw.CharField(max_length=255)
password = pw.CharField(max_length=255)
active = pw.BooleanField()
class Meta:
table_name = "auth"
@migrator.create_model
class Chat(pw.Model):
id = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.CharField()
chat = pw.TextField()
timestamp = pw.DateField()
class Meta:
table_name = "chat"
@migrator.create_model
class ChatIdTag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
tag_name = pw.CharField(max_length=255)
chat_id = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
timestamp = pw.DateField()
class Meta:
table_name = "chatidtag"
@migrator.create_model
class Document(pw.Model):
id = pw.AutoField()
collection_name = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255, unique=True)
title = pw.CharField()
filename = pw.CharField()
content = pw.TextField(null=True)
user_id = pw.CharField(max_length=255)
timestamp = pw.DateField()
class Meta:
table_name = "document"
@migrator.create_model
class Modelfile(pw.Model):
id = pw.AutoField()
tag_name = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
modelfile = pw.TextField()
timestamp = pw.DateField()
class Meta:
table_name = "modelfile"
@migrator.create_model
class Prompt(pw.Model):
id = pw.AutoField()
command = pw.CharField(max_length=255, unique=True)
user_id = pw.CharField(max_length=255)
title = pw.CharField()
content = pw.TextField()
timestamp = pw.DateField()
class Meta:
table_name = "prompt"
@migrator.create_model
class Tag(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
user_id = pw.CharField(max_length=255)
data = pw.TextField(null=True)
class Meta:
table_name = "tag"
@migrator.create_model
class User(pw.Model):
id = pw.CharField(max_length=255, unique=True)
name = pw.CharField(max_length=255)
email = pw.CharField(max_length=255)
role = pw.CharField(max_length=255)
profile_image_url = pw.CharField(max_length=255)
timestamp = pw.DateField()
class Meta:
table_name = "user"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("user")
migrator.remove_model("tag")
migrator.remove_model("prompt")
migrator.remove_model("modelfile")
migrator.remove_model("document")
migrator.remove_model("chatidtag")
migrator.remove_model("chat")
migrator.remove_model("auth")
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"chat", share_id=pw.CharField(max_length=255, null=True, unique=True)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("chat", "share_id")
"""Peewee migrations -- 002_add_local_sharing.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
migrator.add_fields(
"user", api_key=pw.CharField(max_length=255, null=True, unique=True)
)
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_fields("user", "api_key")
# Database Migrations
This directory contains all the database migrations for the web app.
Migrations are done using the [`peewee-migrate`](https://github.com/klen/peewee_migrate) library.
Migrations are automatically ran at app startup.
## Creating a migration
Have you made a change to the schema of an existing model?
You will need to create a migration file to ensure that existing databases are updated for backwards compatibility.
1. Have a database file (`webui.db`) that has the old schema prior to any of your changes.
2. Make your changes to the models.
3. From the `backend` directory, run the following command:
```bash
pw_migrate create --auto --auto-source apps.web.models --database sqlite:///${SQLITE_DB} --directory apps/web/internal/migrations ${MIGRATION_NAME}
```
- `$SQLITE_DB` should be the path to the database file.
- `$MIGRATION_NAME` should be a descriptive name for the migration.
4. The migration file will be created in the `apps/web/internal/migrations` directory.
...@@ -20,6 +20,7 @@ from config import ( ...@@ -20,6 +20,7 @@ from config import (
ENABLE_SIGNUP, ENABLE_SIGNUP,
USER_PERMISSIONS, USER_PERMISSIONS,
WEBHOOK_URL, WEBHOOK_URL,
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
) )
app = FastAPI() app = FastAPI()
...@@ -34,7 +35,7 @@ app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS ...@@ -34,7 +35,7 @@ app.state.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE app.state.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
app.state.USER_PERMISSIONS = USER_PERMISSIONS app.state.USER_PERMISSIONS = USER_PERMISSIONS
app.state.WEBHOOK_URL = WEBHOOK_URL app.state.WEBHOOK_URL = WEBHOOK_URL
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
......
...@@ -47,6 +47,10 @@ class Token(BaseModel): ...@@ -47,6 +47,10 @@ class Token(BaseModel):
token_type: str token_type: str
class ApiKey(BaseModel):
api_key: Optional[str] = None
class UserResponse(BaseModel): class UserResponse(BaseModel):
id: str id: str
email: str email: str
...@@ -123,6 +127,28 @@ class AuthsTable: ...@@ -123,6 +127,28 @@ class AuthsTable:
except: except:
return None return None
def authenticate_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_api_key: {api_key}")
# if no api_key, return None
if not api_key:
return None
try:
user = Users.get_user_by_api_key(api_key)
return user if user else None
except:
return False
def authenticate_user_by_trusted_header(self, email: str) -> Optional[UserModel]:
log.info(f"authenticate_user_by_trusted_header: {email}")
try:
auth = Auth.get(Auth.email == email, Auth.active == True)
if auth:
user = Users.get_user_by_id(auth.id)
return user
except:
return None
def update_user_password_by_id(self, id: str, new_password: str) -> bool: def update_user_password_by_id(self, id: str, new_password: str) -> bool:
try: try:
query = Auth.update(password=new_password).where(Auth.id == id) query = Auth.update(password=new_password).where(Auth.id == id)
......
...@@ -20,6 +20,7 @@ class Chat(Model): ...@@ -20,6 +20,7 @@ class Chat(Model):
title = CharField() title = CharField()
chat = TextField() # Save Chat JSON as Text chat = TextField() # Save Chat JSON as Text
timestamp = DateField() timestamp = DateField()
share_id = CharField(null=True, unique=True)
class Meta: class Meta:
database = DB database = DB
...@@ -31,6 +32,7 @@ class ChatModel(BaseModel): ...@@ -31,6 +32,7 @@ class ChatModel(BaseModel):
title: str title: str
chat: str chat: str
timestamp: int # timestamp in epoch timestamp: int # timestamp in epoch
share_id: Optional[str] = None
#################### ####################
...@@ -52,6 +54,7 @@ class ChatResponse(BaseModel): ...@@ -52,6 +54,7 @@ class ChatResponse(BaseModel):
title: str title: str
chat: dict chat: dict
timestamp: int # timestamp in epoch timestamp: int # timestamp in epoch
share_id: Optional[str] = None # id of the chat to be shared
class ChatTitleIdResponse(BaseModel): class ChatTitleIdResponse(BaseModel):
...@@ -95,6 +98,71 @@ class ChatTable: ...@@ -95,6 +98,71 @@ class ChatTable:
except: except:
return None return None
def insert_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
# Get the existing chat to share
chat = Chat.get(Chat.id == chat_id)
# Check if the chat is already shared
if chat.share_id:
return self.get_chat_by_id_and_user_id(chat.share_id, "shared")
# Create a new chat with the same data, but with a new ID
shared_chat = ChatModel(
**{
"id": str(uuid.uuid4()),
"user_id": f"shared-{chat_id}",
"title": chat.title,
"chat": chat.chat,
"timestamp": int(time.time()),
}
)
shared_result = Chat.create(**shared_chat.model_dump())
# Update the original chat with the share_id
result = (
Chat.update(share_id=shared_chat.id).where(Chat.id == chat_id).execute()
)
return shared_chat if (shared_result and result) else None
def update_shared_chat_by_chat_id(self, chat_id: str) -> Optional[ChatModel]:
try:
print("update_shared_chat_by_id")
chat = Chat.get(Chat.id == chat_id)
print(chat)
query = Chat.update(
title=chat.title,
chat=chat.chat,
).where(Chat.id == chat.share_id)
query.execute()
chat = Chat.get(Chat.id == chat.share_id)
return ChatModel(**model_to_dict(chat))
except:
return None
def delete_shared_chat_by_chat_id(self, chat_id: str) -> bool:
try:
query = Chat.delete().where(Chat.user_id == f"shared-{chat_id}")
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
def update_chat_share_id_by_id(
self, id: str, share_id: Optional[str]
) -> Optional[ChatModel]:
try:
query = Chat.update(
share_id=share_id,
).where(Chat.id == id)
query.execute()
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chat_lists_by_user_id( def get_chat_lists_by_user_id(
self, user_id: str, skip: int = 0, limit: int = 50 self, user_id: str, skip: int = 0, limit: int = 50
) -> List[ChatModel]: ) -> List[ChatModel]:
...@@ -131,6 +199,13 @@ class ChatTable: ...@@ -131,6 +199,13 @@ class ChatTable:
.order_by(Chat.timestamp.desc()) .order_by(Chat.timestamp.desc())
] ]
def get_chat_by_id(self, id: str) -> Optional[ChatModel]:
try:
chat = Chat.get(Chat.id == id)
return ChatModel(**model_to_dict(chat))
except:
return None
def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]: def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
try: try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id) chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
...@@ -149,12 +224,15 @@ class ChatTable: ...@@ -149,12 +224,15 @@ class ChatTable:
query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id)) query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed. query.execute() # Remove the rows, return number of rows removed.
return True return True and self.delete_shared_chat_by_chat_id(id)
except: except:
return False return False
def delete_chats_by_user_id(self, user_id: str) -> bool: def delete_chats_by_user_id(self, user_id: str) -> bool:
try: try:
self.delete_shared_chats_by_user_id(user_id)
query = Chat.delete().where(Chat.user_id == user_id) query = Chat.delete().where(Chat.user_id == user_id)
query.execute() # Remove the rows, return number of rows removed. query.execute() # Remove the rows, return number of rows removed.
...@@ -162,5 +240,19 @@ class ChatTable: ...@@ -162,5 +240,19 @@ class ChatTable:
except: except:
return False return False
def delete_shared_chats_by_user_id(self, user_id: str) -> bool:
try:
shared_chat_ids = [
f"shared-{chat.id}"
for chat in Chat.select().where(Chat.user_id == user_id)
]
query = Chat.delete().where(Chat.user_id << shared_chat_ids)
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Chats = ChatTable(DB) Chats = ChatTable(DB)
...@@ -20,6 +20,7 @@ class User(Model): ...@@ -20,6 +20,7 @@ class User(Model):
role = CharField() role = CharField()
profile_image_url = CharField() profile_image_url = CharField()
timestamp = DateField() timestamp = DateField()
api_key = CharField(null=True, unique=True)
class Meta: class Meta:
database = DB database = DB
...@@ -32,6 +33,7 @@ class UserModel(BaseModel): ...@@ -32,6 +33,7 @@ class UserModel(BaseModel):
role: str = "pending" role: str = "pending"
profile_image_url: str = "/user.png" profile_image_url: str = "/user.png"
timestamp: int # timestamp in epoch timestamp: int # timestamp in epoch
api_key: Optional[str] = None
#################### ####################
...@@ -82,6 +84,13 @@ class UsersTable: ...@@ -82,6 +84,13 @@ class UsersTable:
except: except:
return None return None
def get_user_by_api_key(self, api_key: str) -> Optional[UserModel]:
try:
user = User.get(User.api_key == api_key)
return UserModel(**model_to_dict(user))
except:
return None
def get_user_by_email(self, email: str) -> Optional[UserModel]: def get_user_by_email(self, email: str) -> Optional[UserModel]:
try: try:
user = User.get(User.email == email) user = User.get(User.email == email)
...@@ -149,5 +158,21 @@ class UsersTable: ...@@ -149,5 +158,21 @@ class UsersTable:
except: except:
return False return False
def update_user_api_key_by_id(self, id: str, api_key: str) -> str:
try:
query = User.update(api_key=api_key).where(User.id == id)
result = query.execute()
return True if result == 1 else False
except:
return False
def get_user_api_key_by_id(self, id: str) -> Optional[str]:
try:
user = User.get(User.id == id)
return user.api_key
except:
return None
Users = UsersTable(DB) Users = UsersTable(DB)
from fastapi import Response, Request from fastapi import Request
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union
from fastapi import APIRouter, status from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import time
import uuid
import re import re
import uuid
from apps.web.models.auths import ( from apps.web.models.auths import (
SigninForm, SigninForm,
...@@ -17,6 +14,7 @@ from apps.web.models.auths import ( ...@@ -17,6 +14,7 @@ from apps.web.models.auths import (
UserResponse, UserResponse,
SigninResponse, SigninResponse,
Auths, Auths,
ApiKey,
) )
from apps.web.models.users import Users from apps.web.models.users import Users
...@@ -25,10 +23,12 @@ from utils.utils import ( ...@@ -25,10 +23,12 @@ from utils.utils import (
get_current_user, get_current_user,
get_admin_user, get_admin_user,
create_token, create_token,
create_api_key,
) )
from utils.misc import parse_duration, validate_email_format from utils.misc import parse_duration, validate_email_format
from utils.webhook import post_webhook from utils.webhook import post_webhook
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from config import WEBUI_AUTH_TRUSTED_EMAIL_HEADER
router = APIRouter() router = APIRouter()
...@@ -79,6 +79,8 @@ async def update_profile( ...@@ -79,6 +79,8 @@ async def update_profile(
async def update_password( async def update_password(
form_data: UpdatePasswordForm, session_user=Depends(get_current_user) form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
): ):
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
if session_user: if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password) user = Auths.authenticate_user(session_user.email, form_data.password)
...@@ -98,7 +100,22 @@ async def update_password( ...@@ -98,7 +100,22 @@ async def update_password(
@router.post("/signin", response_model=SigninResponse) @router.post("/signin", response_model=SigninResponse)
async def signin(request: Request, form_data: SigninForm): async def signin(request: Request, form_data: SigninForm):
user = Auths.authenticate_user(form_data.email.lower(), form_data.password) if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
if not Users.get_user_by_email(trusted_email.lower()):
await signup(
request,
SignupForm(
email=trusted_email, password=str(uuid.uuid4()), name=trusted_email
),
)
user = Auths.authenticate_user_by_trusted_header(trusted_email)
else:
user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user: if user:
token = create_token( token = create_token(
data={"id": user.id}, data={"id": user.id},
...@@ -249,3 +266,40 @@ async def update_token_expires_duration( ...@@ -249,3 +266,40 @@ async def update_token_expires_duration(
return request.app.state.JWT_EXPIRES_IN return request.app.state.JWT_EXPIRES_IN
else: else:
return request.app.state.JWT_EXPIRES_IN return request.app.state.JWT_EXPIRES_IN
############################
# API Key
############################
# create api key
@router.post("/api_key", response_model=ApiKey)
async def create_api_key_(user=Depends(get_current_user)):
api_key = create_api_key()
success = Users.update_user_api_key_by_id(user.id, api_key)
if success:
return {
"api_key": api_key,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_API_KEY_ERROR)
# delete api key
@router.delete("/api_key", response_model=bool)
async def delete_api_key(user=Depends(get_current_user)):
success = Users.update_user_api_key_by_id(user.id, None)
return success
# get api key
@router.get("/api_key", response_model=ApiKey)
async def get_api_key(user=Depends(get_current_user)):
api_key = Users.get_user_api_key_by_id(user.id)
if api_key:
return {
"api_key": api_key,
}
else:
raise HTTPException(404, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
...@@ -189,6 +189,78 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_ ...@@ -189,6 +189,78 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
return result return result
############################
# ShareChatById
############################
@router.post("/{id}/share", response_model=Optional[ChatResponse])
async def share_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if chat.share_id:
shared_chat = Chats.update_shared_chat_by_chat_id(chat.id)
return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
)
shared_chat = Chats.insert_shared_chat_by_chat_id(chat.id)
if not shared_chat:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DEFAULT(),
)
return ChatResponse(
**{**shared_chat.model_dump(), "chat": json.loads(shared_chat.chat)}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# DeletedSharedChatById
############################
@router.delete("/{id}/share", response_model=Optional[bool])
async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
if not chat.share_id:
return False
result = Chats.delete_shared_chat_by_chat_id(id)
update_result = Chats.update_chat_share_id_by_id(id, None)
return result and update_result != None
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# GetSharedChatById
############################
@router.get("/share/{share_id}", response_model=Optional[ChatResponse])
async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id(share_id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################ ############################
# GetChatTagsById # GetChatTagsById
############################ ############################
......
...@@ -367,6 +367,9 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100") ...@@ -367,6 +367,9 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
#################################### ####################################
WEBUI_AUTH = True WEBUI_AUTH = True
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
)
#################################### ####################################
# WEBUI_SECRET_KEY # WEBUI_SECRET_KEY
......
...@@ -20,6 +20,7 @@ class ERROR_MESSAGES(str, Enum): ...@@ -20,6 +20,7 @@ class ERROR_MESSAGES(str, Enum):
ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now." ENV_VAR_NOT_FOUND = "Required environment variable not found. Terminating now."
CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance." CREATE_USER_ERROR = "Oops! Something went wrong while creating your account. Please try again later. If the issue persists, contact support for assistance."
DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot." DELETE_USER_ERROR = "Oops! Something went wrong. We encountered an issue while trying to delete the user. Please give it another shot."
EMAIL_MISMATCH = "Uh-oh! This email does not match the email your provider is registered with. Please check your email and try again."
EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew." EMAIL_TAKEN = "Uh-oh! This email is already registered. Sign in with your existing account or choose another email to start anew."
USERNAME_TAKEN = ( USERNAME_TAKEN = (
"Uh-oh! This username is already registered. Please choose another username." "Uh-oh! This username is already registered. Please choose another username."
...@@ -36,6 +37,7 @@ class ERROR_MESSAGES(str, Enum): ...@@ -36,6 +37,7 @@ class ERROR_MESSAGES(str, Enum):
INVALID_PASSWORD = ( INVALID_PASSWORD = (
"The password provided is incorrect. Please check for typos and try again." "The password provided is incorrect. Please check for typos and try again."
) )
INVALID_TRUSTED_HEADER = "Your provider has not provided a trusted header. Please contact your administrator for assistance."
UNAUTHORIZED = "401 Unauthorized" UNAUTHORIZED = "401 Unauthorized"
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED = ( ACTION_PROHIBITED = (
...@@ -58,7 +60,8 @@ class ERROR_MESSAGES(str, Enum): ...@@ -58,7 +60,8 @@ class ERROR_MESSAGES(str, Enum):
RATE_LIMIT_EXCEEDED = "API rate limit exceeded" RATE_LIMIT_EXCEEDED = "API rate limit exceeded"
MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found" MODEL_NOT_FOUND = lambda name="": f"Model '{name}' was not found"
OPENAI_NOT_FOUND = lambda name="": f"OpenAI API was not found" OPENAI_NOT_FOUND = lambda name="": "OpenAI API was not found"
OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama" OLLAMA_NOT_FOUND = "WebUI could not connect to Ollama"
CREATE_API_KEY_ERROR = "Oops! Something went wrong while creating your API key. Please try again later. If the issue persists, contact support for assistance."
EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding." EMPTY_CONTENT = "The content provided is empty. Please ensure that there is text or data present before proceeding."
...@@ -62,6 +62,21 @@ class SPAStaticFiles(StaticFiles): ...@@ -62,6 +62,21 @@ class SPAStaticFiles(StaticFiles):
raise ex raise ex
print(
f"""
___ __ __ _ _ _ ___
/ _ \ _ __ ___ _ __ \ \ / /__| |__ | | | |_ _|
| | | | '_ \ / _ \ '_ \ \ \ /\ / / _ \ '_ \| | | || |
| |_| | |_) | __/ | | | \ V V / __/ |_) | |_| || |
\___/| .__/ \___|_| |_| \_/\_/ \___|_.__/ \___/|___|
|_|
v{VERSION} - building the best open-source AI user interface.
https://github.com/open-webui/open-webui
"""
)
app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None) app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
...@@ -179,6 +194,7 @@ async def get_app_config(): ...@@ -179,6 +194,7 @@ async def get_app_config():
"images": images_app.state.ENABLED, "images": images_app.state.ENABLED,
"default_models": webui_app.state.DEFAULT_MODELS, "default_models": webui_app.state.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
"trusted_header_auth": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
} }
......
...@@ -14,6 +14,7 @@ uuid ...@@ -14,6 +14,7 @@ uuid
requests requests
aiohttp aiohttp
peewee peewee
peewee-migrate
bcrypt bcrypt
litellm==1.30.7 litellm==1.30.7
......
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends from fastapi import HTTPException, status, Depends
from apps.web.models.users import Users from apps.web.models.users import Users
from pydantic import BaseModel from pydantic import BaseModel
from typing import Union, Optional from typing import Union, Optional
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -8,6 +10,7 @@ from passlib.context import CryptContext ...@@ -8,6 +10,7 @@ from passlib.context import CryptContext
from datetime import datetime, timedelta from datetime import datetime, timedelta
import requests import requests
import jwt import jwt
import uuid
import logging import logging
import config import config
...@@ -58,6 +61,11 @@ def extract_token_from_auth_header(auth_header: str): ...@@ -58,6 +61,11 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def create_api_key():
key = str(uuid.uuid4()).replace("-", "")
return f"sk-{key}"
def get_http_authorization_cred(auth_header: str): def get_http_authorization_cred(auth_header: str):
try: try:
scheme, credentials = auth_header.split(" ") scheme, credentials = auth_header.split(" ")
...@@ -69,6 +77,10 @@ def get_http_authorization_cred(auth_header: str): ...@@ -69,6 +77,10 @@ def get_http_authorization_cred(auth_header: str):
def get_current_user( def get_current_user(
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security), auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
): ):
# auth by api key
if auth_token.credentials.startswith("sk-"):
return get_current_user_by_api_key(auth_token.credentials)
# auth by jwt token
data = decode_token(auth_token.credentials) data = decode_token(auth_token.credentials)
if data != None and "id" in data: if data != None and "id" in data:
user = Users.get_user_by_id(data["id"]) user = Users.get_user_by_id(data["id"])
...@@ -85,6 +97,16 @@ def get_current_user( ...@@ -85,6 +97,16 @@ def get_current_user(
) )
def get_current_user_by_api_key(api_key: str):
user = Users.get_user_by_api_key(api_key)
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
return user
def get_verified_user(user=Depends(get_current_user)): def get_verified_user(user=Depends(get_current_user)):
if user.role not in {"user", "admin"}: if user.role not in {"user", "admin"}:
raise HTTPException( raise HTTPException(
......
...@@ -318,3 +318,78 @@ export const updateJWTExpiresDuration = async (token: string, duration: string) ...@@ -318,3 +318,78 @@ export const updateJWTExpiresDuration = async (token: string, duration: string)
return res; return res;
}; };
export const createAPIKey = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res.api_key;
};
export const getAPIKey = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res.api_key;
};
export const deleteAPIKey = async (token: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/auths/api_key`, {
method: 'DELETE',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${token}`
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.catch((err) => {
console.log(err);
error = err.detail;
return null;
});
if (error) {
throw error;
}
return res;
};
...@@ -218,6 +218,102 @@ export const getChatById = async (token: string, id: string) => { ...@@ -218,6 +218,102 @@ export const getChatById = async (token: string, id: string) => {
return res; return res;
}; };
export const getChatByShareId = async (token: string, share_id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/share/${share_id}`, {
method: 'GET',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const shareChatById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/share`, {
method: 'POST',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const deleteSharedChatById = async (token: string, id: string) => {
let error = null;
const res = await fetch(`${WEBUI_API_BASE_URL}/chats/${id}/share`, {
method: 'DELETE',
headers: {
Accept: 'application/json',
'Content-Type': 'application/json',
...(token && { authorization: `Bearer ${token}` })
}
})
.then(async (res) => {
if (!res.ok) throw await res.json();
return res.json();
})
.then((json) => {
return json;
})
.catch((err) => {
error = err;
console.log(err);
return null;
});
if (error) {
throw error;
}
return res;
};
export const updateChatById = async (token: string, id: string, chat: object) => { export const updateChatById = async (token: string, id: string, chat: object) => {
let error = null; let error = null;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment