Unverified Commit 437d7ff6 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #897 from open-webui/main

dev
parents 02f364bf 81eceb48
...@@ -5,12 +5,7 @@ import uuid ...@@ -5,12 +5,7 @@ import uuid
from peewee import * from peewee import *
from apps.web.models.users import UserModel, Users from apps.web.models.users import UserModel, Users
from utils.utils import ( from utils.utils import verify_password
verify_password,
get_password_hash,
bearer_scheme,
create_token,
)
from apps.web.internal.db import DB from apps.web.internal.db import DB
...@@ -63,6 +58,15 @@ class SigninForm(BaseModel): ...@@ -63,6 +58,15 @@ class SigninForm(BaseModel):
password: str password: str
class ProfileImageUrlForm(BaseModel):
profile_image_url: str
class UpdateProfileForm(BaseModel):
profile_image_url: str
name: str
class UpdatePasswordForm(BaseModel): class UpdatePasswordForm(BaseModel):
password: str password: str
new_password: str new_password: str
......
...@@ -60,23 +60,23 @@ class ChatTitleIdResponse(BaseModel): ...@@ -60,23 +60,23 @@ class ChatTitleIdResponse(BaseModel):
class ChatTable: class ChatTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
db.create_tables([Chat]) db.create_tables([Chat])
def insert_new_chat(self, user_id: str, def insert_new_chat(self, user_id: str, form_data: ChatForm) -> Optional[ChatModel]:
form_data: ChatForm) -> Optional[ChatModel]:
id = str(uuid.uuid4()) id = str(uuid.uuid4())
chat = ChatModel( chat = ChatModel(
**{ **{
"id": id, "id": id,
"user_id": user_id, "user_id": user_id,
"title": form_data.chat["title"] if "title" in "title": (
form_data.chat else "New Chat", form_data.chat["title"] if "title" in form_data.chat else "New Chat"
),
"chat": json.dumps(form_data.chat), "chat": json.dumps(form_data.chat),
"timestamp": int(time.time()), "timestamp": int(time.time()),
}) }
)
result = Chat.create(**chat.model_dump()) result = Chat.create(**chat.model_dump())
return chat if result else None return chat if result else None
...@@ -109,25 +109,43 @@ class ChatTable: ...@@ -109,25 +109,43 @@ class ChatTable:
except: except:
return None return None
def get_chat_lists_by_user_id(self, def get_chat_lists_by_user_id(
user_id: str, self, user_id: str, skip: int = 0, limit: int = 50
skip: int = 0, ) -> List[ChatModel]:
limit: int = 50) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) for chat in Chat.select().where( ChatModel(**model_to_dict(chat))
Chat.user_id == user_id).order_by(Chat.timestamp.desc()) for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
# .limit(limit) # .limit(limit)
# .offset(skip) # .offset(skip)
] ]
def get_chat_lists_by_chat_ids(
self, chat_ids: List[str], skip: int = 0, limit: int = 50
) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.id.in_(chat_ids))
.order_by(Chat.timestamp.desc())
]
def get_all_chats(self) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select().order_by(Chat.timestamp.desc())
]
def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]: def get_all_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [ return [
ChatModel(**model_to_dict(chat)) for chat in Chat.select().where( ChatModel(**model_to_dict(chat))
Chat.user_id == user_id).order_by(Chat.timestamp.desc()) for chat in Chat.select()
.where(Chat.user_id == user_id)
.order_by(Chat.timestamp.desc())
] ]
def get_chat_by_id_and_user_id(self, id: str, def get_chat_by_id_and_user_id(self, id: str, user_id: str) -> Optional[ChatModel]:
user_id: str) -> Optional[ChatModel]:
try: try:
chat = Chat.get(Chat.id == id, Chat.user_id == user_id) chat = Chat.get(Chat.id == id, Chat.user_id == user_id)
return ChatModel(**model_to_dict(chat)) return ChatModel(**model_to_dict(chat))
...@@ -142,8 +160,7 @@ class ChatTable: ...@@ -142,8 +160,7 @@ class ChatTable:
def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool: def delete_chat_by_id_and_user_id(self, id: str, user_id: str) -> bool:
try: try:
query = Chat.delete().where((Chat.id == id) query = Chat.delete().where((Chat.id == id) & (Chat.user_id == user_id))
& (Chat.user_id == user_id))
query.execute() # Remove the rows, return number of rows removed. query.execute() # Remove the rows, return number of rows removed.
return True return True
......
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
import json
####################
# Documents DB Schema
####################
class Document(Model):
collection_name = CharField(unique=True)
name = CharField(unique=True)
title = CharField()
filename = CharField()
content = TextField(null=True)
user_id = CharField()
timestamp = DateField()
class Meta:
database = DB
class DocumentModel(BaseModel):
collection_name: str
name: str
title: str
filename: str
content: Optional[str] = None
user_id: str
timestamp: int # timestamp in epoch
####################
# Forms
####################
class DocumentResponse(BaseModel):
collection_name: str
name: str
title: str
filename: str
content: Optional[dict] = None
user_id: str
timestamp: int # timestamp in epoch
class DocumentUpdateForm(BaseModel):
name: str
title: str
class DocumentForm(DocumentUpdateForm):
collection_name: str
filename: str
content: Optional[str] = None
class DocumentsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Document])
def insert_new_doc(
self, user_id: str, form_data: DocumentForm
) -> Optional[DocumentModel]:
document = DocumentModel(
**{
**form_data.model_dump(),
"user_id": user_id,
"timestamp": int(time.time()),
}
)
try:
result = Document.create(**document.model_dump())
if result:
return document
else:
return None
except:
return None
def get_doc_by_name(self, name: str) -> Optional[DocumentModel]:
try:
document = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(document))
except:
return None
def get_docs(self) -> List[DocumentModel]:
return [
DocumentModel(**model_to_dict(doc))
for doc in Document.select()
# .limit(limit).offset(skip)
]
def update_doc_by_name(
self, name: str, form_data: DocumentUpdateForm
) -> Optional[DocumentModel]:
try:
query = Document.update(
title=form_data.title,
name=form_data.name,
timestamp=int(time.time()),
).where(Document.name == name)
query.execute()
doc = Document.get(Document.name == form_data.name)
return DocumentModel(**model_to_dict(doc))
except Exception as e:
print(e)
return None
def update_doc_content_by_name(
self, name: str, updated: dict
) -> Optional[DocumentModel]:
try:
doc = self.get_doc_by_name(name)
doc_content = json.loads(doc.content if doc.content else "{}")
doc_content = {**doc_content, **updated}
query = Document.update(
content=json.dumps(doc_content),
timestamp=int(time.time()),
).where(Document.name == name)
query.execute()
doc = Document.get(Document.name == name)
return DocumentModel(**model_to_dict(doc))
except Exception as e:
print(e)
return None
def delete_doc_by_name(self, name: str) -> bool:
try:
query = Document.delete().where((Document.name == name))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Documents = DocumentsTable(DB)
from pydantic import BaseModel
from typing import List, Union, Optional
from peewee import *
from playhouse.shortcuts import model_to_dict
import json
import uuid
import time
from apps.web.internal.db import DB
####################
# Tag DB Schema
####################
class Tag(Model):
id = CharField(unique=True)
name = CharField()
user_id = CharField()
data = TextField(null=True)
class Meta:
database = DB
class ChatIdTag(Model):
id = CharField(unique=True)
tag_name = CharField()
chat_id = CharField()
user_id = CharField()
timestamp = DateField()
class Meta:
database = DB
class TagModel(BaseModel):
id: str
name: str
user_id: str
data: Optional[str] = None
class ChatIdTagModel(BaseModel):
id: str
tag_name: str
chat_id: str
user_id: str
timestamp: int
####################
# Forms
####################
class ChatIdTagForm(BaseModel):
tag_name: str
chat_id: str
class TagChatIdsResponse(BaseModel):
chat_ids: List[str]
class ChatTagsResponse(BaseModel):
tags: List[str]
class TagTable:
def __init__(self, db):
self.db = db
db.create_tables([Tag, ChatIdTag])
def insert_new_tag(self, name: str, user_id: str) -> Optional[TagModel]:
id = str(uuid.uuid4())
tag = TagModel(**{"id": id, "user_id": user_id, "name": name})
try:
result = Tag.create(**tag.model_dump())
if result:
return tag
else:
return None
except Exception as e:
return None
def get_tag_by_name_and_user_id(
self, name: str, user_id: str
) -> Optional[TagModel]:
try:
tag = Tag.get(Tag.name == name, Tag.user_id == user_id)
return TagModel(**model_to_dict(tag))
except Exception as e:
return None
def add_tag_to_chat(
self, user_id: str, form_data: ChatIdTagForm
) -> Optional[ChatIdTagModel]:
tag = self.get_tag_by_name_and_user_id(form_data.tag_name, user_id)
if tag == None:
tag = self.insert_new_tag(form_data.tag_name, user_id)
id = str(uuid.uuid4())
chatIdTag = ChatIdTagModel(
**{
"id": id,
"user_id": user_id,
"chat_id": form_data.chat_id,
"tag_name": tag.name,
"timestamp": int(time.time()),
}
)
try:
result = ChatIdTag.create(**chatIdTag.model_dump())
if result:
return chatIdTag
else:
return None
except:
return None
def get_tags_by_user_id(self, user_id: str) -> List[TagModel]:
tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where(ChatIdTag.user_id == user_id)
.order_by(ChatIdTag.timestamp.desc())
]
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select().where(Tag.name.in_(tag_names))
]
def get_tags_by_chat_id_and_user_id(
self, chat_id: str, user_id: str
) -> List[TagModel]:
tag_names = [
ChatIdTagModel(**model_to_dict(chat_id_tag)).tag_name
for chat_id_tag in ChatIdTag.select()
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.chat_id == chat_id))
.order_by(ChatIdTag.timestamp.desc())
]
return [
TagModel(**model_to_dict(tag))
for tag in Tag.select().where(Tag.name.in_(tag_names))
]
def get_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> Optional[ChatIdTagModel]:
return [
ChatIdTagModel(**model_to_dict(chat_id_tag))
for chat_id_tag in ChatIdTag.select()
.where((ChatIdTag.user_id == user_id) & (ChatIdTag.tag_name == tag_name))
.order_by(ChatIdTag.timestamp.desc())
]
def count_chat_ids_by_tag_name_and_user_id(
self, tag_name: str, user_id: str
) -> int:
return (
ChatIdTag.select()
.where((ChatIdTag.tag_name == tag_name) & (ChatIdTag.user_id == user_id))
.count()
)
def delete_tag_by_tag_name_and_chat_id_and_user_id(
self, tag_name: str, chat_id: str, user_id: str
) -> bool:
try:
query = ChatIdTag.delete().where(
(ChatIdTag.tag_name == tag_name)
& (ChatIdTag.chat_id == chat_id)
& (ChatIdTag.user_id == user_id)
)
res = query.execute() # Remove the rows, return number of rows removed.
print(res)
tag_count = self.count_chat_ids_by_tag_name_and_user_id(tag_name, user_id)
if tag_count == 0:
# Remove tag item from Tag col as well
query = Tag.delete().where(
(Tag.name == tag_name) & (Tag.user_id == user_id)
)
query.execute() # Remove the rows, return number of rows removed.
return True
except Exception as e:
print("delete_tag", e)
return False
def delete_tags_by_chat_id_and_user_id(self, chat_id: str, user_id: str) -> bool:
tags = self.get_tags_by_chat_id_and_user_id(chat_id, user_id)
for tag in tags:
self.delete_tag_by_tag_name_and_chat_id_and_user_id(
tag.tag_name, chat_id, user_id
)
return True
Tags = TagTable(DB)
...@@ -65,7 +65,7 @@ class UsersTable: ...@@ -65,7 +65,7 @@ class UsersTable:
"name": name, "name": name,
"email": email, "email": email,
"role": role, "role": role,
"profile_image_url": get_gravatar_url(email), "profile_image_url": "/user.png",
"timestamp": int(time.time()), "timestamp": int(time.time()),
} }
) )
...@@ -92,7 +92,8 @@ class UsersTable: ...@@ -92,7 +92,8 @@ class UsersTable:
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [ return [
UserModel(**model_to_dict(user)) UserModel(**model_to_dict(user))
for user in User.select().limit(limit).offset(skip) for user in User.select()
# .limit(limit).offset(skip)
] ]
def get_num_users(self) -> Optional[int]: def get_num_users(self) -> Optional[int]:
...@@ -108,6 +109,20 @@ class UsersTable: ...@@ -108,6 +109,20 @@ class UsersTable:
except: except:
return None return None
def update_user_profile_image_url_by_id(
self, id: str, profile_image_url: str
) -> Optional[UserModel]:
try:
query = User.update(profile_image_url=profile_image_url).where(
User.id == id
)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]: def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try: try:
query = User.update(**updated).where(User.id == id) query = User.update(**updated).where(User.id == id)
......
...@@ -3,14 +3,16 @@ from fastapi import Depends, FastAPI, HTTPException, status ...@@ -3,14 +3,16 @@ from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union from typing import List, Union
from fastapi import APIRouter from fastapi import APIRouter, status
from pydantic import BaseModel from pydantic import BaseModel
import time import time
import uuid import uuid
import re
from apps.web.models.auths import ( from apps.web.models.auths import (
SigninForm, SigninForm,
SignupForm, SignupForm,
UpdateProfileForm,
UpdatePasswordForm, UpdatePasswordForm,
UserResponse, UserResponse,
SigninResponse, SigninResponse,
...@@ -18,8 +20,13 @@ from apps.web.models.auths import ( ...@@ -18,8 +20,13 @@ from apps.web.models.auths import (
) )
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token from utils.utils import (
from utils.misc import get_gravatar_url, validate_email_format get_password_hash,
get_current_user,
get_admin_user,
create_token,
)
from utils.misc import parse_duration, validate_email_format
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -40,14 +47,37 @@ async def get_session_user(user=Depends(get_current_user)): ...@@ -40,14 +47,37 @@ async def get_session_user(user=Depends(get_current_user)):
} }
############################
# Update Profile
############################
@router.post("/update/profile", response_model=UserResponse)
async def update_profile(
form_data: UpdateProfileForm, session_user=Depends(get_current_user)
):
if session_user:
user = Users.update_user_by_id(
session_user.id,
{"profile_image_url": form_data.profile_image_url, "name": form_data.name},
)
if user:
return user
else:
raise HTTPException(400, detail=ERROR_MESSAGES.DEFAULT())
else:
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_CRED)
############################ ############################
# Update Password # Update Password
############################ ############################
@router.post("/update/password", response_model=bool) @router.post("/update/password", response_model=bool)
async def update_password(form_data: UpdatePasswordForm, async def update_password(
session_user=Depends(get_current_user)): form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
):
if session_user: if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password) user = Auths.authenticate_user(session_user.email, form_data.password)
...@@ -66,10 +96,13 @@ async def update_password(form_data: UpdatePasswordForm, ...@@ -66,10 +96,13 @@ async def update_password(form_data: UpdatePasswordForm,
@router.post("/signin", response_model=SigninResponse) @router.post("/signin", response_model=SigninResponse)
async def signin(form_data: SigninForm): async def signin(request: Request, form_data: SigninForm):
user = Auths.authenticate_user(form_data.email.lower(), form_data.password) user = Auths.authenticate_user(form_data.email.lower(), form_data.password)
if user: if user:
token = create_token(data={"email": user.email}) token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
)
return { return {
"token": token, "token": token,
...@@ -91,17 +124,35 @@ async def signin(form_data: SigninForm): ...@@ -91,17 +124,35 @@ async def signin(form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse) @router.post("/signup", response_model=SigninResponse)
async def signup(request: Request, form_data: SignupForm): async def signup(request: Request, form_data: SignupForm):
if request.app.state.ENABLE_SIGNUP: if not request.app.state.ENABLE_SIGNUP:
if validate_email_format(form_data.email.lower()): raise HTTPException(
if not Users.get_user_by_email(form_data.email.lower()): status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
)
if not validate_email_format(form_data.email.lower()):
raise HTTPException(
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
)
if Users.get_user_by_email(form_data.email.lower()):
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
try: try:
role = "admin" if Users.get_num_users() == 0 else "pending" role = (
"admin"
if Users.get_num_users() == 0
else request.app.state.DEFAULT_USER_ROLE
)
hashed = get_password_hash(form_data.password) hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(form_data.email.lower(), user = Auths.insert_new_auth(
hashed, form_data.name, role) form_data.email.lower(), hashed, form_data.name, role
)
if user: if user:
token = create_token(data={"email": user.email}) token = create_token(
data={"id": user.id},
expires_delta=parse_duration(request.app.state.JWT_EXPIRES_IN),
)
# response.set_cookie(key='token', value=token, httponly=True) # response.set_cookie(key='token', value=token, httponly=True)
return { return {
...@@ -114,18 +165,9 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -114,18 +165,9 @@ async def signup(request: Request, form_data: SignupForm):
"profile_image_url": user.profile_image_url, "profile_image_url": user.profile_image_url,
} }
else: else:
raise HTTPException( raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err: except Exception as err:
raise HTTPException(500, raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
detail=ERROR_MESSAGES.DEFAULT(err))
else:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
else:
raise HTTPException(400,
detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
############################ ############################
...@@ -134,23 +176,64 @@ async def signup(request: Request, form_data: SignupForm): ...@@ -134,23 +176,64 @@ async def signup(request: Request, form_data: SignupForm):
@router.get("/signup/enabled", response_model=bool) @router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_current_user)): async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
if user.role == "admin":
return request.app.state.ENABLE_SIGNUP return request.app.state.ENABLE_SIGNUP
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@router.get("/signup/enabled/toggle", response_model=bool) @router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_current_user)): async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
if user.role == "admin":
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
return request.app.state.ENABLE_SIGNUP return request.app.state.ENABLE_SIGNUP
############################
# Default User Role
############################
@router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
return request.app.state.DEFAULT_USER_ROLE
class UpdateRoleForm(BaseModel):
role: str
@router.post("/signup/user/role")
async def update_default_user_role(
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
):
if form_data.role in ["pending", "user", "admin"]:
request.app.state.DEFAULT_USER_ROLE = form_data.role
return request.app.state.DEFAULT_USER_ROLE
############################
# JWT Expiration
############################
@router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return request.app.state.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel):
duration: str
@router.post("/token/expires/update")
async def update_token_expires_duration(
request: Request,
form_data: UpdateJWTExpiresDurationForm,
user=Depends(get_admin_user),
):
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
# Check if the input string matches the pattern
if re.match(pattern, form_data.duration):
request.app.state.JWT_EXPIRES_IN = form_data.duration
return request.app.state.JWT_EXPIRES_IN
else: else:
raise HTTPException( return request.app.state.JWT_EXPIRES_IN
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
from fastapi import Depends, Request, HTTPException, status from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
from utils.utils import get_current_user from utils.utils import get_current_user, get_admin_user
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
...@@ -16,8 +16,15 @@ from apps.web.models.chats import ( ...@@ -16,8 +16,15 @@ from apps.web.models.chats import (
Chats, Chats,
) )
from utils.utils import (
bearer_scheme, ) from apps.web.models.tags import (
TagModel,
ChatIdTagModel,
ChatIdTagForm,
ChatTagsResponse,
Tags,
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -29,7 +36,8 @@ router = APIRouter() ...@@ -29,7 +36,8 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats( async def get_user_chats(
user=Depends(get_current_user), skip: int = 0, limit: int = 50): user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
return Chats.get_chat_lists_by_user_id(user.id, skip, limit) return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
...@@ -41,9 +49,21 @@ async def get_user_chats( ...@@ -41,9 +49,21 @@ async def get_user_chats(
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(user=Depends(get_current_user)): async def get_all_user_chats(user=Depends(get_current_user)):
return [ return [
ChatResponse(**{ ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
**chat.model_dump(), "chat": json.loads(chat.chat) for chat in Chats.get_all_chats_by_user_id(user.id)
}) for chat in Chats.get_all_chats_by_user_id(user.id) ]
############################
# GetAllChatsInDB
############################
@router.get("/all/db", response_model=List[ChatResponse])
async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats()
] ]
...@@ -54,8 +74,50 @@ async def get_all_user_chats(user=Depends(get_current_user)): ...@@ -54,8 +74,50 @@ async def get_all_user_chats(user=Depends(get_current_user)):
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)): async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
try:
chat = Chats.insert_new_chat(user.id, form_data) chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetAllTags
############################
@router.get("/tags/all", response_model=List[TagModel])
async def get_all_tags(user=Depends(get_current_user)):
try:
tags = Tags.get_tags_by_user_id(user.id)
return tags
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# GetChatsByTags
############################
@router.get("/tags/tag/{tag_name}", response_model=List[ChatTitleIdResponse])
async def get_user_chats_by_tag_name(
tag_name: str, user=Depends(get_current_user), skip: int = 0, limit: int = 50
):
chat_ids = [
chat_id_tag.chat_id
for chat_id_tag in Tags.get_chat_ids_by_tag_name_and_user_id(tag_name, user.id)
]
print(chat_ids)
return Chats.get_chat_lists_by_chat_ids(chat_ids, skip, limit)
############################ ############################
...@@ -68,12 +130,11 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): ...@@ -68,12 +130,11 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
return ChatResponse(**{ return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
**chat.model_dump(), "chat": json.loads(chat.chat)
})
else: else:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, raise HTTPException(
detail=ERROR_MESSAGES.NOT_FOUND) status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################ ############################
...@@ -82,17 +143,15 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)): ...@@ -82,17 +143,15 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(id: str, async def update_chat_by_id(
form_data: ChatForm, id: str, form_data: ChatForm, user=Depends(get_current_user)
user=Depends(get_current_user)): ):
chat = Chats.get_chat_by_id_and_user_id(id, user.id) chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat: if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat} updated_chat = {**json.loads(chat.chat), **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat) chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{ return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
**chat.model_dump(), "chat": json.loads(chat.chat)
})
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
...@@ -106,11 +165,103 @@ async def update_chat_by_id(id: str, ...@@ -106,11 +165,103 @@ async def update_chat_by_id(id: str,
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(id: str, user=Depends(get_current_user)): async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
if (
user.role == "user"
and not request.app.state.USER_PERMISSIONS["chat"]["deletion"]
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Chats.delete_chat_by_id_and_user_id(id, user.id) result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result return result
############################
# GetChatTagsById
############################
@router.get("/{id}/tags", response_model=List[TagModel])
async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if tags != None:
return tags
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################
# AddChatTagById
############################
@router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
async def add_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
):
tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
if form_data.tag_name not in tags:
tag = Tags.add_tag_to_chat(user.id, form_data)
if tag:
return tag
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# DeleteChatTagById
############################
@router.delete("/{id}/tags", response_model=Optional[bool])
async def delete_chat_tag_by_id(
id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
):
result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
form_data.tag_name, id, user.id
)
if result:
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################
# DeleteAllChatTagsById
############################
@router.delete("/{id}/tags/all", response_model=Optional[bool])
async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
if result:
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
)
############################ ############################
# DeleteAllChats # DeleteAllChats
############################ ############################
......
...@@ -10,7 +10,7 @@ import uuid ...@@ -10,7 +10,7 @@ import uuid
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token from utils.utils import get_password_hash, get_current_user, get_admin_user, create_token
from utils.misc import get_gravatar_url, validate_email_format from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -21,20 +21,35 @@ class SetDefaultModelsForm(BaseModel): ...@@ -21,20 +21,35 @@ class SetDefaultModelsForm(BaseModel):
models: str models: str
class PromptSuggestion(BaseModel):
title: List[str]
content: str
class SetDefaultSuggestionsForm(BaseModel):
suggestions: List[PromptSuggestion]
############################ ############################
# SetDefaultModels # SetDefaultModels
############################ ############################
@router.post("/default/models", response_model=str) @router.post("/default/models", response_model=str)
async def set_global_default_models(request: Request, async def set_global_default_models(
form_data: SetDefaultModelsForm, request: Request, form_data: SetDefaultModelsForm, user=Depends(get_admin_user)
user=Depends(get_current_user)): ):
if user.role == "admin":
request.app.state.DEFAULT_MODELS = form_data.models request.app.state.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS return request.app.state.DEFAULT_MODELS
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, @router.post("/default/suggestions", response_model=List[PromptSuggestion])
) async def set_global_default_suggestions(
request: Request,
form_data: SetDefaultSuggestionsForm,
user=Depends(get_admin_user),
):
data = form_data.model_dump()
request.app.state.DEFAULT_PROMPT_SUGGESTIONS = data["suggestions"]
return request.app.state.DEFAULT_PROMPT_SUGGESTIONS
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.documents import (
Documents,
DocumentForm,
DocumentUpdateForm,
DocumentModel,
DocumentResponse,
)
from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES
router = APIRouter()
############################
# GetDocuments
############################
@router.get("/", response_model=List[DocumentResponse])
async def get_documents(user=Depends(get_current_user)):
docs = [
DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
for doc in Documents.get_docs()
]
return docs
############################
# CreateNewDoc
############################
@router.post("/create", response_model=Optional[DocumentResponse])
async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
doc = Documents.get_doc_by_name(form_data.name)
if doc == None:
doc = Documents.insert_new_doc(user.id, form_data)
if doc:
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.FILE_EXISTS,
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NAME_TAG_TAKEN,
)
############################
# GetDocByName
############################
@router.get("/name/{name}", response_model=Optional[DocumentResponse])
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
doc = Documents.get_doc_by_name(name)
if doc:
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# TagDocByName
############################
class TagItem(BaseModel):
name: str
class TagDocumentForm(BaseModel):
name: str
tags: List[dict]
@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
if doc:
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateDocByName
############################
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
):
doc = Documents.update_doc_by_name(name, form_data)
if doc:
return DocumentResponse(
**{
**doc.model_dump(),
"content": json.loads(doc.content if doc.content else "{}"),
}
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.NAME_TAG_TAKEN,
)
############################
# DeleteDocByName
############################
@router.delete("/name/{name}/delete", response_model=bool)
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
result = Documents.delete_doc_by_name(name)
return result
...@@ -13,7 +13,7 @@ from apps.web.models.modelfiles import ( ...@@ -13,7 +13,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse, ModelfileResponse,
) )
from utils.utils import get_current_user from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0, ...@@ -37,13 +37,7 @@ async def get_modelfiles(skip: int = 0,
@router.post("/create", response_model=Optional[ModelfileResponse]) @router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_current_user)): user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data) modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile: if modelfile:
...@@ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, ...@@ -91,12 +85,7 @@ async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
@router.post("/update", response_model=Optional[ModelfileResponse]) @router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
user=Depends(get_current_user)): user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
updated_modelfile = { updated_modelfile = {
...@@ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm, ...@@ -127,12 +116,6 @@ async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
@router.delete("/delete", response_model=bool) @router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm, async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)): user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name) result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result return result
...@@ -8,7 +8,7 @@ import json ...@@ -8,7 +8,7 @@ import json
from apps.web.models.prompts import Prompts, PromptForm, PromptModel from apps.web.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user from utils.utils import get_current_user, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -29,25 +29,17 @@ async def get_prompts(user=Depends(get_current_user)): ...@@ -29,25 +29,17 @@ async def get_prompts(user=Depends(get_current_user)):
@router.post("/create", response_model=Optional[PromptModel]) @router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)): async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.get_prompt_by_command(form_data.command) prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt == None: if prompt == None:
prompt = Prompts.insert_new_prompt(user.id, form_data) prompt = Prompts.insert_new_prompt(user.id, form_data)
if prompt: if prompt:
return prompt return prompt
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(), detail=ERROR_MESSAGES.DEFAULT(),
) )
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.COMMAND_TAKEN, detail=ERROR_MESSAGES.COMMAND_TAKEN,
...@@ -79,14 +71,8 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)): ...@@ -79,14 +71,8 @@ async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
@router.post("/command/{command}/update", response_model=Optional[PromptModel]) @router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command( async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_current_user) command: str, form_data: PromptForm, user=Depends(get_admin_user)
): ):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data) prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt: if prompt:
return prompt return prompt
...@@ -103,12 +89,6 @@ async def update_prompt_by_command( ...@@ -103,12 +89,6 @@ async def update_prompt_by_command(
@router.delete("/command/{command}/delete", response_model=bool) @router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_current_user)): async def delete_prompt_by_command(command: str, user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Prompts.delete_prompt_by_command(f"/{command}") result = Prompts.delete_prompt_by_command(f"/{command}")
return result return result
from fastapi import Response from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
...@@ -11,7 +11,7 @@ import uuid ...@@ -11,7 +11,7 @@ import uuid
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash from utils.utils import get_current_user, get_password_hash, get_admin_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -22,33 +22,39 @@ router = APIRouter() ...@@ -22,33 +22,39 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel]) @router.get("/", response_model=List[UserModel])
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)): async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_admin_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
return Users.get_users(skip, limit) return Users.get_users(skip, limit)
############################
# User Permissions
############################
@router.get("/permissions/user")
async def get_user_permissions(request: Request, user=Depends(get_admin_user)):
return request.app.state.USER_PERMISSIONS
@router.post("/permissions/user")
async def update_user_permissions(
request: Request, form_data: dict, user=Depends(get_admin_user)
):
request.app.state.USER_PERMISSIONS = form_data
return request.app.state.USER_PERMISSIONS
############################ ############################
# UpdateUserRole # UpdateUserRole
############################ ############################
@router.post("/update/role", response_model=Optional[UserModel]) @router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role( async def update_user_role(form_data: UserRoleUpdateForm, user=Depends(get_admin_user)):
form_data: UserRoleUpdateForm, user=Depends(get_current_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user.id != form_data.id: if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role) return Users.update_user_role_by_id(form_data.id, form_data.role)
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED, detail=ERROR_MESSAGES.ACTION_PROHIBITED,
...@@ -62,14 +68,8 @@ async def update_user_role( ...@@ -62,14 +68,8 @@ async def update_user_role(
@router.post("/{user_id}/update", response_model=Optional[UserModel]) @router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id( async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user) user_id: str, form_data: UserUpdateForm, session_user=Depends(get_admin_user)
): ):
if session_user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user = Users.get_user_by_id(user_id) user = Users.get_user_by_id(user_id)
if user: if user:
...@@ -98,13 +98,12 @@ async def update_user_by_id( ...@@ -98,13 +98,12 @@ async def update_user_by_id(
if updated_user: if updated_user:
return updated_user return updated_user
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(), detail=ERROR_MESSAGES.DEFAULT(),
) )
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND, detail=ERROR_MESSAGES.USER_NOT_FOUND,
...@@ -117,25 +116,19 @@ async def update_user_by_id( ...@@ -117,25 +116,19 @@ async def update_user_by_id(
@router.delete("/{user_id}", response_model=bool) @router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, user=Depends(get_current_user)): async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
if user.role == "admin":
if user.id != user_id: if user.id != user_id:
result = Auths.delete_auth_by_id(user_id) result = Auths.delete_auth_by_id(user_id)
if result: if result:
return True return True
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DELETE_USER_ERROR, detail=ERROR_MESSAGES.DELETE_USER_ERROR,
) )
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED, detail=ERROR_MESSAGES.ACTION_PROHIBITED,
) )
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
...@@ -9,9 +9,11 @@ import os ...@@ -9,9 +9,11 @@ import os
import aiohttp import aiohttp
import json import json
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256, get_gravatar_url
from config import OLLAMA_API_BASE_URL, DATA_DIR, UPLOAD_DIR
from constants import ERROR_MESSAGES
from config import OLLAMA_API_BASE_URL
router = APIRouter() router = APIRouter()
...@@ -40,10 +42,7 @@ def parse_huggingface_url(hf_url): ...@@ -40,10 +42,7 @@ def parse_huggingface_url(hf_url):
return None return None
async def download_file_stream(url, async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
file_path,
file_name,
chunk_size=1024 * 1024):
done = False done = False
if os.path.exists(file_path): if os.path.exists(file_path):
...@@ -57,8 +56,7 @@ async def download_file_stream(url, ...@@ -57,8 +56,7 @@ async def download_file_stream(url,
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response: async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get("content-length", total_size = int(response.headers.get("content-length", 0)) + current_size
0)) + current_size
with open(file_path, "ab+") as file: with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size): async for data in response.content.iter_chunked(chunk_size):
...@@ -91,13 +89,14 @@ async def download_file_stream(url, ...@@ -91,13 +89,14 @@ async def download_file_stream(url,
@router.get("/download") @router.get("/download")
async def download(url: str, ): async def download(
url: str,
):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name = parse_huggingface_url(url) file_name = parse_huggingface_url(url)
if file_name: if file_name:
os.makedirs("./uploads", exist_ok=True) file_path = f"{UPLOAD_DIR}/{file_name}"
file_path = os.path.join("./uploads", f"{file_name}")
return StreamingResponse( return StreamingResponse(
download_file_stream(url, file_path, file_name), download_file_stream(url, file_path, file_name),
...@@ -108,25 +107,29 @@ async def download(url: str, ): ...@@ -108,25 +107,29 @@ async def download(url: str, ):
@router.post("/upload") @router.post("/upload")
async def upload(file: UploadFile = File(...)): def upload(file: UploadFile = File(...)):
os.makedirs("./uploads", exist_ok=True) file_path = f"{UPLOAD_DIR}/{file.filename}"
file_path = os.path.join("./uploads", file.filename)
async def file_write_stream(): # Save file in chunks
total = 0 with open(file_path, "wb+") as f:
total_size = file.size for chunk in file.file:
chunk_size = 1024 * 1024 f.write(chunk)
done = False def file_process_stream():
total_size = os.path.getsize(file_path)
chunk_size = 1024 * 1024
try: try:
with open(file_path, "wb+") as f: with open(file_path, "rb") as f:
while True: total = 0
chunk = file.file.read(chunk_size) done = False
while not done:
chunk = f.read(chunk_size)
if not chunk: if not chunk:
break done = True
f.write(chunk) continue
total += len(chunk) total += len(chunk)
done = total_size == total
progress = round((total / total_size) * 100, 2) progress = round((total / total_size) * 100, 2)
res = { res = {
...@@ -134,7 +137,6 @@ async def upload(file: UploadFile = File(...)): ...@@ -134,7 +137,6 @@ async def upload(file: UploadFile = File(...)):
"total": total_size, "total": total_size,
"completed": total, "completed": total,
} }
yield f"data: {json.dumps(res)}\n\n" yield f"data: {json.dumps(res)}\n\n"
if done: if done:
...@@ -152,14 +154,21 @@ async def upload(file: UploadFile = File(...)): ...@@ -152,14 +154,21 @@ async def upload(file: UploadFile = File(...)):
"name": file.filename, "name": file.filename,
} }
os.remove(file_path) os.remove(file_path)
yield f"data: {json.dumps(res)}\n\n" yield f"data: {json.dumps(res)}\n\n"
else: else:
raise "Ollama: Could not create blob, Please try again." raise Exception(
"Ollama: Could not create blob, Please try again."
)
except Exception as e: except Exception as e:
res = {"error": str(e)} res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n" yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_write_stream(), return StreamingResponse(file_process_stream(), media_type="text/event-stream")
media_type="text/event-stream")
@router.get("/gravatar")
async def get_gravatar(
email: str,
):
return get_gravatar_url(email)
from dotenv import load_dotenv, find_dotenv
import os import os
import chromadb import chromadb
from chromadb import Settings from chromadb import Settings
from base64 import b64encode
from bs4 import BeautifulSoup
from pathlib import Path
import json
import markdown
import requests
import shutil
from secrets import token_bytes from secrets import token_bytes
from base64 import b64encode
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from pathlib import Path try:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv("../.env"))
except ImportError:
print("dotenv not installed, skipping...")
WEBUI_NAME = "Open WebUI"
shutil.copyfile("../build/favicon.png", "./static/favicon.png")
####################################
# ENV (dev,test,prod)
####################################
ENV = os.environ.get("ENV", "dev")
try:
with open(f"../package.json", "r") as f:
PACKAGE_DATA = json.load(f)
except:
PACKAGE_DATA = {"version": "0.0.0"}
VERSION = PACKAGE_DATA["version"]
# Function to parse each section
def parse_section(section):
items = []
for li in section.find_all("li"):
# Extract raw HTML string
raw_html = str(li)
# Extract text without HTML tags
text = li.get_text(separator=" ", strip=True)
load_dotenv(find_dotenv("../.env")) # Split into title and content
parts = text.split(": ", 1)
title = parts[0].strip() if len(parts) > 1 else ""
content = parts[1].strip() if len(parts) > 1 else text
items.append({"title": title, "content": content, "raw": raw_html})
return items
try:
with open("../CHANGELOG.md", "r") as file:
changelog_content = file.read()
except:
changelog_content = ""
# Convert markdown content to HTML
html_content = markdown.markdown(changelog_content)
# Parse the HTML content
soup = BeautifulSoup(html_content, "html.parser")
# Initialize JSON structure
changelog_json = {}
# Iterate over each version
for version in soup.find_all("h2"):
version_number = version.get_text().strip().split(" - ")[0][1:-1] # Remove brackets
date = version.get_text().strip().split(" - ")[1]
version_data = {"date": date}
# Find the next sibling that is a h3 tag (section title)
current = version.find_next_sibling()
print(current)
while current and current.name != "h2":
if current.name == "h3":
section_title = current.get_text().lower() # e.g., "added", "fixed"
section_items = parse_section(current.find_next_sibling("ul"))
version_data[section_title] = section_items
# Move to the next element
current = current.find_next_sibling()
changelog_json[version_number] = version_data
CHANGELOG = changelog_json
#################################### ####################################
# File Upload # CUSTOM_NAME
#################################### ####################################
CUSTOM_NAME = os.environ.get("CUSTOM_NAME", "")
if CUSTOM_NAME:
try:
r = requests.get(f"https://api.openwebui.com/api/v1/custom/{CUSTOM_NAME}")
data = r.json()
if r.ok:
if "logo" in data:
url = (
f"https://api.openwebui.com{data['logo']}"
if data["logo"][0] == "/"
else data["logo"]
)
UPLOAD_DIR = "./data/uploads" r = requests.get(url, stream=True)
if r.status_code == 200:
with open("./static/favicon.png", "wb") as f:
r.raw.decode_content = True
shutil.copyfileobj(r.raw, f)
WEBUI_NAME = data["name"]
except Exception as e:
print(e)
pass
####################################
# DATA/FRONTEND BUILD DIR
####################################
DATA_DIR = str(Path(os.getenv("DATA_DIR", "./data")).resolve())
FRONTEND_BUILD_DIR = str(Path(os.getenv("FRONTEND_BUILD_DIR", "../build")))
try:
with open(f"{DATA_DIR}/config.json", "r") as f:
CONFIG_DATA = json.load(f)
except:
CONFIG_DATA = {}
####################################
# File Upload DIR
####################################
UPLOAD_DIR = f"{DATA_DIR}/uploads"
Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True) Path(UPLOAD_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# ENV (dev,test,prod) # Cache DIR
#################################### ####################################
ENV = os.environ.get("ENV", "dev") CACHE_DIR = f"{DATA_DIR}/cache"
Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Docs DIR
####################################
DOCS_DIR = f"{DATA_DIR}/docs"
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
#################################### ####################################
# OLLAMA_API_BASE_URL # OLLAMA_API_BASE_URL
...@@ -54,11 +187,50 @@ OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "") ...@@ -54,11 +187,50 @@ OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "": if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1" OPENAI_API_BASE_URL = "https://api.openai.com/v1"
####################################
# WEBUI
####################################
ENABLE_SIGNUP = os.environ.get("ENABLE_SIGNUP", True)
DEFAULT_MODELS = os.environ.get("DEFAULT_MODELS", None)
DEFAULT_PROMPT_SUGGESTIONS = (
CONFIG_DATA["ui"]["prompt_suggestions"]
if "ui" in CONFIG_DATA
and "prompt_suggestions" in CONFIG_DATA["ui"]
and type(CONFIG_DATA["ui"]["prompt_suggestions"]) is list
else [
{
"title": ["Help me study", "vocabulary for a college entrance exam"],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option.",
},
{
"title": ["Give me ideas", "for what to do with my kids' art"],
"content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter.",
},
{
"title": ["Tell me a fun fact", "about the Roman Empire"],
"content": "Tell me a random fun fact about the Roman Empire",
},
{
"title": ["Show me a code snippet", "of a website's sticky header"],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript.",
},
]
)
DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
USER_PERMISSIONS = {"chat": {"deletion": True}}
#################################### ####################################
# WEBUI_VERSION # WEBUI_VERSION
#################################### ####################################
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.50") WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.100")
#################################### ####################################
# WEBUI_AUTH (Required for security) # WEBUI_AUTH (Required for security)
...@@ -67,22 +239,62 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.50") ...@@ -67,22 +239,62 @@ WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.50")
WEBUI_AUTH = True WEBUI_AUTH = True
#################################### ####################################
# WEBUI_JWT_SECRET_KEY # WEBUI_SECRET_KEY
#################################### ####################################
WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t") WEBUI_SECRET_KEY = os.environ.get(
"WEBUI_SECRET_KEY",
os.environ.get(
"WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t"
), # DEPRECATED: remove at next major version
)
if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "": if WEBUI_AUTH and WEBUI_SECRET_KEY == "":
raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND) raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
#################################### ####################################
# RAG # RAG
#################################### ####################################
CHROMA_DATA_PATH = "./data/vector_db" CHROMA_DATA_PATH = f"{DATA_DIR}/vector_db"
EMBED_MODEL = "all-MiniLM-L6-v2" # this uses the model defined in the Dockerfile ENV variable. If you dont use docker or docker based deployments such as k8s, the default embedding model will be used (all-MiniLM-L6-v2)
RAG_EMBEDDING_MODEL = os.environ.get("RAG_EMBEDDING_MODEL", "all-MiniLM-L6-v2")
# device type ebbeding models - "cpu" (default), "cuda" (nvidia gpu required) or "mps" (apple silicon) - choosing this right can lead to better performance
RAG_EMBEDDING_MODEL_DEVICE_TYPE = os.environ.get(
"RAG_EMBEDDING_MODEL_DEVICE_TYPE", "cpu"
)
CHROMA_CLIENT = chromadb.PersistentClient( CHROMA_CLIENT = chromadb.PersistentClient(
path=CHROMA_DATA_PATH, settings=Settings(allow_reset=True) path=CHROMA_DATA_PATH,
settings=Settings(allow_reset=True, anonymized_telemetry=False),
) )
CHUNK_SIZE = 1500 CHUNK_SIZE = 1500
CHUNK_OVERLAP = 100 CHUNK_OVERLAP = 100
RAG_TEMPLATE = """Use the following context as your learned knowledge, inside <context></context> XML tags.
<context>
[context]
</context>
When answer to user:
- If you don't know, just say that you don't know.
- If you don't know when you are not sure, ask for clarification.
Avoid mentioning that you obtained the information from the context.
And answer according to the language of the user's question.
Given the context information, answer the query.
Query: [query]"""
####################################
# Transcribe
####################################
WHISPER_MODEL = os.getenv("WHISPER_MODEL", "base")
WHISPER_MODEL_DIR = os.getenv("WHISPER_MODEL_DIR", f"{CACHE_DIR}/whisper/models")
####################################
# Images
####################################
AUTOMATIC1111_BASE_URL = os.getenv("AUTOMATIC1111_BASE_URL", "")
...@@ -18,6 +18,9 @@ class ERROR_MESSAGES(str, Enum): ...@@ -18,6 +18,9 @@ class ERROR_MESSAGES(str, Enum):
"Uh-oh! This username is already registered. Please choose another username." "Uh-oh! This username is already registered. Please choose another username."
) )
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string." COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
INVALID_TOKEN = ( INVALID_TOKEN = (
"Your session has expired or the token is invalid. Please sign in again." "Your session has expired or the token is invalid. Please sign in again."
) )
...@@ -39,3 +42,8 @@ class ERROR_MESSAGES(str, Enum): ...@@ -39,3 +42,8 @@ class ERROR_MESSAGES(str, Enum):
USER_NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature." API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
MALICIOUS = "Unusual activities detected, please try again in a few minutes." MALICIOUS = "Unusual activities detected, please try again in a few minutes."
PANDOC_NOT_INSTALLED = "Pandoc is not installed on the server. Please contact your administrator for assistance."
INCORRECT_FORMAT = (
lambda err="": f"Invalid format. Please use the correct format{err if err else ''}"
)
{
"ui": {
"prompt_suggestions": [
{
"title": [
"Help me study",
"vocabulary for a college entrance exam"
],
"content": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."
},
{
"title": [
"Give me ideas",
"for what to do with my kids' art"
],
"content": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."
},
{
"title": [
"Tell me a fun fact",
"about the Roman Empire"
],
"content": "Tell me a random fun fact about the Roman Empire"
},
{
"title": [
"Show me a code snippet",
"of a website's sticky header"
],
"content": "Show me a code snippet of a website's sticky header in CSS and JavaScript."
}
]
}
}
\ No newline at end of file
uvicorn main:app --port 8080 --host 0.0.0.0 --forwarded-allow-ips '*' --reload PORT="${PORT:-8080}"
\ No newline at end of file uvicorn main:app --port $PORT --host 0.0.0.0 --forwarded-allow-ips '*' --reload
\ No newline at end of file
from bs4 import BeautifulSoup
import json
import markdown
import time import time
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from fastapi import HTTPException from fastapi import HTTPException
...@@ -10,11 +14,13 @@ from starlette.exceptions import HTTPException as StarletteHTTPException ...@@ -10,11 +14,13 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app from apps.openai.main import app as openai_app
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.web.main import app as webui_app from apps.web.main import app as webui_app
from apps.rag.main import app as rag_app
from config import ENV from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
class SPAStaticFiles(StaticFiles): class SPAStaticFiles(StaticFiles):
...@@ -55,7 +61,35 @@ app.mount("/api/v1", webui_app) ...@@ -55,7 +61,35 @@ app.mount("/api/v1", webui_app)
app.mount("/ollama/api", ollama_app) app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app) app.mount("/openai/api", openai_app)
app.mount("/images/api/v1", images_app)
app.mount("/audio/api/v1", audio_app)
app.mount("/rag/api/v1", rag_app) app.mount("/rag/api/v1", rag_app)
app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files") @app.get("/api/config")
async def get_app_config():
return {
"status": True,
"name": WEBUI_NAME,
"version": VERSION,
"images": images_app.state.ENABLED,
"default_models": webui_app.state.DEFAULT_MODELS,
"default_prompt_suggestions": webui_app.state.DEFAULT_PROMPT_SUGGESTIONS,
}
@app.get("/api/changelog")
async def get_app_changelog():
return CHANGELOG
app.mount("/static", StaticFiles(directory="static"), name="static")
app.mount(
"/",
SPAStaticFiles(directory=FRONTEND_BUILD_DIR, html=True),
name="spa-static-files",
)
...@@ -22,6 +22,15 @@ chromadb ...@@ -22,6 +22,15 @@ chromadb
sentence_transformers sentence_transformers
pypdf pypdf
docx2txt docx2txt
unstructured
markdown
pypandoc
pandas
openpyxl
pyxlsb
xlrd
faster-whisper
PyJWT PyJWT
pyjwt[crypto] pyjwt[crypto]
......
#!/usr/bin/env bash #!/usr/bin/env bash
uvicorn main:app --host 0.0.0.0 --port 8080 --forwarded-allow-ips '*' SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd "$SCRIPT_DIR" || exit
KEY_FILE=.webui_secret_key
PORT="${PORT:-8080}"
if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
echo No WEBUI_SECRET_KEY provided
if ! [ -e "$KEY_FILE" ]; then
echo Generating WEBUI_SECRET_KEY
# Generate a random value to use as a WEBUI_SECRET_KEY in case the user didn't provide one.
echo $(head -c 12 /dev/random | base64) > $KEY_FILE
fi
echo Loading WEBUI_SECRET_KEY from $KEY_FILE
WEBUI_SECRET_KEY=`cat $KEY_FILE`
fi
WEBUI_SECRET_KEY="$WEBUI_SECRET_KEY" exec uvicorn main:app --host 0.0.0.0 --port "$PORT" --forwarded-allow-ips '*'
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment