utils.py 3.81 KB
Newer Older
1
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
Timothy J. Baek's avatar
Timothy J. Baek committed
2
from fastapi import HTTPException, status, Depends, Request
3
from sqlalchemy.orm import Session
Timothy J. Baek's avatar
refac  
Timothy J. Baek committed
4

5
from apps.webui.models.users import Users
liu.vaayne's avatar
liu.vaayne committed
6

7
8
from pydantic import BaseModel
from typing import Union, Optional
9
from constants import ERROR_MESSAGES
10
11
12
13
from passlib.context import CryptContext
from datetime import datetime, timedelta
import requests
import jwt
liu.vaayne's avatar
liu.vaayne committed
14
import uuid
Timothy J. Baek's avatar
Timothy J. Baek committed
15
import logging
16
17
import config

Timothy J. Baek's avatar
Timothy J. Baek committed
18
19
20
logging.getLogger("passlib").setLevel(logging.ERROR)


21
SESSION_SECRET = config.WEBUI_SECRET_KEY
22
23
24
25
26
27
ALGORITHM = "HS256"

##############
# Auth Utils
##############

Timothy J. Baek's avatar
Timothy J. Baek committed
28
bearer_security = HTTPBearer(auto_error=False)
29
30
31
32
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")


def verify_password(plain_password, hashed_password):
Timothy J. Baek's avatar
Timothy J. Baek committed
33
34
35
    return (
        pwd_context.verify(plain_password, hashed_password) if hashed_password else None
    )
36
37
38
39
40
41


def get_password_hash(password):
    return pwd_context.hash(password)


Timothy J. Baek's avatar
Timothy J. Baek committed
42
def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
43
44
45
46
47
48
    payload = data.copy()

    if expires_delta:
        expire = datetime.utcnow() + expires_delta
        payload.update({"exp": expire})

49
    encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
50
51
52
53
54
    return encoded_jwt


def decode_token(token: str) -> Optional[dict]:
    try:
55
        decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
56
57
58
59
60
61
        return decoded
    except Exception as e:
        return None


def extract_token_from_auth_header(auth_header: str):
Timothy J. Baek's avatar
Timothy J. Baek committed
62
    return auth_header[len("Bearer ") :]
63
64


liu.vaayne's avatar
liu.vaayne committed
65
66
67
68
69
def create_api_key():
    key = str(uuid.uuid4()).replace("-", "")
    return f"sk-{key}"


Timothy J. Baek's avatar
Timothy J. Baek committed
70
71
72
def get_http_authorization_cred(auth_header: str):
    try:
        scheme, credentials = auth_header.split(" ")
Timothy J. Baek's avatar
Timothy J. Baek committed
73
        return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
Timothy J. Baek's avatar
Timothy J. Baek committed
74
75
76
77
    except:
        raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)


Timothy J. Baek's avatar
Timothy J. Baek committed
78
def get_current_user(
Timothy J. Baek's avatar
Timothy J. Baek committed
79
    request: Request,
Timothy J. Baek's avatar
Timothy J. Baek committed
80
81
    auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
):
Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
82
    token = None
Timothy J. Baek's avatar
Timothy J. Baek committed
83
84
85
86

    if auth_token is not None:
        token = auth_token.credentials

Timothy J. Baek's avatar
fix  
Timothy J. Baek committed
87
88
89
90
91
92
    if token is None and "token" in request.cookies:
        token = request.cookies.get("token")

    if token is None:
        raise HTTPException(status_code=403, detail="Not authenticated")

liu.vaayne's avatar
liu.vaayne committed
93
    # auth by api key
Timothy J. Baek's avatar
Timothy J. Baek committed
94
    if token.startswith("sk-"):
95
        return get_current_user_by_api_key(token)
Timothy J. Baek's avatar
Timothy J. Baek committed
96

liu.vaayne's avatar
liu.vaayne committed
97
    # auth by jwt token
Timothy J. Baek's avatar
Timothy J. Baek committed
98
    data = decode_token(token)
99
    if data != None and "id" in data:
100
        user = Users.get_user_by_id(data["id"])
101
102
103
104
        if user is None:
            raise HTTPException(
                status_code=status.HTTP_401_UNAUTHORIZED,
                detail=ERROR_MESSAGES.INVALID_TOKEN,
105
            )
Timothy J. Baek's avatar
Timothy J. Baek committed
106
        else:
107
            Users.update_user_last_active_by_id(user.id)
108
        return user
109
110
111
112
113
    else:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.UNAUTHORIZED,
        )
114

Timothy J. Baek's avatar
Timothy J. Baek committed
115

116
117
def get_current_user_by_api_key(db: Session, api_key: str):
    user = Users.get_user_by_api_key(db, api_key)
Timothy J. Baek's avatar
Timothy J. Baek committed
118

liu.vaayne's avatar
liu.vaayne committed
119
120
121
122
123
    if user is None:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.INVALID_TOKEN,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
124
    else:
125
        Users.update_user_last_active_by_id(db, user.id)
Timothy J. Baek's avatar
Timothy J. Baek committed
126

liu.vaayne's avatar
liu.vaayne committed
127
    return user
128

Timothy J. Baek's avatar
Timothy J. Baek committed
129

Timothy J. Baek's avatar
Timothy J. Baek committed
130
def get_verified_user(user=Depends(get_current_user)):
131
132
133
134
135
    if user.role not in {"user", "admin"}:
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
136
    return user
137
138


Timothy J. Baek's avatar
Timothy J. Baek committed
139
def get_admin_user(user=Depends(get_current_user)):
140
141
142
143
144
    if user.role != "admin":
        raise HTTPException(
            status_code=status.HTTP_401_UNAUTHORIZED,
            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
        )
Timothy J. Baek's avatar
Timothy J. Baek committed
145
    return user