"git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "89b35e455647c1a5ec147a4a65967e7807229e95"
Commit a01b112f authored by Anuraag Jain's avatar Anuraag Jain
Browse files

feat(auth): add auth middleware

- refactored chat routes to use request.user instead of doing authentication in every route
parent 83704657
...@@ -4,4 +4,5 @@ _old ...@@ -4,4 +4,5 @@ _old
uploads uploads
.ipynb_checkpoints .ipynb_checkpoints
*.db *.db
_test _test
\ No newline at end of file Pipfile
\ No newline at end of file
from fastapi import FastAPI, Request, Depends, HTTPException from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.authentication import AuthenticationMiddleware
from apps.web.routers import auths, users, chats, modelfiles, utils from apps.web.routers import auths, users, chats, modelfiles, utils
from config import WEBUI_VERSION, WEBUI_AUTH from config import WEBUI_VERSION, WEBUI_AUTH
from apps.web.middlewares.auth import BearerTokenAuthBackend, on_auth_error
app = FastAPI() app = FastAPI()
...@@ -18,11 +19,12 @@ app.add_middleware( ...@@ -18,11 +19,12 @@ app.add_middleware(
app.include_router(auths.router, prefix="/auths", tags=["auths"]) app.include_router(auths.router, prefix="/auths", tags=["auths"])
app.add_middleware(AuthenticationMiddleware, backend=BearerTokenAuthBackend(), on_error=on_auth_error)
app.include_router(users.router, prefix="/users", tags=["users"]) app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"]) app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"]) app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
app.include_router(utils.router, prefix="/utils", tags=["utils"]) app.include_router(utils.router, prefix="/utils", tags=["utils"])
......
from apps.web.models.users import Users
from fastapi import Request, status
from starlette.authentication import (
AuthCredentials, AuthenticationBackend, AuthenticationError,
)
from starlette.requests import HTTPConnection
from utils.utils import verify_token
from starlette.responses import JSONResponse
from constants import ERROR_MESSAGES
class BearerTokenAuthBackend(AuthenticationBackend):
async def authenticate(self, conn: HTTPConnection):
if "Authorization" not in conn.headers:
return
data = verify_token(conn)
if data != None and 'email' in data:
user = Users.get_user_by_email(data['email'])
if user is None:
raise AuthenticationError('Invalid credentials')
return AuthCredentials([user.role]), user
else:
raise AuthenticationError('Invalid credentials')
def on_auth_error(request: Request, exc: Exception):
print('Authentication failed: ', exc)
return JSONResponse({"detail": ERROR_MESSAGES.INVALID_TOKEN}, status_code=status.HTTP_401_UNAUTHORIZED)
\ No newline at end of file
from fastapi import Response
from fastapi import Depends, FastAPI, HTTPException, status from fastapi import Depends, Request, HTTPException, status
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import List, Union, Optional from typing import List, Union, Optional
...@@ -30,17 +30,8 @@ router = APIRouter() ...@@ -30,17 +30,8 @@ router = APIRouter()
@router.get("/", response_model=List[ChatTitleIdResponse]) @router.get("/", response_model=List[ChatTitleIdResponse])
async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)): async def get_user_chats(request:Request, skip: int = 0, limit: int = 50):
token = cred.credentials return Chats.get_chat_lists_by_user_id(request.user.id, skip, limit)
user = Users.get_user_by_token(token)
if user:
return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -49,20 +40,11 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch ...@@ -49,20 +40,11 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
@router.get("/all", response_model=List[ChatResponse]) @router.get("/all", response_model=List[ChatResponse])
async def get_all_user_chats(cred=Depends(bearer_scheme)): async def get_all_user_chats(request:Request,):
token = cred.credentials return [
user = Users.get_user_by_token(token)
if user:
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_all_chats_by_user_id(user.id) for chat in Chats.get_all_chats_by_user_id(request.user.id)
] ]
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -71,18 +53,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)): ...@@ -71,18 +53,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)):
@router.post("/new", response_model=Optional[ChatResponse]) @router.post("/new", response_model=Optional[ChatResponse])
async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)): async def create_new_chat(form_data: ChatForm,request:Request):
token = cred.credentials chat = Chats.insert_new_chat(request.user.id, form_data)
user = Users.get_user_by_token(token) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
if user:
chat = Chats.insert_new_chat(user.id, form_data)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -91,25 +64,14 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)): ...@@ -91,25 +64,14 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
@router.get("/{id}", response_model=Optional[ChatResponse]) @router.get("/{id}", response_model=Optional[ChatResponse])
async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)): async def get_chat_by_id(id: str, request:Request):
token = cred.credentials chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
user = Users.get_user_by_token(token)
if user: if chat:
chat = Chats.get_chat_by_id_and_user_id(id, user.id) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
if chat:
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
else: else:
raise HTTPException( raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND)
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -118,27 +80,18 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)): ...@@ -118,27 +80,18 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
@router.post("/{id}", response_model=Optional[ChatResponse]) @router.post("/{id}", response_model=Optional[ChatResponse])
async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)): async def update_chat_by_id(id: str, form_data: ChatForm, request:Request):
token = cred.credentials chat = Chats.get_chat_by_id_and_user_id(id, request.user.id)
user = Users.get_user_by_token(token) if chat:
if user:
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
updated_chat = {**json.loads(chat.chat), **form_data.chat} updated_chat = {**json.loads(chat.chat), **form_data.chat}
chat = Chats.update_chat_by_id(id, updated_chat) chat = Chats.update_chat_by_id(id, updated_chat)
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)}) return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else: else:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.ACCESS_PROHIBITED, detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
) )
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
############################ ############################
...@@ -147,15 +100,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc ...@@ -147,15 +100,6 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
@router.delete("/{id}", response_model=bool) @router.delete("/{id}", response_model=bool)
async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)): async def delete_chat_by_id(id: str, request: Request):
token = cred.credentials result = Chats.delete_chat_by_id_and_user_id(id, request.user.id)
user = Users.get_user_by_token(token) return result
\ No newline at end of file
if user:
result = Chats.delete_chat_by_id_and_user_id(id, user.id)
return result
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.INVALID_TOKEN,
)
...@@ -55,13 +55,13 @@ def extract_token_from_auth_header(auth_header: str): ...@@ -55,13 +55,13 @@ def extract_token_from_auth_header(auth_header: str):
def verify_token(request): def verify_token(request):
try: try:
bearer = request.headers["authorization"] authorization = request.headers["authorization"]
if bearer: if authorization:
token = bearer[len("Bearer ") :] _, token = authorization.split()
decoded = jwt.decode( decoded_token = jwt.decode(
token, JWT_SECRET_KEY, options={"verify_signature": False} token, JWT_SECRET_KEY, options={"verify_signature": False}
) )
return decoded return decoded_token
else: else:
return None return None
except Exception as e: except Exception as e:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment