"src/vscode:/vscode.git/clone" did not exist on "c7e1d3dbc7cfa563d4025dc891fa780f4cc5fdf1"
Unverified Commit 0ddb2b32 authored by Timothy Jaeryang Baek's avatar Timothy Jaeryang Baek Committed by GitHub
Browse files

Merge pull request #413 from ollama-webui/main

dev
parents 880f58e8 ed1d9e61
...@@ -12,7 +12,7 @@ from apps.web.internal.db import DB ...@@ -12,7 +12,7 @@ from apps.web.internal.db import DB
import json import json
#################### ####################
# User DB Schema # Modelfile DB Schema
#################### ####################
...@@ -58,13 +58,14 @@ class ModelfileResponse(BaseModel): ...@@ -58,13 +58,14 @@ class ModelfileResponse(BaseModel):
class ModelfilesTable: class ModelfilesTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
self.db.create_tables([Modelfile]) self.db.create_tables([Modelfile])
def insert_new_modelfile( def insert_new_modelfile(
self, user_id: str, form_data: ModelfileForm self, user_id: str,
) -> Optional[ModelfileModel]: form_data: ModelfileForm) -> Optional[ModelfileModel]:
if "tagName" in form_data.modelfile: if "tagName" in form_data.modelfile:
modelfile = ModelfileModel( modelfile = ModelfileModel(
**{ **{
...@@ -72,8 +73,7 @@ class ModelfilesTable: ...@@ -72,8 +73,7 @@ class ModelfilesTable:
"tag_name": form_data.modelfile["tagName"], "tag_name": form_data.modelfile["tagName"],
"modelfile": json.dumps(form_data.modelfile), "modelfile": json.dumps(form_data.modelfile),
"timestamp": int(time.time()), "timestamp": int(time.time()),
} })
)
try: try:
result = Modelfile.create(**modelfile.model_dump()) result = Modelfile.create(**modelfile.model_dump())
...@@ -87,28 +87,29 @@ class ModelfilesTable: ...@@ -87,28 +87,29 @@ class ModelfilesTable:
else: else:
return None return None
def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]: def get_modelfile_by_tag_name(self,
tag_name: str) -> Optional[ModelfileModel]:
try: try:
modelfile = Modelfile.get(Modelfile.tag_name == tag_name) modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
return ModelfileModel(**model_to_dict(modelfile)) return ModelfileModel(**model_to_dict(modelfile))
except: except:
return None return None
def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]: def get_modelfiles(self,
skip: int = 0,
limit: int = 50) -> List[ModelfileResponse]:
return [ return [
ModelfileResponse( ModelfileResponse(
**{ **{
**model_to_dict(modelfile), **model_to_dict(modelfile),
"modelfile": json.loads(modelfile.modelfile), "modelfile":
} json.loads(modelfile.modelfile),
) }) for modelfile in Modelfile.select()
for modelfile in Modelfile.select()
# .limit(limit).offset(skip) # .limit(limit).offset(skip)
] ]
def update_modelfile_by_tag_name( def update_modelfile_by_tag_name(
self, tag_name: str, modelfile: dict self, tag_name: str, modelfile: dict) -> Optional[ModelfileModel]:
) -> Optional[ModelfileModel]:
try: try:
query = Modelfile.update( query = Modelfile.update(
modelfile=json.dumps(modelfile), modelfile=json.dumps(modelfile),
......
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url
from apps.web.internal.db import DB
import json
####################
# Prompts DB Schema
####################
class Prompt(Model):
command = CharField(unique=True)
user_id = CharField()
title = CharField()
content = TextField()
timestamp = DateField()
class Meta:
database = DB
class PromptModel(BaseModel):
command: str
user_id: str
title: str
content: str
timestamp: int # timestamp in epoch
####################
# Forms
####################
class PromptForm(BaseModel):
command: str
title: str
content: str
class PromptsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Prompt])
def insert_new_prompt(self, user_id: str,
form_data: PromptForm) -> Optional[PromptModel]:
prompt = PromptModel(
**{
"user_id": user_id,
"command": form_data.command,
"title": form_data.title,
"content": form_data.content,
"timestamp": int(time.time()),
})
try:
result = Prompt.create(**prompt.model_dump())
if result:
return prompt
else:
return None
except:
return None
def get_prompt_by_command(self, command: str) -> Optional[PromptModel]:
try:
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
except:
return None
def get_prompts(self) -> List[PromptModel]:
return [
PromptModel(**model_to_dict(prompt)) for prompt in Prompt.select()
# .limit(limit).offset(skip)
]
def update_prompt_by_command(
self, command: str,
form_data: PromptForm) -> Optional[PromptModel]:
try:
query = Prompt.update(
title=form_data.title,
content=form_data.content,
timestamp=int(time.time()),
).where(Prompt.command == command)
query.execute()
prompt = Prompt.get(Prompt.command == command)
return PromptModel(**model_to_dict(prompt))
except:
return None
def delete_prompt_by_command(self, command: str) -> bool:
try:
query = Prompt.delete().where((Prompt.command == command))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Prompts = PromptsTable(DB)
...@@ -3,14 +3,11 @@ from peewee import * ...@@ -3,14 +3,11 @@ from peewee import *
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional from typing import List, Union, Optional
import time import time
from utils.utils import decode_token
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from apps.web.internal.db import DB from apps.web.internal.db import DB
from apps.web.models.chats import Chats from apps.web.models.chats import Chats
#################### ####################
# User DB Schema # User DB Schema
#################### ####################
...@@ -47,6 +44,13 @@ class UserRoleUpdateForm(BaseModel): ...@@ -47,6 +44,13 @@ class UserRoleUpdateForm(BaseModel):
role: str role: str
class UserUpdateForm(BaseModel):
name: str
email: str
profile_image_url: str
password: Optional[str] = None
class UsersTable: class UsersTable:
def __init__(self, db): def __init__(self, db):
self.db = db self.db = db
...@@ -85,14 +89,6 @@ class UsersTable: ...@@ -85,14 +89,6 @@ class UsersTable:
except: except:
return None return None
def get_user_by_token(self, token: str) -> Optional[UserModel]:
data = decode_token(token)
if data != None and "email" in data:
return self.get_user_by_email(data["email"])
else:
return None
def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]: def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
return [ return [
UserModel(**model_to_dict(user)) UserModel(**model_to_dict(user))
...@@ -112,6 +108,16 @@ class UsersTable: ...@@ -112,6 +108,16 @@ class UsersTable:
except: except:
return None return None
def update_user_by_id(self, id: str, updated: dict) -> Optional[UserModel]:
try:
query = User.update(**updated).where(User.id == id)
query.execute()
user = User.get(User.id == id)
return UserModel(**model_to_dict(user))
except:
return None
def delete_user_by_id(self, id: str) -> bool: def delete_user_by_id(self, id: str) -> bool:
try: try:
# Delete User Chats # Delete User Chats
......
from fastapi import Response from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union from typing import List, Union
...@@ -18,16 +18,10 @@ from apps.web.models.auths import ( ...@@ -18,16 +18,10 @@ from apps.web.models.auths import (
) )
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token
from utils.utils import ( from utils.misc import get_gravatar_url, validate_email_format
get_password_hash,
bearer_scheme,
create_token,
)
from utils.misc import get_gravatar_url
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
############################ ############################
...@@ -36,22 +30,14 @@ router = APIRouter() ...@@ -36,22 +30,14 @@ router = APIRouter()
@router.get("/", response_model=UserResponse) @router.get("/", response_model=UserResponse)
async def get_session_user(cred=Depends(bearer_scheme)): async def get_session_user(user=Depends(get_current_user)):
token = cred.credentials return {
user = Users.get_user_by_token(token) "id": user.id,
if user: "email": user.email,
return { "name": user.name,
"id": user.id, "role": user.role,
"email": user.email, "profile_image_url": user.profile_image_url,
"name": user.name, }
"role": user.role,
"profile_image_url": user.profile_image_url,
}
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -60,10 +46,8 @@ async def get_session_user(cred=Depends(bearer_scheme)): ...@@ -60,10 +46,8 @@ async def get_session_user(cred=Depends(bearer_scheme)):
@router.post("/update/password", response_model=bool) @router.post("/update/password", response_model=bool)
async def update_password(form_data: UpdatePasswordForm, cred=Depends(bearer_scheme)): async def update_password(form_data: UpdatePasswordForm,
token = cred.credentials session_user=Depends(get_current_user)):
session_user = Users.get_user_by_token(token)
if session_user: if session_user:
user = Auths.authenticate_user(session_user.email, form_data.password) user = Auths.authenticate_user(session_user.email, form_data.password)
...@@ -106,31 +90,67 @@ async def signin(form_data: SigninForm): ...@@ -106,31 +90,67 @@ async def signin(form_data: SigninForm):
@router.post("/signup", response_model=SigninResponse) @router.post("/signup", response_model=SigninResponse)
async def signup(form_data: SignupForm): async def signup(request: Request, form_data: SignupForm):
if not Users.get_user_by_email(form_data.email.lower()): if request.app.state.ENABLE_SIGNUP:
try: if validate_email_format(form_data.email.lower()):
role = "admin" if Users.get_num_users() == 0 else "pending" if not Users.get_user_by_email(form_data.email.lower()):
hashed = get_password_hash(form_data.password) try:
user = Auths.insert_new_auth( role = "admin" if Users.get_num_users() == 0 else "pending"
form_data.email.lower(), hashed, form_data.name, role hashed = get_password_hash(form_data.password)
) user = Auths.insert_new_auth(form_data.email.lower(),
hashed, form_data.name, role)
if user:
token = create_token(data={"email": user.email}) if user:
# response.set_cookie(key='token', value=token, httponly=True) token = create_token(data={"email": user.email})
# response.set_cookie(key='token', value=token, httponly=True)
return {
"token": token, return {
"token_type": "Bearer", "token": token,
"id": user.id, "token_type": "Bearer",
"email": user.email, "id": user.id,
"name": user.name, "email": user.email,
"role": user.role, "name": user.name,
"profile_image_url": user.profile_image_url, "role": user.role,
} "profile_image_url": user.profile_image_url,
}
else:
raise HTTPException(
500, detail=ERROR_MESSAGES.CREATE_USER_ERROR)
except Exception as err:
raise HTTPException(500,
detail=ERROR_MESSAGES.DEFAULT(err))
else: else:
raise HTTPException(500, detail=ERROR_MESSAGES.CREATE_USER_ERROR) raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
except Exception as err: else:
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err)) raise HTTPException(400,
detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT)
else:
raise HTTPException(400, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
############################
# ToggleSignUp
############################
@router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_current_user)):
if user.role == "admin":
return request.app.state.ENABLE_SIGNUP
else: else:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN) raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
@router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_current_user)):
if user.role == "admin":
request.app.state.ENABLE_SIGNUP = not request.app.state.ENABLE_SIGNUP
return request.app.state.ENABLE_SIGNUP
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
from fastapi import Response from fastapi import Depends, Request, HTTPException, status
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
from utils.utils import get_current_user
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
...@@ -18,8 +17,7 @@ from apps.web.models.chats import ( ...@@ -18,8 +17,7 @@ from apps.web.models.chats import (
) )
from utils.utils import ( from utils.utils import (
bearer_scheme, bearer_scheme, )
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -30,17 +28,9 @@ router = APIRouter() ...@@ -30,17 +28,9 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_user_chats(
token = cred.credentials user=Depends(get_current_user), skip: int = 0, limit: int = 50):
user = Users.get_user_by_token(token) return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
if user:
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -49,20 +39,12 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch ...@@ -49,20 +39,12 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(cred=Depends(bearer_scheme)): async def get_all_user_chats(user=Depends(get_current_user)):
token = cred.credentials return [
user = Users.get_user_by_token(token) ChatResponse(**{
**chat.model_dump(), "chat": json.loads(chat.chat)
if user: }) for chat in Chats.get_all_chats_by_user_id(user.id)
return [ ]
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id)
]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -71,18 +53,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)): ...@@ -71,18 +53,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)):
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)): async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
token = cred.credentials chat = Chats.insert_new_chat(user.id, form_data)
user = Users.get_user_by_token(token) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
if user:
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -91,25 +64,16 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)): ...@@ -91,25 +64,16 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
@router.get("/{id}", response_model=Optional[ChatResponse]) @router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)): async def get_chat_by_id(id: str, user=Depends(get_current_user)):
token = cred.credentials chat = Chats.get_chat_by_id_and_user_id(id, user.id)
user = Users.get_user_by_token(token)
if chat:
if user: return ChatResponse(**{
chat = Chats.get_chat_by_id_and_user_id(id, user.id) **chat.model_dump(), "chat": json.loads(chat.chat)
})
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else: else:
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND)
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -118,26 +82,21 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)): ...@@ -118,26 +82,21 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)): async def update_chat_by_id(id: str,
token = cred.credentials form_data: ChatForm,
user = Users.get_user_by_token(token) user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if user: if chat:
chat = Chats.get_chat_by_id_and_user_id(id, user.id) updated_chat = {**json.loads(chat.chat), **form_data.chat}
if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat} chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{
chat = Chats.update_chat_by_id(id, updated_chat) **chat.model_dump(), "chat": json.loads(chat.chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) })
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
...@@ -147,18 +106,9 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc ...@@ -147,18 +106,9 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)): async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
token = cred.credentials result = Chats.delete_chat_by_id_and_user_id(id, user.id)
user = Users.get_user_by_token(token) return result
if user:
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -167,15 +117,6 @@ async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)): ...@@ -167,15 +117,6 @@ async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
@router.delete("/", response_model=bool) @router.delete("/", response_model=bool)
async def delete_all_user_chats(cred=Depends(bearer_scheme)): async def delete_all_user_chats(user=Depends(get_current_user)):
token = cred.credentials result = Chats.delete_chats_by_user_id(user.id)
user = Users.get_user_by_token(token) return result
if user:
result = Chats.delete_chats_by_user_id(user.id)
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
from fastapi import Response, Request
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union
from fastapi import APIRouter
from pydantic import BaseModel
import time
import uuid
from apps.web.models.users import Users
from utils.utils import get_password_hash, get_current_user, create_token
from utils.misc import get_gravatar_url, validate_email_format
from constants import ERROR_MESSAGES
router = APIRouter()
class SetDefaultModelsForm(BaseModel):
models: str
############################
# SetDefaultModels
############################
@router.post("/default/models", response_model=str)
async def set_global_default_models(request: Request,
form_data: SetDefaultModelsForm,
user=Depends(get_current_user)):
if user.role == "admin":
request.app.state.DEFAULT_MODELS = form_data.models
return request.app.state.DEFAULT_MODELS
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
...@@ -6,8 +5,6 @@ from typing import List, Union, Optional ...@@ -6,8 +5,6 @@ from typing import List, Union, Optional
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.web.models.users import Users
from apps.web.models.modelfiles import ( from apps.web.models.modelfiles import (
Modelfiles, Modelfiles,
ModelfileForm, ModelfileForm,
...@@ -16,9 +13,7 @@ from apps.web.models.modelfiles import ( ...@@ -16,9 +13,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse, ModelfileResponse,
) )
from utils.utils import ( from utils.utils import get_current_user
bearer_scheme,
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -29,17 +24,10 @@ router = APIRouter() ...@@ -29,17 +24,10 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse]) @router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_modelfiles(skip: int = 0,
token = cred.credentials limit: int = 50,
user = Users.get_user_by_token(token) user=Depends(get_current_user)):
return Modelfiles.get_modelfiles(skip, limit)
if user:
return Modelfiles.get_modelfiles(skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -48,36 +36,27 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch ...@@ -48,36 +36,27 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.post("/create", response_model=Optional[ModelfileResponse]) @router.post("/create", response_model=Optional[ModelfileResponse])
async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)): async def create_new_modelfile(form_data: ModelfileForm,
token = cred.credentials user=Depends(get_current_user)):
user = Users.get_user_by_token(token) if user.role != "admin":
raise HTTPException(
if user: status_code=status.HTTP_401_UNAUTHORIZED,
# Admin Only detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
if user.role == "admin": )
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
if modelfile:
return ModelfileResponse( if modelfile:
**{ return ModelfileResponse(
**modelfile.model_dump(), **{
"modelfile": json.loads(modelfile.modelfile), **modelfile.model_dump(),
} "modelfile":
) json.loads(modelfile.modelfile),
else: })
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.DEFAULT(),
) )
...@@ -87,31 +66,21 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch ...@@ -87,31 +66,21 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
@router.post("/", response_model=Optional[ModelfileResponse]) @router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name( async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme) user=Depends(get_current_user)):
): modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
token = cred.credentials
user = Users.get_user_by_token(token) if modelfile:
return ModelfileResponse(
if user: **{
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) **modelfile.model_dump(),
"modelfile":
if modelfile: json.loads(modelfile.modelfile),
return ModelfileResponse( })
**{
**modelfile.model_dump(),
"modelfile": json.loads(modelfile.modelfile),
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.NOT_FOUND,
) )
...@@ -121,45 +90,33 @@ async def get_modelfile_by_tag_name( ...@@ -121,45 +90,33 @@ async def get_modelfile_by_tag_name(
@router.post("/update", response_model=Optional[ModelfileResponse]) @router.post("/update", response_model=Optional[ModelfileResponse])
async def update_modelfile_by_tag_name( async def update_modelfile_by_tag_name(form_data: ModelfileUpdateForm,
form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme) user=Depends(get_current_user)):
): if user.role != "admin":
token = cred.credentials raise HTTPException(
user = Users.get_user_by_token(token) status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
if user: )
if user.role == "admin": modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) if modelfile:
if modelfile: updated_modelfile = {
updated_modelfile = { **json.loads(modelfile.modelfile),
**json.loads(modelfile.modelfile), **form_data.modelfile,
**form_data.modelfile, }
}
modelfile = Modelfiles.update_modelfile_by_tag_name(
modelfile = Modelfiles.update_modelfile_by_tag_name( form_data.tag_name, updated_modelfile)
form_data.tag_name, updated_modelfile
) return ModelfileResponse(
**{
return ModelfileResponse( **modelfile.model_dump(),
**{ "modelfile":
**modelfile.model_dump(), json.loads(modelfile.modelfile),
"modelfile": json.loads(modelfile.modelfile), })
}
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
...@@ -169,23 +126,13 @@ async def update_modelfile_by_tag_name( ...@@ -169,23 +126,13 @@ async def update_modelfile_by_tag_name(
@router.delete("/delete", response_model=bool) @router.delete("/delete", response_model=bool)
async def delete_modelfile_by_tag_name( async def delete_modelfile_by_tag_name(form_data: ModelfileTagNameForm,
form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme) user=Depends(get_current_user)):
): if user.role != "admin":
token = cred.credentials
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
return result
from fastapi import Depends, FastAPI, HTTPException, status
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.web.models.prompts import Prompts, PromptForm, PromptModel
from utils.utils import get_current_user
from constants import ERROR_MESSAGES
router = APIRouter()
############################
# GetPrompts
############################
@router.get("/", response_model=List[PromptModel])
async def get_prompts(user=Depends(get_current_user)):
return Prompts.get_prompts()
############################
# CreateNewPrompt
############################
@router.post("/create", response_model=Optional[PromptModel])
async def create_new_prompt(form_data: PromptForm, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.get_prompt_by_command(form_data.command)
if prompt == None:
prompt = Prompts.insert_new_prompt(user.id, form_data)
if prompt:
return prompt
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.DEFAULT(),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.COMMAND_TAKEN,
)
############################
# GetPromptByCommand
############################
@router.get("/command/{command}", response_model=Optional[PromptModel])
async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
prompt = Prompts.get_prompt_by_command(f"/{command}")
if prompt:
return prompt
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdatePromptByCommand
############################
@router.post("/command/{command}/update", response_model=Optional[PromptModel])
async def update_prompt_by_command(
command: str, form_data: PromptForm, user=Depends(get_current_user)
):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
prompt = Prompts.update_prompt_by_command(f"/{command}", form_data)
if prompt:
return prompt
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
############################
# DeletePromptByCommand
############################
@router.delete("/command/{command}/delete", response_model=bool)
async def delete_prompt_by_command(command: str, user=Depends(get_current_user)):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
result = Prompts.delete_prompt_by_command(f"/{command}")
return result
...@@ -8,15 +8,10 @@ from pydantic import BaseModel ...@@ -8,15 +8,10 @@ from pydantic import BaseModel
import time import time
import uuid import uuid
from apps.web.models.users import UserModel, UserRoleUpdateForm, Users from apps.web.models.users import UserModel, UserUpdateForm, UserRoleUpdateForm, Users
from apps.web.models.auths import Auths from apps.web.models.auths import Auths
from utils.utils import get_current_user, get_password_hash
from utils.utils import (
get_password_hash,
bearer_scheme,
create_token,
)
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -27,23 +22,13 @@ router = APIRouter() ...@@ -27,23 +22,13 @@ router = APIRouter()
@router.get("/", response_model=List[UserModel]) @router.get("/", response_model=List[UserModel])
async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
token = cred.credentials if user.role != "admin":
user = Users.get_user_by_token(token)
if user:
if user.role == "admin":
return Users.get_users(skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
return Users.get_users(skip, limit)
############################ ############################
...@@ -52,28 +37,77 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)) ...@@ -52,28 +37,77 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme))
@router.post("/update/role", response_model=Optional[UserModel]) @router.post("/update/role", response_model=Optional[UserModel])
async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)): async def update_user_role(
token = cred.credentials form_data: UserRoleUpdateForm, user=Depends(get_current_user)
user = Users.get_user_by_token(token) ):
if user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
if user.id != form_data.id:
return Users.update_user_role_by_id(form_data.id, form_data.role)
else:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACTION_PROHIBITED,
)
############################
# UpdateUserById
############################
@router.post("/{user_id}/update", response_model=Optional[UserModel])
async def update_user_by_id(
user_id: str, form_data: UserUpdateForm, session_user=Depends(get_current_user)
):
if session_user.role != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
)
user = Users.get_user_by_id(user_id)
if user: if user:
if user.role == "admin": if form_data.email.lower() != user.email:
if user.id != form_data.id: email_user = Users.get_user_by_email(form_data.email.lower())
return Users.update_user_role_by_id(form_data.id, form_data.role) if email_user:
else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACTION_PROHIBITED, detail=ERROR_MESSAGES.EMAIL_TAKEN,
) )
if form_data.password:
hashed = get_password_hash(form_data.password)
print(hashed)
Auths.update_user_password_by_id(user_id, hashed)
Auths.update_email_by_id(user_id, form_data.email.lower())
updated_user = Users.update_user_by_id(
user_id,
{
"name": form_data.name,
"email": form_data.email.lower(),
"profile_image_url": form_data.profile_image_url,
},
)
if updated_user:
return updated_user
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.DEFAULT(),
) )
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.USER_NOT_FOUND,
) )
...@@ -83,34 +117,25 @@ async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_sc ...@@ -83,34 +117,25 @@ async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_sc
@router.delete("/{user_id}", response_model=bool) @router.delete("/{user_id}", response_model=bool)
async def delete_user_by_id(user_id: str, cred=Depends(bearer_scheme)): async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
token = cred.credentials if user.role == "admin":
user = Users.get_user_by_token(token) if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
if user: if result:
if user.role == "admin": return True
if user.id != user_id:
result = Auths.delete_auth_by_id(user_id)
if result:
return True
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.DELETE_USER_ERROR,
)
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=ERROR_MESSAGES.ACTION_PROHIBITED, detail=ERROR_MESSAGES.DELETE_USER_ERROR,
) )
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACTION_PROHIBITED,
) )
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_403_FORBIDDEN,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
...@@ -9,12 +9,10 @@ import os ...@@ -9,12 +9,10 @@ import os
import aiohttp import aiohttp
import json import json
from utils.misc import calculate_sha256 from utils.misc import calculate_sha256
from config import OLLAMA_API_BASE_URL from config import OLLAMA_API_BASE_URL
router = APIRouter() router = APIRouter()
...@@ -42,7 +40,10 @@ def parse_huggingface_url(hf_url): ...@@ -42,7 +40,10 @@ def parse_huggingface_url(hf_url):
return None return None
async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024): async def download_file_stream(url,
file_path,
file_name,
chunk_size=1024 * 1024):
done = False done = False
if os.path.exists(file_path): if os.path.exists(file_path):
...@@ -56,7 +57,8 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024 ...@@ -56,7 +57,8 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
async with aiohttp.ClientSession(timeout=timeout) as session: async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.get(url, headers=headers) as response: async with session.get(url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size total_size = int(response.headers.get("content-length",
0)) + current_size
with open(file_path, "ab+") as file: with open(file_path, "ab+") as file:
async for data in response.content.iter_chunked(chunk_size): async for data in response.content.iter_chunked(chunk_size):
...@@ -89,9 +91,7 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024 ...@@ -89,9 +91,7 @@ async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024
@router.get("/download") @router.get("/download")
async def download( async def download(url: str, ):
url: str,
):
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf" # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
file_name = parse_huggingface_url(url) file_name = parse_huggingface_url(url)
...@@ -161,4 +161,5 @@ async def upload(file: UploadFile = File(...)): ...@@ -161,4 +161,5 @@ async def upload(file: UploadFile = File(...)):
res = {"error": str(e)} res = {"error": str(e)}
yield f"data: {json.dumps(res)}\n\n" yield f"data: {json.dumps(res)}\n\n"
return StreamingResponse(file_write_stream(), media_type="text/event-stream") return StreamingResponse(file_write_stream(),
media_type="text/event-stream")
...@@ -19,19 +19,28 @@ ENV = os.environ.get("ENV", "dev") ...@@ -19,19 +19,28 @@ ENV = os.environ.get("ENV", "dev")
# OLLAMA_API_BASE_URL # OLLAMA_API_BASE_URL
#################################### ####################################
OLLAMA_API_BASE_URL = os.environ.get( OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL",
"OLLAMA_API_BASE_URL", "http://localhost:11434/api" "http://localhost:11434/api")
)
if ENV == "prod": if ENV == "prod":
if OLLAMA_API_BASE_URL == "/ollama/api": if OLLAMA_API_BASE_URL == "/ollama/api":
OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api" OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api"
####################################
# OPENAI_API
####################################
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
if OPENAI_API_BASE_URL == "":
OPENAI_API_BASE_URL = "https://api.openai.com/v1"
#################################### ####################################
# WEBUI_VERSION # WEBUI_VERSION
#################################### ####################################
WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.42") WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.50")
#################################### ####################################
# WEBUI_AUTH (Required for security) # WEBUI_AUTH (Required for security)
......
...@@ -6,6 +6,7 @@ class MESSAGES(str, Enum): ...@@ -6,6 +6,7 @@ class MESSAGES(str, Enum):
class ERROR_MESSAGES(str, Enum): class ERROR_MESSAGES(str, Enum):
def __str__(self) -> str: def __str__(self) -> str:
return super().__str__() return super().__str__()
...@@ -17,19 +18,20 @@ class ERROR_MESSAGES(str, Enum): ...@@ -17,19 +18,20 @@ class ERROR_MESSAGES(str, Enum):
USERNAME_TAKEN = ( USERNAME_TAKEN = (
"Uh-oh! This username is already registered. Please choose another username." "Uh-oh! This username is already registered. Please choose another username."
) )
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
INVALID_TOKEN = ( INVALID_TOKEN = (
"Your session has expired or the token is invalid. Please sign in again." "Your session has expired or the token is invalid. Please sign in again."
) )
INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again." INVALID_CRED = "The email or password provided is incorrect. Please check for typos and try logging in again."
INVALID_EMAIL_FORMAT = "The email format you entered is invalid. Please double-check and make sure you're using a valid email address (e.g., yourname@example.com)."
INVALID_PASSWORD = ( INVALID_PASSWORD = (
"The password provided is incorrect. Please check for typos and try again." "The password provided is incorrect. Please check for typos and try again."
) )
UNAUTHORIZED = "401 Unauthorized" UNAUTHORIZED = "401 Unauthorized"
ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance." ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
ACTION_PROHIBITED = ( ACTION_PROHIBITED = (
"The requested action has been restricted as a security measure." "The requested action has been restricted as a security measure.")
)
NOT_FOUND = "We could not find what you're looking for :/" NOT_FOUND = "We could not find what you're looking for :/"
USER_NOT_FOUND = "We could not find what you're looking for :/" USER_NOT_FOUND = "We could not find what you're looking for :/"
API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
MALICIOUS = "Unusual activities detected, please try again in a few minutes." MALICIOUS = "Unusual activities detected, please try again in a few minutes."
...@@ -6,12 +6,15 @@ from fastapi.middleware.cors import CORSMiddleware ...@@ -6,12 +6,15 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.exceptions import HTTPException as StarletteHTTPException
from apps.ollama.main import app as ollama_app from apps.ollama.main import app as ollama_app
from apps.openai.main import app as openai_app
from apps.web.main import app as webui_app from apps.web.main import app as webui_app
import time import time
class SPAStaticFiles(StaticFiles): class SPAStaticFiles(StaticFiles):
async def get_response(self, path: str, scope): async def get_response(self, path: str, scope):
try: try:
return await super().get_response(path, scope) return await super().get_response(path, scope)
...@@ -46,5 +49,9 @@ async def check_url(request: Request, call_next): ...@@ -46,5 +49,9 @@ async def check_url(request: Request, call_next):
app.mount("/api/v1", webui_app) app.mount("/api/v1", webui_app)
app.mount("/ollama/api", WSGIMiddleware(ollama_app)) app.mount("/ollama/api", ollama_app)
app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files") app.mount("/openai/api", openai_app)
app.mount("/",
SPAStaticFiles(directory="../build", html=True),
name="spa-static-files")
...@@ -18,3 +18,5 @@ bcrypt ...@@ -18,3 +18,5 @@ bcrypt
PyJWT PyJWT
pyjwt[crypto] pyjwt[crypto]
black
\ No newline at end of file
uvicorn main:app --host 0.0.0.0 --port 8080 --forwarded-allow-ips '*' #!/usr/bin/env bash
\ No newline at end of file
uvicorn main:app --host 0.0.0.0 --port 8080 --forwarded-allow-ips '*'
import hashlib import hashlib
import re
def get_gravatar_url(email): def get_gravatar_url(email):
...@@ -21,3 +22,9 @@ def calculate_sha256(file): ...@@ -21,3 +22,9 @@ def calculate_sha256(file):
for chunk in iter(lambda: file.read(8192), b""): for chunk in iter(lambda: file.read(8192), b""):
sha256.update(chunk) sha256.update(chunk)
return sha256.hexdigest() return sha256.hexdigest()
def validate_email_format(email: str) -> bool:
if not re.match(r"[^@]+@[^@]+\.[^@]+", email):
return False
return True
from fastapi.security import HTTPBasicCredentials, HTTPBearer from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends
from apps.web.models.users import Users
from pydantic import BaseModel from pydantic import BaseModel
from typing import Union, Optional from typing import Union, Optional
from constants import ERROR_MESSAGES
from passlib.context import CryptContext from passlib.context import CryptContext
from datetime import datetime, timedelta from datetime import datetime, timedelta
import requests import requests
import jwt import jwt
import logging
import config import config
logging.getLogger("passlib").setLevel(logging.ERROR)
JWT_SECRET_KEY = config.WEBUI_JWT_SECRET_KEY JWT_SECRET_KEY = config.WEBUI_JWT_SECRET_KEY
ALGORITHM = "HS256" ALGORITHM = "HS256"
...@@ -53,16 +58,18 @@ def extract_token_from_auth_header(auth_header: str): ...@@ -53,16 +58,18 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def verify_token(request): def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
try: data = decode_token(auth_token.credentials)
bearer = request.headers["authorization"] if data != None and "email" in data:
if bearer: user = Users.get_user_by_email(data["email"])
token = bearer[len("Bearer ") :] if user is None:
decoded = jwt.decode( raise HTTPException(
token, JWT_SECRET_KEY, options={"verify_signature": False} status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
return decoded return user
else: else:
return None raise HTTPException(
except Exception as e: status_code=status.HTTP_401_UNAUTHORIZED,
return None detail=ERROR_MESSAGES.UNAUTHORIZED,
)
File added
This image diff could not be displayed because it is too large. You can view the blob instead.
version: '3.6' version: '3.8'
services: services:
ollama: ollama:
# Expose Ollama API outside the container stack # Expose Ollama API outside the container stack
ports: ports:
- 11434:11434 - ${OLLAMA_WEBAPI_PORT-11434}:11434
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment