Unverified Commit 0ddb2b32 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #413 from ollama-webui/main

dev
parents 880f58e8 ed1d9e61
......@@ -12,7 +12,7 @@ from apps.web.internal.db import DB
import json
####################
# User DB Schema
# Modelfile DB Schema
####################
......@@ -58,13 +58,14 @@ class ModelfileResponse(BaseModel):
class ModelfilesTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Modelfile])
def insert_new_modelfile(
self, user_id: str, form_data: ModelfileForm
) -> Optional[ModelfileModel]:
self, user_id: str,
form_data: ModelfileForm) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile:
modelfile = ModelfileModel(
**{
......@@ -72,8 +73,7 @@ class ModelfilesTable:
"tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()),
}
)
})
try:
result = Modelfile.create(**modelfile.model_dump())
......@@ -87,28 +87,29 @@ class ModelfilesTable:
else:
return None
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
def get_modelfile_by_tag_name(self,
tag_name: str) -> Optional[ModelfileModel]:
try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile))
except:
return None
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
def get_modelfiles(self,
skip: int = 0,
limit: int = 50) -> List[ModelfileResponse]:
return [
ModelfileResponse(
**{
**model_to_dict(modelfile),
"modelfile": json.loads(modelfile.modelfile),
}
)
for modelfile in Modelfile.select()
"modelfile":
json.loads(modelfile.modelfile),
}) for modelfile in Modelfile.select()
# .limit(limit).offset(skip)
]
def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict
) -> Optional[ModelfileModel]:
self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]:
try:
query = Modelfile.update(
modelfile=json.dumps(modelfile),
......
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
import json
####################
# Prompts DB Schema
####################
class Prompt(Model):
command = CharField(unique=True)
user_id = CharField()
title = CharField()
content = TextField()
timestamp = DateField()
class Meta:
database = DB
class PromptModel(BaseModel):
command: str
user_id: str
title: str
content: str
timestamp: int # timestamp in epoch
####################
# Forms
####################
class PromptForm(BaseModel):
command: str
title: str
content: str
class PromptsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt(self, user_id: str,
form_data: PromptForm) -> Optional[PromptModel]:
prompt = PromptModel(
**{
"user_id": user_id,
"command": form_data.command,
"title": form_data.title,
"content": form_data.content,
"timestamp": int(time.time()),
})
try:
result = Prompt.create(**prompt.model_dump())
if result:
return prompt
else:
return None
except:
return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try:
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
except:
return None
def get_prompts(self) -> List[PromptModel]:
return [
PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select()
# .limit(limit).offset(skip)
]
def update_prompt_by_command(
self, command: str,
form_data: PromptForm) -> Optional[PromptModel]:
try:
query = Prompt.update(
title=form_data.title,
content=form_data.content,
timestamp=int(time.time()),
).where(Prompt.command == command)
query.execute()
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
except:
return None
def delete_prompt_by_command(self, command: str) -> bool:
try:
query = Prompt.delete().where((Prompt.command == command))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Prompts = PromptsTable(DB)
......@@ -3,14 +3,11 @@ from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
from apps.web.models.chats import Chats
####################
# User DB Schema
####################
......@@ -47,6 +44,13 @@ class UserRoleUpdateForm(BaseModel):
role: str
class UserUpdateForm(BaseModel):
name: str
email: str
profile_image_url: str
password: Optional[str] = None
class UsersTable:
def __init__(self, db):
self.db = db
......@@ -85,14 +89,6 @@ class UsersTable:
except:
return None
def get_user_by_token(self, token: str) -> Optional[UserModel]:
data = decode_token(token)
if data != None and "email" in data:
return self.get_user_by_email(data["email"])
else:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [
UserModel(**model_to_dict(user))
......@@ -112,6 +108,16 @@ class UsersTable:
except:
return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def delete_user_by_id(self, id: str) -> bool:
try:
# Delete User Chats
......
from fastapi import Response
from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union
......@@ -18,16 +18,10 @@ from apps.web.models.auths import (
)
from apps.web.models.users import Users
from utils.utils import (
get_password_hash,
bearer_scheme,
create_token,
)
from utils.misc import get_gravatar_url
from utils.utils import get_password_hash, get_current_user, create_token
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
router = APIRouter()
############################
......@@ -36,10 +30,7 @@ router = APIRouter()
@router.get("/", response_model=UserResponse)
async def get_session_user(cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def get_session_user(user=Depends(get_current_user)):
return {
"id": user.id,
"email": user.email,
......@@ -47,11 +38,6 @@ async def get_session_user(cred=Depends(bearer_scheme)):
"role": user.role,
"profile_image_url": user.profile_image_url,
}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
......@@ -60,10 +46,8 @@ async def get_session_user(cred=Depends(bearer_scheme)):
@router.post("/update/password", response_model=bool)
async def update_password(form_data: UpdatePasswordForm, cred=Depends(bearer_scheme)):
token = cred.credentials
session_user = Users.get_user_by_token(token)
async def update_password(form_data: UpdatePasswordForm,
session_user=Depends(get_current_user)):
if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password)
......@@ -106,14 +90,15 @@ async def signin(form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse)
async def signup(form_data: SignupForm):
async def signup(request: Request, form_data: SignupForm):
if request.app.state.ENABLE_SIGNUP:
if validate_email_format(form_data.email.lower()):
if not Users.get_user_by_email(form_data.email.lower()):
try:
role = "admin" if Users.get_num_users() == 0 else "pending"
hashed = get_password_hash(form_data.password)
user = Auths.insert_new_auth(
form_data.email.lower(), hashed, form_data.name, role
)
user = Auths.insert_new_auth(form_data.email.lower(),
hashed, form_data.name, role)
if user:
token = create_token(data={"email": user.email})
......@@ -129,8 +114,43 @@ async def signup(form_data: SignupForm):
"profile_image_url": user.profile_image_url,
}
else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
raise HTTPException(
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
raise HTTPException(500,
detail=ERROR_MESSAGES.DEFAULT(err))
else:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
else:
raise HTTPException(400,
detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
############################
# ToggleSignUp
############################
@router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_current_user)):
if user.role == "admin":
return request.app.state.ENABLE_SIGNUP
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_current_user)):
if user.role == "admin":
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
return request.app.state.ENABLE_SIGNUP
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status
from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from utils.utils import get_current_user
from fastapi import APIRouter
from pydantic import BaseModel
import json
......@@ -18,8 +17,7 @@ from apps.web.models.chats import (
)
from utils.utils import (
bearer_scheme,
)
bearer_scheme, )
from constants import ERROR_MESSAGES
router = APIRouter()
......@@ -30,17 +28,9 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def get_user_chats(
user=Depends(get_current_user), skip: int = 0, limit: int = 50):
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
......@@ -49,20 +39,12 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def get_all_user_chats(user=Depends(get_current_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id)
ChatResponse(**{
**chat.model_dump(), "chat": json.loads(chat.chat)
}) for chat in Chats.get_all_chats_by_user_id(user.id)
]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
......@@ -71,18 +53,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)):
@router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
......@@ -91,25 +64,16 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
@router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def get_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
return ChatResponse(**{
**chat.model_dump(), "chat": json.loads(chat.chat)
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND)
############################
......@@ -118,27 +82,22 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
@router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def update_chat_by_id(id: str,
form_data: ChatForm,
user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
return ChatResponse(**{
**chat.model_dump(), "chat": json.loads(chat.chat)
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
......@@ -147,18 +106,9 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
@router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
......@@ -167,15 +117,6 @@ async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
@router.delete("/", response_model=bool)
async def delete_all_user_chats(cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def delete_all_user_chats(user=Depends(get_current_user)):
result = Chats.delete_chats_by_user_id(user.id)
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union
from fastapi import APIRouter
from pydantic import BaseModel
import time
import uuid
from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
router = APIRouter()
class SetDefaultModelsForm(BaseModel):
models: str
############################
# SetDefaultModels
############################
@router.post("/default/models", response_model=str)
async def set_global_default_models(request: Request,
form_data: SetDefaultModelsForm,
user=Depends(get_current_user)):
if user.role == "admin":
request.app.state.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
......@@ -6,8 +5,6 @@ from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.users import Users
from apps.web.models.modelfiles import (
Modelfiles,
ModelfileForm,
......@@ -16,9 +13,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse,
)
from utils.utils import (
bearer_scheme,
)
from utils.utils import get_current_user
from constants import ERROR_MESSAGES
router = APIRouter()
......@@ -29,17 +24,10 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def get_modelfiles(skip: int = 0,
limit: int = 50,
user=Depends(get_current_user)):
return Modelfiles.get_modelfiles(skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
......@@ -48,37 +36,28 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
async def create_new_modelfile(form_data: ModelfileForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user:
# Admin Only
if user.role == "admin":
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
......@@ -87,32 +66,22 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
@router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
......@@ -121,14 +90,13 @@ async def get_modelfile_by_tag_name(
@router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name(
form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme)
):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile:
updated_modelfile = {
......@@ -137,30 +105,19 @@ async def update_modelfile_by_tag_name(
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
form_data.tag_name, updated_modelfile
)
form_data.tag_name, updated_modelfile)
return ModelfileResponse(
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
"modelfile":
json.loads(modelfile.modelfile),
})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################
......@@ -169,23 +126,13 @@ async def update_modelfile_by_tag_name(
@router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name(
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result
else:
async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user
from constants import ERROR_MESSAGES
router = APIRouter()
############################
# GetPrompts
############################
@router.get("/", response_model=List[PromptModel])
async def get_prompts(user=Depends(get_current_user)):
return Prompts.get_prompts()
############################
# CreateNewPrompt
############################
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt == None:
prompt = Prompts.insert_new_prompt(user.id, form_data)
if prompt:
return prompt
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.COMMAND_TAKEN,
)
############################
# GetPromptByCommand
############################
@router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}")
if prompt:
return prompt
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdatePromptByCommand
############################
@router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_current_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:
return prompt
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# DeletePromptByCommand
############################
@router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Prompts.delete_prompt_by_command(f"/{command}")
return result
......@@ -8,15 +8,10 @@ from pydantic import BaseModel
import time
import uuid
from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths
from utils.utils import (
get_password_hash,
bearer_scheme,
create_token,
)
from utils.utils import get_current_user, get_password_hash
from constants import ERROR_MESSAGES
router = APIRouter()
......@@ -27,23 +22,13 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel])
async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
return Users.get_users(skip, limit)
else:
async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
return Users.get_users(skip, limit)
############################
......@@ -52,12 +37,15 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme))
@router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
async def update_user_role(
form_data: UserRoleUpdateForm, user=Depends(get_current_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user:
if user.role == "admin":
if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
else:
......@@ -65,15 +53,61 @@ async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_sc
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
else:
############################
# UpdateUserById
############################
@router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user)
):
if session_user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user = Users.get_user_by_id(user_id)
if user:
if form_data.email.lower() != user.email:
email_user = Users.get_user_by_email(form_data.email.lower())
if email_user:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.EMAIL_TAKEN,
)
if form_data.password:
hashed = get_password_hash(form_data.password)
print(hashed)
Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(user_id, form_data.email.lower())
updated_user = Users.update_user_by_id(
user_id,
{
"name": form_data.name,
"email": form_data.email.lower(),
"profile_image_url": form_data.profile_image_url,
},
)
if updated_user:
return updated_user
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.USER_NOT_FOUND,
)
......@@ -83,11 +117,7 @@ async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_sc
@router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, cred=Depends(bearer_scheme)):
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
if user.role == "admin":
if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
......@@ -109,8 +139,3 @@ async def delete_user_by_id(user_id: str, cred=Depends(bearer_scheme)):
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
......@@ -9,12 +9,10 @@ import os
import aiohttp
import json
from utils.misc import calculate_sha256
from config import OLLAMA_API_BASE_URL
router = APIRouter()
......@@ -42,7 +40,10 @@ def parse_huggingface_url(hf_url):
return None
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
async def download_file_stream(url,
file_path,
file_name,
chunk_size=1024 * 1024):
done = False
if os.path.exists(file_path):
......@@ -56,7 +57,8 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
total_size = int(response.headers.get("content-length",
0)) + current_size
with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size):
......@@ -89,9 +91,7 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
@router.get("/download")
async def download(
url: str,
):
async def download(url: str, ):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name = parse_huggingface_url(url)
......@@ -161,4 +161,5 @@ async def upload(file: UploadFile = File(...)):
res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_write_stream(), media_type="text/event-stream")
return StreamingResponse(file_write_stream(),
media_type="text/event-stream")
......@@ -19,19 +19,28 @@ ENV = os.environ.get("ENV", "dev")
# OLLAMA_API_BASE_URL
####################################
OLLAMA_API_BASE_URL = os.environ.get(
"OLLAMA_API_BASE_URL", "http://localhost:11434/api"
)
OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL",
"http://localhost:11434/api")
if ENV == "prod":
if OLLAMA_API_BASE_URL == "/ollama/api":
OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api"
####################################
# OPENAI_API
####################################
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
####################################
# WEBUI_VERSION
####################################
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.42")
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.50")
####################################
# WEBUI_AUTH (Required for security)
......
......@@ -6,6 +6,7 @@ class MESSAGES(str, Enum):
class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str:
return super().__str__()
......@@ -17,19 +18,20 @@ class ERROR_MESSAGES(str, Enum):
USERNAME_TAKEN = (
"Uh-oh! This username is already registered. Please choose another username."
)
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
INVALID_TOKEN = (
"Your session has expired or the token is invalid. Please sign in again."
)
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
INVALID_PASSWORD = (
"The password provided is incorrect. Please check for typos and try again."
)
UNAUTHORIZED = "401 Unauthorized"
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED = (
"The requested action has been restricted as a security measure."
)
"The requested action has been restricted as a security measure.")
NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
MALICIOUS = "Unusual activities detected, please try again in a few minutes."
......@@ -6,12 +6,15 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException
from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
from apps.web.main import app as webui_app
import time
class SPAStaticFiles(StaticFiles):
async def get_response(self, path: str, scope):
try:
return await super().get_response(path, scope)
......@@ -46,5 +49,9 @@ async def check_url(request: Request, call_next):
app.mount("/api/v1", webui_app)
app.mount("/ollama/api", WSGIMiddleware(ollama_app))
app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files")
app.mount("/ollama/api", ollama_app)
app.mount("/openai/api", openai_app)
app.mount("/",
SPAStaticFiles(directory="../build", html=True),
name="spa-static-files")
......@@ -18,3 +18,5 @@ bcrypt
PyJWT
pyjwt[crypto]
black
\ No newline at end of file
#!/usr/bin/env bash
uvicorn main:app --host 0.0.0.0 --port 8080 --forwarded-allow-ips '*'
import hashlib
import re
def get_gravatar_url(email):
......@@ -21,3 +22,9 @@ def calculate_sha256(file):
for chunk in iter(lambda: file.read(8192), b""):
sha256.update(chunk)
return sha256.hexdigest()
def validate_email_format(email: str) -> bool:
if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
return False
return True
from fastapi.security import HTTPBasicCredentials, HTTPBearer
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends
from apps.web.models.users import Users
from pydantic import BaseModel
from typing import Union, Optional
from constants import ERROR_MESSAGES
from passlib.context import CryptContext
from datetime import datetime, timedelta
import requests
import jwt
import logging
import config
logging.getLogger("passlib").setLevel(logging.ERROR)
JWT_SECRET_KEY = config.WEBUI_JWT_SECRET_KEY
ALGORITHM = "HS256"
......@@ -53,16 +58,18 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :]
def verify_token(request):
try:
bearer = request.headers["authorization"]
if bearer:
token = bearer[len("Bearer ") :]
decoded = jwt.decode(
token, JWT_SECRET_KEY, options={"verify_signature": False}
def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
data = decode_token(auth_token.credentials)
if data != None and "email" in data:
user = Users.get_user_by_email(data["email"])
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
return decoded
return user
else:
return None
except Exception as e:
return None
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED,
)
File added
This image diff could not be displayed because it is too large. You can view the blob instead.
version: '3.6'
version: '3.8'
services:
ollama:
# Expose Ollama API outside the container stack
ports:
- 11434:11434
\ No newline at end of file
- ${OLLAMA_WEBAPI_PORT-11434}:11434
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment