Commit 77323d9b authored by Anuraag Jain's avatar Anuraag Jain
Browse files

refac: remove the verify_token and use get-current user for auth+user

parent 2d323b31
...@@ -3,7 +3,6 @@ from fastapi.routing import APIRoute ...@@ -3,7 +3,6 @@ from fastapi.routing import APIRoute
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from apps.web.routers import auths, users, chats, modelfiles, utils from apps.web.routers import auths, users, chats, modelfiles, utils
from config import WEBUI_VERSION, WEBUI_AUTH from config import WEBUI_VERSION, WEBUI_AUTH
from utils.utils import verify_auth_token
app = FastAPI() app = FastAPI()
...@@ -19,24 +18,9 @@ app.add_middleware( ...@@ -19,24 +18,9 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.include_router( app.include_router(users.router, prefix="/users", tags=["users"])
users.router, app.include_router(chats.router, prefix="/chats", tags=["chats"])
prefix="/users", app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
tags=["users"],
dependencies=[Depends(verify_auth_token)],
)
app.include_router(
chats.router,
prefix="/chats",
tags=["chats"],
dependencies=[Depends(verify_auth_token)],
)
app.include_router(
modelfiles.router,
prefix="/modelfiles",
tags=["modelfiles"],
dependencies=[Depends(verify_auth_token)],
)
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])
......
...@@ -19,12 +19,7 @@ from apps.web.models.auths import ( ...@@ -19,12 +19,7 @@ from apps.web.models.auths import (
from apps.web.models.users import Users from apps.web.models.users import Users
from utils.utils import ( from utils.utils import get_password_hash, get_current_user, create_token
get_password_hash,
get_current_user,
create_token,
verify_auth_token,
)
from utils.misc import get_gravatar_url from utils.misc import get_gravatar_url
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
...@@ -36,7 +31,7 @@ router = APIRouter() ...@@ -36,7 +31,7 @@ router = APIRouter()
############################ ############################
@router.get("/", response_model=UserResponse, dependencies=[Depends(verify_auth_token)]) @router.get("/", response_model=UserResponse)
async def get_session_user(user=Depends(get_current_user)): async def get_session_user(user=Depends(get_current_user)):
return { return {
"id": user.id, "id": user.id,
...@@ -52,9 +47,7 @@ async def get_session_user(user=Depends(get_current_user)): ...@@ -52,9 +47,7 @@ async def get_session_user(user=Depends(get_current_user)):
############################ ############################
@router.post( @router.post("/update/password", response_model=bool)
"/update/password", response_model=bool, dependencies=[Depends(verify_auth_token)]
)
async def update_password( async def update_password(
form_data: UpdatePasswordForm, session_user=Depends(get_current_user) form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
): ):
......
...@@ -108,6 +108,7 @@ async def delete_chat_by_id(id: str, user=Depends(get_current_user)): ...@@ -108,6 +108,7 @@ async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
result = Chats.delete_chat_by_id_and_user_id(id, user.id) result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result return result
############################ ############################
# DeleteAllChats # DeleteAllChats
############################ ############################
......
...@@ -5,8 +5,6 @@ from typing import List, Union, Optional ...@@ -5,8 +5,6 @@ from typing import List, Union, Optional
from fastapi import APIRouter from fastapi import APIRouter
from pydantic import BaseModel from pydantic import BaseModel
import json import json
from apps.web.models.users import Users
from apps.web.models.modelfiles import ( from apps.web.models.modelfiles import (
Modelfiles, Modelfiles,
ModelfileForm, ModelfileForm,
...@@ -15,7 +13,7 @@ from apps.web.models.modelfiles import ( ...@@ -15,7 +13,7 @@ from apps.web.models.modelfiles import (
ModelfileResponse, ModelfileResponse,
) )
from utils.utils import bearer_scheme, get_current_user from utils.utils import get_current_user
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
router = APIRouter() router = APIRouter()
...@@ -26,7 +24,7 @@ router = APIRouter() ...@@ -26,7 +24,7 @@ router = APIRouter()
@router.get("/", response_model=List[ModelfileResponse]) @router.get("/", response_model=List[ModelfileResponse])
async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
return Modelfiles.get_modelfiles(skip, limit) return Modelfiles.get_modelfiles(skip, limit)
...@@ -67,7 +65,7 @@ async def create_new_modelfile( ...@@ -67,7 +65,7 @@ async def create_new_modelfile(
@router.post("/", response_model=Optional[ModelfileResponse]) @router.post("/", response_model=Optional[ModelfileResponse])
async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm): async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depends(get_current_user)):
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name) modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
if modelfile: if modelfile:
......
...@@ -55,7 +55,7 @@ def extract_token_from_auth_header(auth_header: str): ...@@ -55,7 +55,7 @@ def extract_token_from_auth_header(auth_header: str):
return auth_header[len("Bearer ") :] return auth_header[len("Bearer ") :]
def verify_auth_token(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())): def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
data = decode_token(auth_token.credentials) data = decode_token(auth_token.credentials)
if data != None and "email" in data: if data != None and "email" in data:
user = Users.get_user_by_email(data["email"]) user = Users.get_user_by_email(data["email"])
...@@ -64,14 +64,9 @@ def verify_auth_token(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBea ...@@ -64,14 +64,9 @@ def verify_auth_token(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBea
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN, detail=ERROR_MESSAGES.INVALID_TOKEN,
) )
return return user
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.UNAUTHORIZED, detail=ERROR_MESSAGES.UNAUTHORIZED,
) )
def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
data = decode_token(auth_token.credentials)
return Users.get_user_by_email(data["email"])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment